Warning: file_get_contents(/data/phpspider/zhask/data//catemap/1/typo3/2.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Python 如何恢复使用Dataset API训练的Tensorflow模型?_Python_Tensorflow_Restore_Tensorflow Datasets - Fatal编程技术网

Python 如何恢复使用Dataset API训练的Tensorflow模型?

Python 如何恢复使用Dataset API训练的Tensorflow模型?,python,tensorflow,restore,tensorflow-datasets,Python,Tensorflow,Restore,Tensorflow Datasets,我正在使用带有可反馈迭代器的Dataset API来训练我的模型,就像在导入数据教程中一样。 问题是,在恢复模型时。它还将使用训练后的形状恢复控制柄占位符。这意味着它希望得到一个示例和一个标签 def loadTFRecord(filenames): dataset = tf.data.TFRecordDataset([filenames]) dataset = dataset.map(extract_img_func) datas

我正在使用带有可反馈迭代器的Dataset API来训练我的模型,就像在导入数据教程中一样。 问题是,在恢复模型时。它还将使用训练后的形状恢复控制柄占位符。这意味着它希望得到一个示例和一个标签

    def loadTFRecord(filenames):          
      dataset = tf.data.TFRecordDataset([filenames])
      dataset = dataset.map(extract_img_func)
      dataset = dataset.batch(batchsize)
      handle = tf.placeholder(tf.string, shape=[])
      iterator = tf.data.Iterator.from_string_handle(handle, dataset.output_types, dataset.output_shapes)
      training_iterator = dataset.make_one_shot_iterator()
      next_element = iterator.get_next() 
      training_handle = self.sess.run(training_iterator.string_handle())
      return next_element #next_element[0] is the example img, next_element[1] is the label

    def model_fn(images, labels=None, train=False):
      input_layer = images
      ...
      predictions = last_layer
      if train:
        return predictions

      # Calculate loss
      loss = tf.losses.mean_squared_error(labels, predictions)
      learning_rate = tf.train.exponential_decay(learning_rate=learningRate, staircase=True)
      optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
      train_op = optimizer.minimize(
          loss=loss,
          global_step=global_step)

      return train_op, predictions, loss
通过这一点,我创建了我的培训模型:

examples, labels = loadTFRecord("path/to/tfrecord")
model_fn(examples, labels=labels)
saver = tf.train.Saver(max_to_keep=4, keep_checkpoint_every_n_hours=0.5)
... #training here
saver.save(sess, "path/to/")
现在的问题是,当我想恢复模型进行推理时。 我想做的是恢复模型,并传入另一个feedable迭代器,该迭代器从磁盘加载一些.png文件。我这样做类似于加载TFRecord文件

def load_images(filenames):
  dataset = tf.data.Dataset.from_tensor_slices(filenames)
  dataset = dataset.map(lambda x: tf.image.resize_images(self.normalize(tf.image.decode_png(tf.read_file(x), channels = 3)), [IM_WIDTH, IM_HEIGHT]))
  dataset = dataset.batch(1)
  iterator = tf.data.Iterator.from_string_handle(handle, dataset.output_types, dataset.output_shapes)
  iterator = dataset.make_one_shot_iterator()
  next_img = iterator.get_next()
  training_handle = sess.run(iterator.string_handle())
  return next_img
现在的问题是,将其传递给恢复的模型时,如下所示:

  saver = tf.train.import_meta_graph(modelbasepath + ".meta")
  saver.restore(sess, modelbasepath)
  ... # restore operations here
  # finally run predictions, error occurs here!
  predictions = sess.run([predictions], feed_dict={handle: training_handle})
我得到了这个错误:

Number of components does not match: expected 2 types but got 1.
 [[Node: IteratorFromStringHandle_2 = IteratorFromStringHandle[output_shapes=[[?,80,80,3], [?,80,80,?]], output_types=[DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_Placeholder_1_0_0)]]
这告诉我,它也希望得到一个标签,而我只是提供一个图像来预测


我怎样才能克服这个问题?有没有办法改变占位符的形状,或者如何实现这一点,以便能够恢复使用datatset API和feedable dicts训练过的模型?

我遇到了同样的问题。然而,我无法想出一个干净的解决方案。最后,我为加载图像时返回的标签创建了一个虚拟张量。也许有更好的方法可以做到这一点,但是这个解决方案现在应该允许您运行模型

dataset = tf.data.Dataset.from_tensor_slices(filenames)
dataset = dataset.map(load_images)

def load_images(x):
    image = tf.image.decode_png(tf.read_file(x), channels = 3))
    image = self.normalize(image)
    image = tf.image.resize_images(image, [IM_WIDTH, IM_HEIGHT])

    # Assuming label is one channel, can slice image to get correct dims
    label = tf.zeros_like(image[:, :, 0:1]) 

    return image, label

谢谢你的建议。这是一个很好的解决办法。然而,我想知道这是如何做到正确的方式!