Python tf。keras fit_generator()在验证_数据上卡住

Python tf。keras fit_generator()在验证_数据上卡住,python,tensorflow,machine-learning,keras,tf.keras,Python,Tensorflow,Machine Learning,Keras,Tf.keras,我正在使用DataGenerator表单tf.keras.Sequence批量加载数据。数据生成器返回图像和掩码的numpy数组。当我调用fit_generator()时,它看起来像是模型适合火车数据,但被验证数据卡住了。如果我设置Validation_data=None,然后运行它,则不会出现错误。我使用的是tensorflow 1.14,tf.keras 2.2.4 以下是代码片段: model = create_model() optimizer = Adam(lr = 0.001) mo

我正在使用DataGenerator表单tf.keras.Sequence批量加载数据。数据生成器返回图像和掩码的numpy数组。当我调用fit_generator()时,它看起来像是模型适合火车数据,但被验证数据卡住了。如果我设置Validation_data=None,然后运行它,则不会出现错误。我使用的是tensorflow 1.14,tf.keras 2.2.4

以下是代码片段:

model = create_model()
optimizer = Adam(lr = 0.001)
model.compile(loss=loss, optimizer=optimizer, metrics=[dice_coefficient])

train_gen = DataGenerator(X_train, batch_size=1,  predict=False, shuffle=True)
val_gen = DataGenerator(X_val, batch_size=1,  predict=False, shuffle=True)

model.fit_generator(train_gen, validation_data = val_gen, callbacks = [checkpoint, reduce_lr, stop], epochs=1,  verbose=1)    
以下是我得到的错误:

Use tf.where in 2.0, which has the same broadcast rule as np.where
19/20 [===========================>..] - ETA: 7s - loss: 3.7508 - dice_coefficient: 0.1282 
---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-18-9915a3b43b57> in <module>
      1 model.fit_generator(generator=train_gen, validation_data=val_gen, epochs=1, 
      2                     callbacks = [checkpoint, reduce_lr, stop],
----> 3                     shuffle=True, verbose=1)
      4 

~\Anaconda3\envs\tf_gpu\lib\site-packages\tensorflow\python\keras\engine\training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, validation_freq, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
   1431         shuffle=shuffle,
   1432         initial_epoch=initial_epoch,
-> 1433         steps_name='steps_per_epoch')
   1434 
   1435   def evaluate_generator(self,

~\Anaconda3\envs\tf_gpu\lib\site-packages\tensorflow\python\keras\engine\training_generator.py in model_iteration(model, data, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, validation_freq, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch, mode, batch_size, steps_name, **kwargs)
    320           verbose=0,
    321           mode=ModeKeys.TEST,
--> 322           steps_name='validation_steps')
    323 
    324       if not isinstance(val_results, list):

~\Anaconda3\envs\tf_gpu\lib\site-packages\tensorflow\python\keras\engine\training_generator.py in model_iteration(model, data, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, validation_freq, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch, mode, batch_size, steps_name, **kwargs)
    262 
    263       is_deferred = not model._is_compiled
--> 264       batch_outs = batch_function(*batch_data)
    265       if not isinstance(batch_outs, list):
    266         batch_outs = [batch_outs]

~\Anaconda3\envs\tf_gpu\lib\site-packages\tensorflow\python\keras\engine\training.py in test_on_batch(self, x, y, sample_weight, reset_metrics)
   1245       self._update_sample_weight_modes(sample_weights=sample_weights)
   1246       self._make_test_function()
-> 1247       outputs = self.test_function(inputs)  # pylint: disable=not-callable
   1248 
   1249     if reset_metrics:

