TensorFlow示例代码中迭代器的用法

TensorFlow示例代码中迭代器的用法,tensorflow,Tensorflow,我正在学习TensorFlow(TF),这仅仅是一天,所以如果我的疑问太基本而无法提出,我提前道歉。 我在TF官方网站上学习 作者定义了一个名为input\u fun的函数来读取数据。功能如下: def input_fn(data_file, num_epochs, shuffle, batch_size): """Generate an input function for the Estimator.""" assert tf.gfile.Exists(data_file), (

我正在学习TensorFlow(TF),这仅仅是一天,所以如果我的疑问太基本而无法提出,我提前道歉。 我在TF官方网站上学习

作者定义了一个名为
input\u fun
的函数来读取数据。功能如下:

def input_fn(data_file, num_epochs, shuffle, batch_size):
  """Generate an input function for the Estimator."""
  assert tf.gfile.Exists(data_file), (
      '%s not found. Please make sure you have either run data_download.py or '
      'set both arguments --train_data and --test_data.' % data_file)

  def parse_csv(value):
    print('Parsing', data_file)
    columns = tf.decode_csv(value, record_defaults=_CSV_COLUMN_DEFAULTS)
    features = dict(zip(_CSV_COLUMNS, columns))
    labels = features.pop('income_bracket')
    return features, tf.equal(labels, '>50K')

  # Extract lines from input files using the Dataset API.
  dataset = tf.data.TextLineDataset(data_file)

  if shuffle:
    dataset = dataset.shuffle(buffer_size=_NUM_EXAMPLES['train'])

  dataset = dataset.map(parse_csv, num_parallel_calls=5)

  # We call repeat after shuffling, rather than before, to prevent separate
  # epochs from blending together.
  dataset = dataset.repeat(num_epochs)
  dataset = dataset.batch(batch_size)

  iterator = dataset.make_one_shot_iterator()
  features, labels = iterator.get_next()
  return features, labels
我不能理解最后一行的第二行。一次性迭代器只调用一次
get_next()
,但不应该对数据进行多次迭代(即行数次)以提取行,如?

因此,这里,get_next()基本上是一个出列操作。数据在队列中,当您使用(使用/调用)get_next()调用的元素时,它将从队列中删除,并且下一个图像/标签会移动到它的位置,下次调用它时,它将退出队列

因此,目前,这个函数只返回元素的tensorflow op,您可以在训练循环中使用它