Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/tensorflow/5.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
如何正确组合TensorFlow和#x27;s数据集API和Keras?_Tensorflow_Keras - Fatal编程技术网

如何正确组合TensorFlow和#x27;s数据集API和Keras?

如何正确组合TensorFlow和#x27;s数据集API和Keras?,tensorflow,keras,Tensorflow,Keras,Keras'fit_generator()model方法需要一个生成形状元组(输入、目标)的生成器,其中两个元素都是NumPy数组。这似乎意味着,如果我简单地将a包装在生成器中,并确保将张量转换为NumPy数组,我就可以开始了。但是,此代码给了我一个错误: import numpy as np import os import keras.backend as K from keras.layers import Dense, Input from keras.models import Mod

Keras'
fit_generator()
model方法需要一个生成形状元组(输入、目标)的生成器,其中两个元素都是NumPy数组。这似乎意味着,如果我简单地将a包装在生成器中,并确保将张量转换为NumPy数组,我就可以开始了。但是,此代码给了我一个错误:

import numpy as np
import os
import keras.backend as K
from keras.layers import Dense, Input
from keras.models import Model
import tensorflow as tf
from tensorflow.contrib.data import Dataset

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

with tf.Session() as sess:
    def create_data_generator():
        dat1 = np.arange(4).reshape(-1, 1)
        ds1 = Dataset.from_tensor_slices(dat1).repeat()

        dat2 = np.arange(5, 9).reshape(-1, 1)
        ds2 = Dataset.from_tensor_slices(dat2).repeat()

        ds = Dataset.zip((ds1, ds2)).batch(4)
        iterator = ds.make_one_shot_iterator()
        while True:
            next_val = iterator.get_next()
            yield sess.run(next_val)

datagen = create_data_generator()

input_vals = Input(shape=(1,))
output = Dense(1, activation='relu')(input_vals)
model = Model(inputs=input_vals, outputs=output)
model.compile('rmsprop', 'mean_squared_error')
model.fit_generator(datagen, steps_per_epoch=1, epochs=5,
                    verbose=2, max_queue_size=2)
下面是我得到的错误:

Using TensorFlow backend.
Epoch 1/5
Exception in thread Thread-1:
Traceback (most recent call last):
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 270, in __init__
    fetch, allow_tensor=True, allow_operation=True))
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 2708, in as_graph_element
    return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 2787, in _as_graph_element_locked
    raise ValueError("Tensor %s is not an element of this graph." % obj)
ValueError: Tensor Tensor("IteratorGetNext:0", shape=(?, 1), dtype=int64) is not an element of this graph.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/jsaporta/anaconda3/lib/python3.6/threading.py", line 916, in _bootstrap_inner
    self.run()
  File "/home/jsaporta/anaconda3/lib/python3.6/threading.py", line 864, in run
    self._target(*self._args, **self._kwargs)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/keras/utils/data_utils.py", line 568, in data_generator_task
    generator_output = next(self._generator)
  File "./datagen_test.py", line 25, in create_data_generator
    yield sess.run(next_val)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 895, in run
    run_metadata_ptr)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1109, in _run
    self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 413, in __init__
    self._fetch_mapper = _FetchMapper.for_fetch(fetches)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 233, in for_fetch
    return _ListFetchMapper(fetch)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 340, in __init__
    self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 340, in <listcomp>
    self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 241, in for_fetch
    return _ElementFetchMapper(fetches, contraction_fn)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 277, in __init__
    'Tensor. (%s)' % (fetch, str(e)))
ValueError: Fetch argument <tf.Tensor 'IteratorGetNext:0' shape=(?, 1) dtype=int64> cannot be interpreted as a Tensor. (Tensor Tensor("IteratorGetNext:0", shape=(?, 1), dtype=int64) is not an element of this graph.)

Traceback (most recent call last):
  File "./datagen_test.py", line 34, in <module>
    verbose=2, max_queue_size=2)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/keras/legacy/interfaces.py", line 87, in wrapper
    return func(*args, **kwargs)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/keras/engine/training.py", line 2011, in fit_generator
    generator_output = next(output_generator)