~\Anaconda3\envs\tf_gpu\lib\site-packages\tensorflow\python\keras\backend.py in __call__(self, inputs)
   3290 
   3291     fetched = self._callable_fn(*array_vals,
-> 3292                                 run_metadata=self.run_metadata)
   3293     self._call_fetch_callbacks(fetched[-len(self._fetches):])
   3294     output_structure = nest.pack_sequence_as(

~\Anaconda3\envs\tf_gpu\lib\site-packages\tensorflow\python\client\session.py in __call__(self, *args, **kwargs)
   1456         ret = tf_session.TF_SessionRunCallable(self._session._session,
   1457                                                self._handle, args,
-> 1458                                                run_metadata_ptr)
   1459         if run_metadata:
   1460           proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

InvalidArgumentError: You must feed a value for placeholder tensor 'reshape_target' with dtype float and shape [?,?,?]
     [[{{node reshape_target}}]]
在2.0中使用tf.where,它与np.where具有相同的广播规则
19/20[============================>…]-预计到达时间:7秒-损失:3.7508-骰子系数:0.1282
---------------------------------------------------------------------------
InvalidArgumentError回溯(最后一次最近调用)
在里面
1个型号。装配发电机(发电机=列车发电机,验证数据=车辆发电机,历元=1,
2个回调=[检查点,减少,停止],
---->3随机播放=真,详细播放=1)
4.
~\Anaconda3\envs\tf\u gpu\lib\site packages\tensorflow\python\keras\engine\training.py in-fit\u生成器(self、生成器、每个历元的步骤、历元、冗余、回调、验证数据、验证步骤、验证频率、类权重、最大队列大小、工作者、使用多处理、无序、初始历元)
1431洗牌=洗牌,
1432初始纪元=初始纪元,
->1433个步骤(名称=每个时代的步骤)
1434
1435 def\U发生器(自,
模型迭代中的~\Anaconda3\envs\tf\u gpu\lib\site packages\tensorflow\python\keras\engine\training\u generator.py(模型、数据、每个历元的步骤、历元、冗余、回调、验证数据、验证步骤、验证频率、类权重、最大队列大小、工作人员、使用多处理、洗牌、初始历元、模式、批量大小、步骤名称、**kwargs)
320详细=0,
321模式=模式键。测试,
-->322个步骤(name='validation')
323
324如果不存在(val_结果,列表):
模型迭代中的~\Anaconda3\envs\tf\u gpu\lib\site packages\tensorflow\python\keras\engine\training\u generator.py(模型、数据、每个历元的步骤、历元、冗余、回调、验证数据、验证步骤、验证频率、类权重、最大队列大小、工作人员、使用多处理、洗牌、初始历元、模式、批量大小、步骤名称、**kwargs)
262
263延迟=非模型。\已编译
-->264批处理输出=批处理功能(*批处理数据)
265如果不存在(批次,列表):
266批次输出=[批次输出]
批量测试中的~\Anaconda3\envs\tf\u gpu\lib\site packages\tensorflow\python\keras\engine\training.py(自身、x、y、样本重量、重置度量)
1245自我更新样本权重模式(样本权重=样本权重)
1246自我测试功能()
->1247输出=自测试功能(输入)#pylint:disable=不可调用
1248
1249如果重置_度量:
~\Anaconda3\envs\tf\u gpu\lib\site packages\tensorflow\python\keras\backend.py in\uuuu调用(self,输入)
3290
3291 fetched=self.\u callable\u fn(*array\u vals,
->3292运行\u元数据=self.run\u元数据)
3293 self.\u call\u fetch\u callbacks(fetched[-len(self.\u fetches):]))
3294输出_结构=nest.pack_序列_as(
~\Anaconda3\envs\tf\u gpu\lib\site packages\tensorflow\python\client\session.py in\uuuu调用(self,*args,**kwargs)
1456 ret=tf_session.tf_SessionRunCallable(self._session._session,
1457自动控制手柄,args,
->1458运行(元数据)
1459如果运行\u元数据:
1460 proto_data=tf_session.tf_GetBuffer(run_metadata_ptr)
InvalidArgumentError:必须使用dtype float和shape[?,?,?]为占位符张量“重塑目标”输入一个值
[{{node reshape_target}}]]

目前无法使用生成器生成验证数据。请查看

出于验证的目的,我实际上并不认为完全启用生成器的目的是什么,因为它的主要目的是衡量网络在给定时代后的可推广性。如果您将验证数据作为一个整体输入,并将其分成若干批,并限制步数,则网络将非常有效。

不回答此问题是否可以使用生成器进行验证数据?