Tensorflow TFRecords数据集分布在多个GPU上的Keras模型

Tensorflow TFRecords数据集分布在多个GPU上的Keras模型,tensorflow,keras,multi-gpu,Tensorflow,Keras,Multi Gpu,我试图使用TFRecord数据集作为Keras模型的输入。网络似乎开始训练,但随后我收到一条错误消息。以下是我用来构造和拟合模型的代码: mirrored_strategy = tf.distribute.MirroredStrategy() with mirrored_strategy.scope(): model = keras.models.Sequential() model.add(keras.layers.Conv3D(64, (7,7,7), strides=(2,

我试图使用TFRecord数据集作为Keras模型的输入。网络似乎开始训练,但随后我收到一条错误消息。以下是我用来构造和拟合模型的代码:

mirrored_strategy = tf.distribute.MirroredStrategy()
with mirrored_strategy.scope():
    model = keras.models.Sequential()
    model.add(keras.layers.Conv3D(64, (7,7,7), strides=(2,2,2), padding="same",
                                  use_bias=False, input_shape=[91,109,91,1]))
    model.add(keras.layers.BatchNormalization())
    model.add(keras.layers.Activation("relu"))
    model.add(keras.layers.MaxPool3D(pool_size=(3,3,3), strides=(2,2,2), padding="same"))
    prev_filters = 64
    for filters in [64] * 3 + [128] * 4 + [256] * 6 + [512] * 3:
        strides = 1 if filters == prev_filters else 2
        model.add(ResidualUnit(filters, strides=strides))
        prev_filters = filters

    model.add(keras.layers.GlobalAveragePooling3D())
    model.add(keras.layers.Flatten())
    model.add(keras.layers.Dense(1, activation="sigmoid"))

    model = keras.utils.multi_gpu_model(model, gpus=2)

    es = keras.callbacks.EarlyStopping(monitor='val_accuracy', mode='auto', patience=20)
    mc = keras.callbacks.ModelCheckpoint('best_model.h5', monitor='val_accuracy', mode='max', save_best_only=True)
    tb = keras.callbacks.TensorBoard(log_dir='./logs', write_images=True, write_graph=True)

    model.compile(loss="binary_crossentropy",
                  optimizer=keras.optimizers.Adam(learning_rate=0.001),
                  metrics=['accuracy'])


    training_set = train_input_fn('train.tfrecords', batch_size=BATCH_SIZE, num_epochs=N_EPOCHS)
    validation_set = validation_input_fn('test.tfrecords', batch_size=BATCH_SIZE)

    history = model.fit(training_set, steps_per_epoch=STEPS_PER_EPOCH_TRAINING,
                        epochs=N_EPOCHS, validation_data=validation_set,
                        validation_steps=STEPS_PER_EPOCH_VALIDATION, callbacks=[tb, es, mc])
这是输出(没有关于分发的标准输出):

训练15步,验证3步
警告:在标记解析转到stderr之前进行日志记录。
W1028 12:25:08.884782 140602772264768摘要_ops_v2.py:1110]模型未能序列化为JSON。忽略。。。在“%uuuu init”中具有参数的层必须重写“get\u config”。
纪元1/1000
2019-10-28 12:25:29.438754:I tensorflow/stream_executor/platform/default/dso_loader.cc:44]成功打开动态库libcublas.so.10.0
2019-10-28 12:25:30.317121:I tensorflow/stream_executor/platform/default/dso_loader.cc:44]成功打开动态库libcudnn.so.7
2019-10-28 12:25:34.050879:I tensorflow/core/profiler/lib/profiler_会话。cc:184]profiler会话已开始。
2019-10-28 12:25:34.053985:I tensorflow/stream_executor/platform/default/dso_loader.cc:44]已成功打开动态库libcupti.so.10.0
1/15[=>………]ETA:6:04-损失:0.6008-准确度:0.7502019-10-28 12:25:35.924105:I tensorflow/core/platform/default/device_tracer.cc:588]收集6788条内核记录,994条memcpy记录。
W1028 12:25:39.252643 140602772264768回调。py:244]方法(在批处理结束时)比批处理更新(2.169159)慢。检查你的回电。
2/15[==>………]ETA:3:17-损失:0.7089-准确度:0.7969W1028 12:25:39.688849 140602772264768回调。py:244]方法(在列批处理结束时)比批处理更新(0.885054)慢。检查你的回电。
3/15[==>……]ETA:2:03-损失:0.7249-准确度:0.8333W1028 12:25:40.090722 140602772264768回调。py:244]方法(在列批处理端)比批处理更新(0.442553)慢。检查你的回电。
4/15[=======>…]ETA:1:25-损失:0.7012-准确度:0.77342019-10-28 12:25:40.205127:W tensorflow/core/framework/op_kernel.cc:1622]op_REQUIRES在示例解析操作时失败。cc:240:无效参数:Key:train/image。无法分析序列化的示例。
7/15[================================================ETA:37s-损失:0.6582-准确度:0.75452019-10-28 12:25:41.299845:W tensorflow/core/common\u runtime/base\u collective\u executor.cc:216]BaseCollectiveExecutor::StartPort无效参数:{函数节点推理数据集映射解析记录}键:train/image。无法分析序列化的示例。
[{{node ParseSingleExample/ParseSingleExample}}]]
[[MultiDeviceInteratorGetNextFromShard]]
[[RemoteCall]]
[[IteratorGetNextAsOptional]]
[[复制品1/填充物32/填充物947]]
2019-10-28 12:25:41.299846:W tensorflow/core/common_runtime/base_collective_executor.cc:216]BaseCollectiveExecutor::StartAbort无效参数:{{函数节点{推理}数据集}映射}解析}键:train/image。无法分析序列化的示例。
[{{node ParseSingleExample/ParseSingleExample}}]]
[[MultiDeviceInteratorGetNextFromShard]]
[[RemoteCall]]
[[IteratorGetNextAsOptional]]
[[OptionalHasValue/_10]]
2019-10-28 12:25:41.300453:W tensorflow/core/common_runtime/base_collective_executor.cc:216]BaseCollectiveExecutor::StartAbort无效参数:{{函数节点{推理}数据集{映射}解析}记录}键:train/image。无法分析序列化的示例。
[{{node ParseSingleExample/ParseSingleExample}}]]
[[MultiDeviceInteratorGetNextFromShard]]
[[RemoteCall]]
[[IteratorGetNextAsOptional]]
W1028 12:25:41.317599 140602772264768回调。py:1250]提前停止取决于度量“val_精度”,该精度不可用。可用指标包括:损失、准确度
W1028 12:25:41.317952 140602772264768回调。py:990]只能在val_精度可用的情况下保存最佳模型,跳过。
回溯(最近一次呼叫最后一次):
文件“DAT_resnet34.py”,第80行,在
验证\u步骤=每个\u历元的步骤\u验证,回调=[tb,es,mc])
文件“/home/willi3by/.local/lib/python3.7/site packages/tensorflow_core/python/keras/engine/training.py”,第728行
使用多处理=使用多处理)
文件“/home/willi3by/.local/lib/python3.7/site-packages/tensorflow_-core/python/keras/engine/training_-v2.py”,第324行,格式为fit
总(单位时间=时间)
文件“/home/willi3by/.local/lib/python3.7/site packages/tensorflow\u core/python/keras/engine/training\u v2.py”,第123行,在run\u one\u中
批处理输出=执行函数(迭代器)
文件“/home/willi3by/.local/lib/python3.7/site packages/tensorflow\u core/python/keras/engine/training\u v2\u utils.py”,第86行,执行函数
分布函数(输入函数)
文件“/home/willi3by/.local/lib/python3.7/site packages/tensorflow_core/python/eager/def_function.py”,第457行,在调用中__
结果=自身调用(*args,**kwds)
文件“/home/willi3by/.local/lib/python3.7/site packages/tensorflow_core/python/eager/def_function.py”,第487行,在调用中
返回self._无状态_fn(*args,**kwds)35; pylint:disable=不可调用
文件“/home/willi3by/.local/lib/python3.7/site packages/tensorflow\u core/python/eager/function.py”,第1823行,在调用中__
返回图形\函数。\过滤\调用(args,kwargs)\ pylint:disable=受保护的访问
文件“/home/willi3by/.local/lib/python3.7/site packages/tensorflow\u core/python/eager/function.py”,第1141行,在过滤调用中
自捕获(U输入)
文件“/home/willi3by/.local/lib/python3.7/site packages/tensorflow\u core/python/eager/function.py”,第1224行,位于调用平面中
ctx、args、取消管理器=取消管理器)
文件“/home/willi3by/.local/lib/python3.7/site packages/tensorflow_core/python/eager/function.py”,第511行,在调用中
ctx=ctx)
文件“/home/willi3by/.local/lib/python3.7/site packages/tensorflow_core/python/eager/execute.py”,第67行,快速
Train for 15 steps, validate for 3 steps
WARNING: Logging before flag parsing goes to stderr.
W1028 12:25:08.884782 140602772264768 summary_ops_v2.py:1110] Model failed to serialize as JSON. Ignoring... Layers with arguments in `__init__` must override `get_config`.
Epoch 1/1000
2019-10-28 12:25:29.438754: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcublas.so.10.0
2019-10-28 12:25:30.317121: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcudnn.so.7
2019-10-28 12:25:34.050879: I tensorflow/core/profiler/lib/profiler_session.cc:184] Profiler session started.
2019-10-28 12:25:34.053985: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcupti.so.10.0
 1/15 [=>............................] - ETA: 6:04 - loss: 0.6008 - accuracy: 0.75002019-10-28 12:25:35.924105: I tensorflow/core/platform/default/device_tracer.cc:588] Collecting 6788 kernel records, 994 memcpy records.