StopIteration
使用TensorFlow后端。
纪元1/5
线程1中的异常:
回溯(最近一次呼叫最后一次):
文件“/home/jsaporta/anaconda3/lib/python3.6/site packages/tensorflow/python/client/session.py”,第270行,在__
fetch,allow_tensor=True,allow_operation=True)
文件“/home/jsaporta/anaconda3/lib/python3.6/site packages/tensorflow/python/framework/ops.py”,第2708行,在as_graph_元素中
返回self.\u as\u graph\u element\u locked(对象、允许张量、允许操作)
文件“/home/jsaporta/anaconda3/lib/python3.6/site packages/tensorflow/python/framework/ops.py”,第2787行,在“作为图形”元素中
raise VALUERROR(“张量%s不是此图的元素。”%obj)
ValueError:Tensor Tensor(“IteratorGetNext:0”,shape=(?,1),dtype=int64)不是此图的元素。
在处理上述异常期间,发生了另一个异常:
回溯(最近一次呼叫最后一次):
文件“/home/jsaporta/anaconda3/lib/python3.6/threading.py”,第916行,在内部引导中
self.run()
文件“/home/jsaporta/anaconda3/lib/python3.6/threading.py”,第864行,运行中
自我目标(*自我参数,**自我参数)
文件“/home/jsaporta/anaconda3/lib/python3.6/site packages/keras/utils/data_utils.py”,第568行,在数据生成器任务中
发电机输出=下一个(自发电机)
文件“/datagen_test.py”,第25行,在创建数据生成器中
产量评估运行(下一个值)
文件“/home/jsaporta/anaconda3/lib/python3.6/site packages/tensorflow/python/client/session.py”,第895行,正在运行
运行_元数据_ptr)
文件“/home/jsaporta/anaconda3/lib/python3.6/site packages/tensorflow/python/client/session.py”,第1109行,正在运行
self.\u图形、回迁、馈送\u dict\u张量、馈送\u句柄=馈送\u句柄)
文件“/home/jsaporta/anaconda3/lib/python3.6/site packages/tensorflow/python/client/session.py”,第413行,在__
self.\u fetch\u mapper=\u FetchMapper.for\u fetch(fetches)
文件“/home/jsaporta/anaconda3/lib/python3.6/site packages/tensorflow/python/client/session.py”,第233行,for_fetch
return\u ListFetchMapper(fetch)
文件“/home/jsaporta/anaconda3/lib/python3.6/site packages/tensorflow/python/client/session.py”,第340行,在__
self._mappers=[_FetchMapper.for_fetch(fetch)for fetch in fetches]
文件“/home/jsaporta/anaconda3/lib/python3.6/site packages/tensorflow/python/client/session.py”,第340行,在
self._mappers=[_FetchMapper.for_fetch(fetch)for fetch in fetches]
文件“/home/jsaporta/anaconda3/lib/python3.6/site packages/tensorflow/python/client/session.py”,第241行,for_fetch
return\u ElementFetchMapper(fetches,contraction\u fn)
文件“/home/jsaporta/anaconda3/lib/python3.6/site packages/tensorflow/python/client/session.py”,第277行,在__
张量。(%s)'(提取,str(e)))
ValueError:无法将Fetch参数解释为张量。(Tensor Tensor(“IteratorGetNext:0”,shape=(?,1),dtype=int64)不是此图的元素。)
回溯(最近一次呼叫最后一次):
文件“/datagen_test.py”,第34行,在
详细信息=2,最大队列大小=2)
文件“/home/jsaporta/anaconda3/lib/python3.6/site packages/keras/legacy/interfaces.py”,第87行,在包装器中
返回函数(*args,**kwargs)
文件“/home/jsaporta/anaconda3/lib/python3.6/site packages/keras/engine/training.py”,第2011行,在fit_generator中
发电机输出=下一个(输出发电机)
停止迭代
奇怪的是,在初始化
datagen
后直接添加一行包含
next(datagen)
的代码会使代码运行正常,没有错误


为什么我的原始代码不起作用?为什么当我把那一行添加到代码中时它就开始工作了?有没有一种更有效的方法可以将TensorFlow的数据集API与Keras结合使用,而不需要将张量转换为NumPy数组,然后再转换回来?

确实有一种更有效的方法可以使用
数据集,而无需将张量转换为NumPy数组。然而,官方文件中没有(尚未?)这一点。从发行说明来看,这是Keras 2.0.7中引入的一个特性。您可能必须安装keras>=2.0.7才能使用它

x = np.arange(4).reshape(-1, 1).astype('float32')
ds_x = Dataset.from_tensor_slices(x).repeat().batch(4)
it_x = ds_x.make_one_shot_iterator()

y = np.arange(5, 9).reshape(-1, 1).astype('float32')
ds_y = Dataset.from_tensor_slices(y).repeat().batch(4)
it_y = ds_y.make_one_shot_iterator()

