Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/python/356.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181

Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/cmake/2.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Python 带Bert(huggingface)分类器的tf.keras模型_Python_Tensorflow2.0_Huggingface Transformers - Fatal编程技术网

Python 带Bert(huggingface)分类器的tf.keras模型

Python 带Bert(huggingface)分类器的tf.keras模型,python,tensorflow2.0,huggingface-transformers,Python,Tensorflow2.0,Huggingface Transformers,我正在训练一个使用Bert(huggingface)的二进制分类器。模型如下所示: def get_model(lr=0.00001): inp_bert = Input(shape=(512), dtype="int32") bert = TFBertModel.from_pretrained('bert-base-multilingual-cased')(inp_bert)[0] doc_encodings = tf.squeeze(bert[:, 0:1, :],

我正在训练一个使用Bert(huggingface)的二进制分类器。模型如下所示:

def get_model(lr=0.00001):
    inp_bert = Input(shape=(512), dtype="int32")
    bert = TFBertModel.from_pretrained('bert-base-multilingual-cased')(inp_bert)[0]
    doc_encodings = tf.squeeze(bert[:, 0:1, :], axis=1)
    out = Dense(1, activation="sigmoid")(doc_encodings)
    model = Model(inp_bert, out)
    adam = optimizers.Adam(lr=lr)
    model.compile(optimizer=adam, loss="binary_crossentropy", metrics=["accuracy"])
    return model
在对分类任务进行微调之后,我想保存模型

model.save("best_model.h5")
但是,这会引发一个NotImplementedError:

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
<ipython-input-55-8c5545f0cd9b> in <module>()
----> 1 model.save("best_spam.h5")
      2 # import transformers

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in save(self, filepath, overwrite, include_optimizer, save_format, signatures, options)
    973     """
    974     saving.save_model(self, filepath, overwrite, include_optimizer, save_format,
--> 975                       signatures, options)
    976 
    977   def save_weights(self, filepath, overwrite=True, save_format=None):

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/saving/save.py in save_model(model, filepath, overwrite, include_optimizer, save_format, signatures, options)
    110           'or using `save_weights`.')
    111     hdf5_format.save_model_to_hdf5(
--> 112         model, filepath, overwrite, include_optimizer)
    113   else:
    114     saved_model_save.save(model, filepath, overwrite, include_optimizer,

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/saving/hdf5_format.py in save_model_to_hdf5(model, filepath, overwrite, include_optimizer)
     97 
     98   try:
---> 99     model_metadata = saving_utils.model_metadata(model, include_optimizer)
    100     for k, v in model_metadata.items():
    101       if isinstance(v, (dict, list, tuple)):

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/saving/saving_utils.py in model_metadata(model, include_optimizer, require_config)
    163   except NotImplementedError as e:
    164     if require_config:
--> 165       raise e
    166 
    167   metadata = dict(

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/saving/saving_utils.py in model_metadata(model, include_optimizer, require_config)
    160   model_config = {'class_name': model.__class__.__name__}
    161   try:
--> 162     model_config['config'] = model.get_config()
    163   except NotImplementedError as e:
    164     if require_config:

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in get_config(self)
    885     if not self._is_graph_network:
    886       raise NotImplementedError
--> 887     return copy.deepcopy(get_network_config(self))
    888 
    889   @classmethod

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in get_network_config(network, serialize_layer_fn)
   1940           filtered_inbound_nodes.append(node_data)
   1941 
-> 1942     layer_config = serialize_layer_fn(layer)
   1943     layer_config['name'] = layer.name
   1944     layer_config['inbound_nodes'] = filtered_inbound_nodes

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/utils/generic_utils.py in serialize_keras_object(instance)
    138   if hasattr(instance, 'get_config'):
    139     return serialize_keras_class_and_config(instance.__class__.__name__,
--> 140                                             instance.get_config())
    141   if hasattr(instance, '__name__'):
    142     return instance.__name__

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in get_config(self)
    884   def get_config(self):
    885     if not self._is_graph_network:
--> 886       raise NotImplementedError
    887     return copy.deepcopy(get_network_config(self))
    888 

NotImplementedError: 
---------------------------------------------------------------------------
NotImplementedError回溯(最后一次调用)
在()
---->1.保存模式(“最佳垃圾邮件.h5”)
2#进口变压器
保存中的~/anaconda3/envs/tensorflow\u p36/lib/python3.6/site-packages/tensorflow\u core/python/keras/engine/network.py(self、filepath、overwrite、include\u优化器、保存格式、签名、选项)
973     """
974保存。保存\u模型(自我、文件路径、覆盖、包含\u优化器、保存\u格式、,
-->975(签名、选项)
976
977 def保存权重(self、filepath、overwrite=True、save_format=None):
保存模型中的~/anaconda3/envs/tensorflow\u p36/lib/python3.6/site-packages/tensorflow\u core/python/keras/saving/save.py(模型、文件路径、覆盖、包括优化器、保存格式、签名、选项)
110'或使用“保存权重”。)
111 hdf5\格式。将\模型\保存到\ hdf5(
-->112型号,文件路径,覆盖,包括(优化器)
113其他:
114保存的\u模型\u保存。保存(模型、文件路径、覆盖、包含\u优化器、,
将模型保存到hdf5中的~/anaconda3/envs/tensorflow\u p36/lib/python3.6/site-packages/tensorflow\u core/python/keras/saving/hdf5\u format.py(模型、文件路径、覆盖、包含优化器)
97
98尝试:
--->99 model\u metadata=保存utils.model\u元数据(model,include\u优化器)
100表示模型_元数据中的k,v.items():
101如果存在(v,(dict,list,tuple)):
模型元数据中的~/anaconda3/envs/tensorflow\u p36/lib/python3.6/site-packages/tensorflow\u core/python/keras/saving/saving\u utils.py(模型,包括优化器,需要配置)
163除未实施的错误外,错误为e:
164如果需要配置:
-->165升e
166
167元数据=dict(
模型元数据中的~/anaconda3/envs/tensorflow\u p36/lib/python3.6/site-packages/tensorflow\u core/python/keras/saving/saving\u utils.py(模型,包括优化器,需要配置)
160 model_config={'class_name':model.\u class_.\u name_}
161尝试:
-->162 model_config['config']=model.get_config()
163除未实施的错误外,错误为e:
164如果需要配置:
获取配置(self)中的~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py
885如果不是自。\是\图形\网络:
886升起未执行错误
-->887返回副本.deepcopy(获取网络配置(自))
888
889@classmethod
获取网络配置中的~/anaconda3/envs/tensorflow\u p36/lib/python3.6/site-packages/tensorflow\u core/python/keras/engine/network.py(网络,序列化层)
1940筛选的\u入站\u节点。追加(节点\u数据)
1941
->1942层\配置=序列化\层\ fn(层)
1943 layer_config['name']=layer.name
1944层\u配置['inbound\u nodes']=过滤的\u inbound\u nodes
序列化对象(实例)中的~/anaconda3/envs/tensorflow\u p36/lib/python3.6/site-packages/tensorflow\u core/python/keras/utils/generic\u utils.py
138如果hasattr(实例“get_config”):
139返回序列化\u keras\u class\u和\u config(实例.\uuuuuuu class\uuuuuuuu.\uuuuuuuuuu名称\uuuuuuuuu,
-->140实例。获取_config())
141如果hasattr(实例“名称”):
142返回实例。\u\u名称__
获取配置(self)中的~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py
884 def get_配置(自身):
885如果不是自。\是\图形\网络:
-->886升起未执行错误
887返回副本.deepcopy(获取网络配置(自))
888
未实现错误:

我知道huggingface为TFBertModel提供了一个model.save_pretrained()方法,但我更喜欢将其包装在tf.keras.model中,因为我计划向该网络添加其他组件/功能。有人能提出保存当前模型的解决方案吗?

这确实是tensorflow 2.0的问题

请使用:
model.save(“model\u name”,save\u format='tf')


或者,您也可以尝试升级或降级tensorflow。

从tensorflow的GIT页面上的一些讨论中,我认为这是tensorflow 2.0的问题,请尝试升级/降级tensorflow。此外,
model.save(“model\u name”,save\u format='tf')
应该是workmodel.save(“model\u name”,save\u format='tf'))已经解决了我的问题。谢谢!如果你把你的评论作为回答,我会接受的。