W1028 12:25:39.252643 140602772264768 callbacks.py:244] Method (on_train_batch_end) is slow compared to the batch update (2.169159). Check your callbacks.
 2/15 [===>..........................] - ETA: 3:17 - loss: 0.7089 - accuracy: 0.7969W1028 12:25:39.688849 140602772264768 callbacks.py:244] Method (on_train_batch_end) is slow compared to the batch update (0.885054). Check your callbacks.
 3/15 [=====>........................] - ETA: 2:03 - loss: 0.7249 - accuracy: 0.8333W1028 12:25:40.090722 140602772264768 callbacks.py:244] Method (on_train_batch_end) is slow compared to the batch update (0.442553). Check your callbacks.
 4/15 [=======>......................] - ETA: 1:25 - loss: 0.7012 - accuracy: 0.77342019-10-28 12:25:40.205127: W tensorflow/core/framework/op_kernel.cc:1622] OP_REQUIRES failed at example_parsing_ops.cc:240 : Invalid argument: Key: train/image.  Can't parse serialized Example.
 7/15 [=============>................] - ETA: 37s - loss: 0.6582 - accuracy: 0.75452019-10-28 12:25:41.299845: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Invalid argument: {{function_node __inference_Dataset_map_parse_record_6895}} Key: train/image.  Can't parse serialized Example.
     [[{{node ParseSingleExample/ParseSingleExample}}]]
     [[MultiDeviceIteratorGetNextFromShard]]
     [[RemoteCall]]
     [[IteratorGetNextAsOptional]]
     [[replica_1/Fill_32/_947]]
