Tensorflow에서 배치 훈련


11

현재 큰 CSV 파일 (60 백만 행 이상 70GB 이상)에서 모델을 훈련하려고합니다. 그렇게하기 위해 tf.contrib.learn.read_batch_examples를 사용하고 있습니다. 이 함수가 실제로 데이터를 읽는 방법을 이해하는 데 어려움을 겪고 있습니다. 예를 들어 배치 크기가 50.000 인 경우 파일의 처음 50.000 줄을 읽습니까? 전체 파일을 반복하려면 (1 epoch) estimator.fit 메소드에 num_rows / batch_size = 1.200 단계 수를 사용해야합니까?

다음은 현재 사용중인 입력 기능입니다.

def input_fn(file_names, batch_size):
    # Read csv files and create examples dict
    examples_dict = read_csv_examples(file_names, batch_size)

    # Continuous features
    feature_cols = {k: tf.string_to_number(examples_dict[k],
                                           out_type=tf.float32) for k in CONTINUOUS_COLUMNS}

    # Categorical features
    feature_cols.update({
                            k: tf.SparseTensor(
                                indices=[[i, 0] for i in range(examples_dict[k].get_shape()[0])],
                                values=examples_dict[k],
                                shape=[int(examples_dict[k].get_shape()[0]), 1])
                            for k in CATEGORICAL_COLUMNS})

    label = tf.string_to_number(examples_dict[LABEL_COLUMN], out_type=tf.int32)

    return feature_cols, label


def read_csv_examples(file_names, batch_size):
    def parse_fn(record):
        record_defaults = [tf.constant([''], dtype=tf.string)] * len(COLUMNS)

        return tf.decode_csv(record, record_defaults)

    examples_op = tf.contrib.learn.read_batch_examples(
        file_names,
        batch_size=batch_size,
        queue_capacity=batch_size*2.5,
        reader=tf.TextLineReader,
        parse_fn=parse_fn,
        #read_batch_size= batch_size,
        #randomize_input=True,
        num_threads=8
    )

    # Important: convert examples to dict for ease of use in `input_fn`
    # Map each header to its respective column (COLUMNS order
    # matters!
    examples_dict_op = {}
    for i, header in enumerate(COLUMNS):
        examples_dict_op[header] = examples_op[:, i]

    return examples_dict_op

다음은 모델 학습에 사용하는 im 코드입니다.

def train_and_eval():
"""Train and evaluate the model."""

m = build_estimator(model_dir)
m.fit(input_fn=lambda: input_fn(train_file_name, batch_size), steps=steps)

동일한 input_fn을 사용하여 fit 함수를 다시 호출하면 어떻게됩니까? 파일의 시작 부분에서 다시 시작합니까, 아니면 마지막에 중지 된 행을 기억합니까?


나는 medium.com/@ilblackdragon/…을 발견 했다. tensorflow input_fn
fistynuts의

yau가 이미 확인 했습니까? stackoverflow.com/questions/37091899/…
Frankstr

답변:


1

아직 답변이 없으므로 적어도 유용한 답변을 제공하려고합니다. 상수 정의를 포함하면 제공된 코드를 이해하는 데 도움이됩니다.

일반적으로 배치는 n 번 레코드 또는 항목을 사용합니다. 항목을 정의하는 방법은 문제에 따라 다릅니다. 텐서 플로우에서 배치는 텐서의 첫 번째 차원으로 인코딩됩니다. csv 파일을 사용하는 경우 줄 단위 ( reader=tf.TextLineReader) 일 수 있습니다 . 열별로 배울 수는 있지만 이것이 코드에서 발생한다고 생각하지 않습니다. 전체 데이터 세트 (= 하나의 에포크 )를 사용하여 학습 하려면을 사용하여 수행 할 수 있습니다 numBatches=numItems/batchSize.

당사 사이트를 사용함과 동시에 당사의 쿠키 정책개인정보 보호정책을 읽고 이해하였음을 인정하는 것으로 간주합니다.
Licensed under cc by-sa 3.0 with attribution required.