Python 发布Keras多输入模型的tf.data.Dataset

Python 发布Keras多输入模型的tf.data.Dataset,python,python-3.x,tensorflow,keras,tensorflow2.0,Python,Python 3.x,Tensorflow,Keras,Tensorflow2.0,我用Keras函数API构建了一个多输入模型。其思想是对文本及其元数据进行分类。该模型适用于NumPy格式的输入,但不适用于tf.data.Dataset UnimplementedError: Cast string to int32 is not supported [[node functional_5/Cast (defined at <ipython-input-3-8e2b230c1da3>:17) ]] [Op:__inference_train_funct

我用Keras函数API构建了一个多输入模型。其思想是对文本及其元数据进行分类。该模型适用于NumPy格式的输入,但不适用于tf.data.Dataset

UnimplementedError:  Cast string to int32 is not supported
     [[node functional_5/Cast (defined at <ipython-input-3-8e2b230c1da3>:17) ]] [Op:__inference_train_function_24120]

Function call stack:
train_function
数据集 包含5个文本和各自元数据的虚拟数据集

# input meta
dict_meta = {
    "Organization": [
        ["BNS", "NA"],
        ["ECB", "PAD"],
        ["NA", "PAD"],
        ["NA", "PAD"],
        ["NA", "PAD"],
    ],
    "Sector": [
        ["BANK", "PAD", "PAD"],
        ["ASS", "PAD", "NA"],
        ["MARKET", "NA", "NA"],
        ["NA", "PAD", "NA"],
        ["NA", "PAD", "NA"],
    ],
    "Content_type": [
        ["NOTES", "PAD"],
        ["PAPER", "UNK"],
        ["LAW", "PAD"],
        ["LAW", "PAD"],
        ["LAW", "NOTES"],
    ],
    "Geography": [
        ["UK", "FR"],
        ["DE", "CH"],
        ["US", "ES"],
        ["ES", "PAD"],
        ["NA", "PAD"],
    ],
    "Themes": [["A", "B"], ["B", "C"], ["C", "PAD"], ["C", "PAD"], ["G", "PAD"]],
}

# input text
list_text = [
    "Trump in denial over election defeat as Biden gears up to fight Covid",
    "Feds seize $1 billion in bitcoins they say were stolen from Silk Road",
    "Kevin de Bruyne misses penalty as Manchester City and Liverpool draw",
    "United States nears 10 million coronavirus cases",
    "Fiji resort offers the ultimate in social distancing",
]

tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
params = {
    "max_length": MAX_LEN,
    "padding": "max_length",
    "truncation": True,
}
tokenized = tokenizer(list_text, **params)
dict_text = tokenized.data

#input label
label = [[1], [0], [1], [0], [1]]
以NumPy格式进行培训

ds_meta = tf.data.Dataset.from_tensor_slices((dict_meta))
ds_meta = ds_meta.batch(5)
example_meta = next(iter(ds_meta))

ds_text = tf.data.Dataset.from_tensor_slices((dict_text))
ds_text = ds_text.batch(5)
example_text = next(iter(ds_text))

ds_label = tf.data.Dataset.from_tensor_slices((label))
ds_label = ds_label.batch(5)
example_label = next(iter(ds_label))

model.fit([example_text, example_meta], example_label)
ds = tf.data.Dataset.from_tensor_slices(
    (
        {
            "attention_mask": dict_text["attention_mask"],
            "input_ids": dict_text["input_ids"],
            "Content_type": dict_meta["Organization"],
            "Geography": dict_meta["Geography"],
            "Organization": dict_meta["Organization"],
            "Sector": dict_meta["Sector"],
            "Themes": dict_meta["Themes"],
        },
        {"class_output": label},
    )
)


ds = ds.batch(5)
model.fit(ds, epochs=1)
使用tf.data.Dataset进行培训

ds_meta = tf.data.Dataset.from_tensor_slices((dict_meta))
ds_meta = ds_meta.batch(5)
example_meta = next(iter(ds_meta))

ds_text = tf.data.Dataset.from_tensor_slices((dict_text))
ds_text = ds_text.batch(5)
example_text = next(iter(ds_text))

ds_label = tf.data.Dataset.from_tensor_slices((label))
ds_label = ds_label.batch(5)
example_label = next(iter(ds_label))

model.fit([example_text, example_meta], example_label)
ds = tf.data.Dataset.from_tensor_slices(
    (
        {
            "attention_mask": dict_text["attention_mask"],
            "input_ids": dict_text["input_ids"],
            "Content_type": dict_meta["Organization"],
            "Geography": dict_meta["Geography"],
            "Organization": dict_meta["Organization"],
            "Sector": dict_meta["Sector"],
            "Themes": dict_meta["Themes"],
        },
        {"class_output": label},
    )
)


ds = ds.batch(5)
model.fit(ds, epochs=1)
2020-11-10 14:52:47.502445:W tensorflow/core/framework/op_kernel.cc:1744]op_REQUIRES在cast_op.cc:124:未实现:不支持将字符串转换为int32
回溯(最近一次呼叫最后一次):
文件“”,第1行,在
model.fit(ds,epochs=1)
文件“/opt/miniconda3/envs/tf2/lib/python3.7/site packages/tensorflow/python/keras/engine/training.py”,第108行,在方法包装中
返回方法(self、*args、**kwargs)
文件“/opt/miniconda3/envs/tf2/lib/python3.7/site packages/tensorflow/python/keras/engine/training.py”,第1098行
tmp_logs=训练函数(迭代器)
文件“/opt/miniconda3/envs/tf2/lib/python3.7/site packages/tensorflow/python/eager/def_function.py”,第780行,在调用中__
结果=自身调用(*args,**kwds)
文件“/opt/miniconda3/envs/tf2/lib/python3.7/site packages/tensorflow/python/eager/def_function.py”,第807行,在调用中
返回self._无状态_fn(*args,**kwds)35; pylint:disable=不可调用
文件“/opt/miniconda3/envs/tf2/lib/python3.7/site packages/tensorflow/python/eager/function.py”,第2829行,在__
返回图形\函数。\过滤\调用(args,kwargs)\ pylint:disable=受保护的访问
文件“/opt/miniconda3/envs/tf2/lib/python3.7/site packages/tensorflow/python/eager/function.py”,第1848行,在
取消管理器=取消管理器)
文件“/opt/miniconda3/envs/tf2/lib/python3.7/site packages/tensorflow/python/eager/function.py”,第1924行,位于调用平面中
ctx,args,取消管理器=取消管理器)
调用中的第550行文件“/opt/miniconda3/envs/tf2/lib/python3.7/site packages/tensorflow/python/eager/function.py”
ctx=ctx)
文件“/opt/miniconda3/envs/tf2/lib/python3.7/site packages/tensorflow/python/eager/execute.py”,第60行,在quick_execute中
输入、属性、数量(输出)
UnimplementedError:不支持将字符串强制转换为int32
[node functional_5/Cast(定义于:17)][Op:_推理_train_function_24120]
函数调用堆栈:
列车功能