2019-10-28 12:25:41.299846: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Invalid argument: {{function_node __inference_Dataset_map_parse_record_6895}} Key: train/image.  Can't parse serialized Example.
     [[{{node ParseSingleExample/ParseSingleExample}}]]
     [[MultiDeviceIteratorGetNextFromShard]]
     [[RemoteCall]]
     [[IteratorGetNextAsOptional]]
     [[OptionalHasValue/_10]]
2019-10-28 12:25:41.300453: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Invalid argument: {{function_node __inference_Dataset_map_parse_record_6895}} Key: train/image.  Can't parse serialized Example.
     [[{{node ParseSingleExample/ParseSingleExample}}]]
     [[MultiDeviceIteratorGetNextFromShard]]
     [[RemoteCall]]
     [[IteratorGetNextAsOptional]]
W1028 12:25:41.317599 140602772264768 callbacks.py:1250] Early stopping conditioned on metric `val_accuracy` which is not available. Available metrics are: loss,accuracy
W1028 12:25:41.317952 140602772264768 callbacks.py:990] Can save best model only with val_accuracy available, skipping.
Traceback (most recent call last):
  File "DAT_resnet34.py", line 80, in <module>
    validation_steps=STEPS_PER_EPOCH_VALIDATION, callbacks=[tb, es, mc])
  File "/home/willi3by/.local/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py", line 728, in fit
    use_multiprocessing=use_multiprocessing)
  File "/home/willi3by/.local/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py", line 324, in fit
    total_epochs=epochs)
  File "/home/willi3by/.local/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py", line 123, in run_one_epoch
    batch_outs = execution_function(iterator)
  File "/home/willi3by/.local/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2_utils.py", line 86, in execution_function
    distributed_function(input_fn))
  File "/home/willi3by/.local/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py", line 457, in __call__
    result = self._call(*args, **kwds)
  File "/home/willi3by/.local/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py", line 487, in _call
    return self._stateless_fn(*args, **kwds)  # pylint: disable=not-callable
  File "/home/willi3by/.local/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py", line 1823, in __call__
    return graph_function._filtered_call(args, kwargs)  # pylint: disable=protected-access
  File "/home/willi3by/.local/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py", line 1141, in _filtered_call
    self.captured_inputs)
  File "/home/willi3by/.local/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py", line 1224, in _call_flat
    ctx, args, cancellation_manager=cancellation_manager)
  File "/home/willi3by/.local/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py", line 511, in call
    ctx=ctx)
  File "/home/willi3by/.local/lib/python3.7/site-packages/tensorflow_core/python/eager/execute.py", line 67, in quick_execute
    six.raise_from(core._status_to_exception(e.code, message), None)
  File "<string>", line 3, in raise_from