input_vals = Input(tensor=it_x.get_next())
output = Dense(1, activation='relu')(input_vals)
model = Model(inputs=input_vals, outputs=output)
model.compile('rmsprop', 'mse', target_tensors=[it_y.get_next()])
model.fit(steps_per_epoch=1, epochs=5, verbose=2)
几个不同之处:

  • 张量
    参数提供给
    输入
    层。Keras将从该张量中读取值,并将其用作输入以拟合模型
  • Model.compile()
    提供
    target\u张量
    参数
  • 请记住将x和y都转换为
    float32
    。在正常使用情况下,Keras将为您进行此转换。但现在你得自己动手了
  • 批大小是在构建
    数据集
    期间指定的。使用
    每个历元的步数
    历元
    控制何时停止模型拟合
  • 简而言之,如果要从张量读取数据,请使用
    输入(张量=…)
    模型。编译(目标张量=…)
    模型。拟合(x=None,y=None,…)

    更新日期:2018年6月9日
    • 从Tensorflow 1.9开始,可以将
      tf.data.Dataset
      对象直接传递到
      keras.Model.fit()
      ,其作用类似于
      fit\u生成器
    • 可以在此上找到完整的示例
    #加载列表训练数据
    (x_列,y_列),u=tf.keras.datasets.mnist.load_data()
    训练集=tfdata\U生成器(x\U列,y\U列,is\U训练=真)
    model=#这里是您的keras模型
    模型拟合(
    训练集。生成一个迭代器(),
    每历元步数=len(x\u列)//128,
    纪元=5,
    详细=1)
    
    • def _get_input_data_for_dataset(file_name): df_input=pd.read_csv(file_name.decode(),usecols=['Wind_MWh']) X_data = df_input.as_matrix() return X_data.astype('float32', copy=False) X_dataset = tf.data.Dataset.from_tensor_slices(file_names) X_dataset = X_dataset.flat_map(lambda file_name: tf.data.Dataset.from_tensor_slices( tf.reshape(tf.py_func(_get_input_data_for_dataset,[file_name], tf.float32),[-1,1]))) X_dataset = X_dataset.batch(5) X_iter = X_dataset.make_one_shot_iterator() X_batch = X_iter.get_next() input_X1 = Input(tensor= X_batch ,name='input_X1') y1 = Dense(units=64, activation='relu',kernel_initializer=tf.keras.initializers.Constant(1),name='layer_FC1')(input_X1)
    estimator = tf.keras.estimator.model_to_estimator(keras_model=model,
                                                      model_dir=model_dir)
    input_name = model.layers[0].input.op.name
    
    def input_fn(dataset):
        dataset = dataset.map(lambda X,y: {input_name: X}, y)
        return dataset.make_one_shot_iterator().get_next()
    
    train_spec = tf.estimator.TrainSpec(
        input_fn=lambda: input_fn(train_set), max_steps=100)
    eval_spec = tf.estimator.EvalSpec(
        input_fn=lambda: input_fn(test_set))
    
    tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
    
    def create_generator_tf_dataset(self, images, onehots, batch_size):
        # Get shapes
        img_size = images.shape
        img_size = (None, img_size[1], img_size[2], img_size[3])
        onehot_size = onehots.shape
        onehot_size = (None, onehot_size[1])
    
        # Placeholders
        images_tensor = tf.placeholder(tf.float32, shape=img_size)
        onehots_tensor = tf.placeholder(tf.float32, shape=onehot_size)
    
        # Dataset
        dataset = tf.data.Dataset.from_tensor_slices((images_tensor, onehots_tensor))
        # Map function (e.g. augmentation)
        if map_fn is not None:
            dataset = dataset.map(lambda x, y: (map_fn(x), y), num_parallel_calls=tf.data.experimental.AUTOTUNE)
        # Combined shuffle and infinite repeat
        dataset = dataset.apply(
            tf.data.experimental.shuffle_and_repeat(len(images), None))  
        dataset = dataset.batch(batch_size)
        dataset = dataset.prefetch(1)
    
        # Make the iterator
        iterator = dataset.make_initializable_iterator()
        init_op = iterator.initializer
        next_val = iterator.get_next()
    
        with K.get_session().as_default() as sess:
            sess.run(init_op, feed_dict={images_tensor: images, onehots_tensor: onehots})
            while True:
                inputs, labels = sess.run(next_val)
                yield inputs, labels