您可以使用组合数据集。
zip
函数可以将嵌套的数据集作为参数,因此我们只需要用numpy数组重现在fit函数中输入数据的方式:

ds_meta=tf.data.Dataset.from_tensor_切片((dict_meta))
ds_text=tf.data.Dataset.from_tensor_切片((dict_text))
ds_label=tf.data.Dataset.from_tensor_切片((label))
combined_dataset=tf.data.dataset.zip(((ds_文本,ds_元),ds_标签))
组合数据集=组合数据集。批处理(5)
运行它:

>model.fit(组合数据集)
1/1[====================================]-0s 212us/步-损耗:2.2895

谢谢。非常感谢。所以嵌套结构就是答案。
2020-11-10 14:52:47.502445: W tensorflow/core/framework/op_kernel.cc:1744] OP_REQUIRES failed at cast_op.cc:124 : Unimplemented: Cast string to int32 is not supported
Traceback (most recent call last):

  File "<ipython-input-10-a894466398cd>", line 1, in <module>
    model.fit(ds, epochs=1)

  File "/opt/miniconda3/envs/tf2/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 108, in _method_wrapper
    return method(self, *args, **kwargs)

  File "/opt/miniconda3/envs/tf2/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 1098, in fit
    tmp_logs = train_function(iterator)

  File "/opt/miniconda3/envs/tf2/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 780, in __call__
    result = self._call(*args, **kwds)

  File "/opt/miniconda3/envs/tf2/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 807, in _call
    return self._stateless_fn(*args, **kwds)  # pylint: disable=not-callable

  File "/opt/miniconda3/envs/tf2/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 2829, in __call__
    return graph_function._filtered_call(args, kwargs)  # pylint: disable=protected-access

  File "/opt/miniconda3/envs/tf2/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1848, in _filtered_call
    cancellation_manager=cancellation_manager)

  File "/opt/miniconda3/envs/tf2/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1924, in _call_flat
    ctx, args, cancellation_manager=cancellation_manager))

  File "/opt/miniconda3/envs/tf2/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 550, in call
    ctx=ctx)

  File "/opt/miniconda3/envs/tf2/lib/python3.7/site-packages/tensorflow/python/eager/execute.py", line 60, in quick_execute
    inputs, attrs, num_outputs)

UnimplementedError:  Cast string to int32 is not supported
     [[node functional_5/Cast (defined at <ipython-input-3-8e2b230c1da3>:17) ]] [Op:__inference_train_function_24120]

Function call stack:
train_function