tensorflow.python.framework.errors_impl.InvalidArgumentError: 3 root error(s) found.
  (0) Invalid argument:   Key: train/image.  Can't parse serialized Example.
     [[{{node ParseSingleExample/ParseSingleExample}}]]
     [[MultiDeviceIteratorGetNextFromShard]]
     [[RemoteCall]]
     [[IteratorGetNextAsOptional]]
     [[OptionalHasValue/_10]]
  (1) Invalid argument:   Key: train/image.  Can't parse serialized Example.
     [[{{node ParseSingleExample/ParseSingleExample}}]]
     [[MultiDeviceIteratorGetNextFromShard]]
     [[RemoteCall]]
     [[IteratorGetNextAsOptional]]
     [[replica_1/Fill_32/_947]]
  (2) Invalid argument:   Key: train/image.  Can't parse serialized Example.
     [[{{node ParseSingleExample/ParseSingleExample}}]]
     [[MultiDeviceIteratorGetNextFromShard]]
     [[RemoteCall]]
     [[IteratorGetNextAsOptional]]
0 successful operations.
0 derived errors ignored. [Op:__inference_distributed_function_40587]

Function call stack:
distributed_function -> distributed_function -> distributed_function -> distributed_function -> distributed_function -> distributed_function
Layers with arguments in `__init__` must override `get_config`.