Tensorflow 由于TF2中的tfVariable出现问题,简单RNN模型无法工作

Tensorflow 由于TF2中的tfVariable出现问题,简单RNN模型无法工作,tensorflow,Tensorflow,我试图建立一个简单的模型,并保存未经训练的图层。(我以后会想训练它)。我正在尝试使用tensorflow核心API,而不依赖Keras层,这样我可以更直接地控制我使用的内容,并最大限度地提高与TFLite的兼容性 import numpy as np import tensorflow as tf class BasicModel(tf.Module): def __init__(self): self.const = None @tf.function(in

我试图建立一个简单的模型,并保存未经训练的图层。(我以后会想训练它)。我正在尝试使用tensorflow核心API,而不依赖Keras层,这样我可以更直接地控制我使用的内容,并最大限度地提高与TFLite的兼容性

import numpy as np
import tensorflow as tf

class BasicModel(tf.Module):
    def __init__(self):
        self.const = None

    @tf.function(input_signature=[
            tf.TensorSpec(shape=[None,20],dtype=tf.int32),
    ])
    def rnn(self, captions):
        # ENCODER
        weights = tf.Variable(tf.random.normal([10000, 724]))#, shape=[vocab_size,embedding_dimension], name="embedding_weights")
        embedding_output = tf.nn.embedding_lookup(weights,captions)
        #activation is tanh for GRUCell
        sequence = tf.unstack(embedding_output,num=20, axis=1) 
        cell = tf.compat.v1.nn.rnn_cell.GRUCell(20)
        print(sequence)
        gru_layer = tf.compat.v1.nn.static_rnn(cell, sequence, dtype=tf.float32)
        return gru_layer

root = BasicModel()
concrete_function = root.rnn.get_concrete_function()
tf.saved_model.save(root,"model",concrete_function)
我希望有一个未经培训的模型可以保存,但我得到了一个错误:

Traceback (most recent call last):
  File "model_tensorflow_2.py", line 24, in <module>
    concrete_function = root.rnn.get_concrete_function()#tf.constant(images), tf.constant(captions), tf.constant(cap_lens))
  File "/Users/t.capes/miniconda3/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py", line 782, in get_concrete_function
    return self._stateless_fn.get_concrete_function(*args, **kwargs)
  File "/Users/t.capes/miniconda3/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py", line 1891, in get_concrete_function
    graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
  File "/Users/t.capes/miniconda3/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py", line 2150, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/Users/t.capes/miniconda3/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py", line 2041, in _create_graph_function
    capture_by_value=self._capture_by_value),
  File "/Users/t.capes/miniconda3/lib/python3.7/site-packages/tensorflow_core/python/framework/func_graph.py", line 915, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/Users/t.capes/miniconda3/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py", line 358, in wrapped_fn
    return weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/Users/t.capes/miniconda3/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py", line 2658, in bound_method_wrapper
    return wrapped_fn(*args, **kwargs)
  File "/Users/t.capes/miniconda3/lib/python3.7/site-packages/tensorflow_core/python/framework/func_graph.py", line 905, in wrapper
    raise e.ag_error_metadata.to_exception(e)
ValueError: in converted code:

    model_tensorflow_2.py:13 rnn  *
        weights = tf.Variable(tf.random.normal([10000, 724]))#, shape=[vocab_size,embedding_dimension], name="embedding_weights")
    /Users/t.capes/miniconda3/lib/python3.7/site-packages/tensorflow_core/python/ops/variables.py:260 __call__
        return cls._variable_v2_call(*args, **kwargs)
    /Users/t.capes/miniconda3/lib/python3.7/site-packages/tensorflow_core/python/ops/variables.py:254 _variable_v2_call
        shape=shape)
    /Users/t.capes/miniconda3/lib/python3.7/site-packages/tensorflow_core/python/ops/variables.py:65 getter
        return captured_getter(captured_previous, **kwargs)
    /Users/t.capes/miniconda3/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py:413 invalid_creator_scope
        "tf.function-decorated function tried to create "

    ValueError: tf.function-decorated function tried to create variables on non-first call.
回溯(最近一次呼叫最后一次):
文件“model_tensorflow_2.py”,第24行,中
concrete_function=root.rnn.get_concrete_function()
文件“/Users/t.capes/miniconda3/lib/python3.7/site packages/tensorflow_core/python/eager/def_function.py”,第782行,在get_concrete_函数中
返回self.\u无状态\u fn.get\u具体函数(*args,**kwargs)
文件“/Users/t.capes/miniconda3/lib/python3.7/site packages/tensorflow\u core/python/eager/function.py”,第1891行,在get\u concrete\u函数中
图形函数,args,kwargs=self.\u可能定义函数(args,kwargs)
文件“/Users/t.capes/miniconda3/lib/python3.7/site packages/tensorflow\u core/python/eager/function.py”,第2150行,在定义函数中
graph\u function=self.\u create\u graph\u function(args,kwargs)
文件“/Users/t.capes/miniconda3/lib/python3.7/site packages/tensorflow\u core/python/eager/function.py”,第2041行,在“创建图”函数中
按值捕获=自身。_按值捕获),
文件“/Users/t.capes/miniconda3/lib/python3.7/site packages/tensorflow_core/python/framework/func_graph.py”,第915行,func_graph_from_py_func
func_outputs=python_func(*func_args,**func_kwargs)
文件“/Users/t.capes/miniconda3/lib/python3.7/site packages/tensorflow_core/python/eager/def_function.py”,第358行,包装为
返回弱_-wrapped_-fn()
文件“/Users/t.capes/miniconda3/lib/python3.7/site packages/tensorflow\u core/python/eager/function.py”,第2658行,在绑定方法包装中
退货包装单(*args,**kwargs)
文件“/Users/t.capes/miniconda3/lib/python3.7/site packages/tensorflow_core/python/framework/func_graph.py”,第905行,在包装器中
将e.ag\u错误\u元数据引发到\u异常(e)
ValueError:在转换的代码中:
模型张流2.py:13 rnn*
权重=tf.Variable(tf.random.normal([10000,724]))#,shape=[vocab_size,embedding_dimension],name=“embedding_weights”)
/Users/t.capes/miniconda3/lib/python3.7/site packages/tensorflow\u core/python/ops/variables.py:260\u调用__
返回cls.\u变量\u v2\u调用(*args,**kwargs)
/Users/t.capes/miniconda3/lib/python3.7/site packages/tensorflow\u core/python/ops/variables.py:254\u variable\u v2\u调用
形状=形状)
/Users/t.capes/miniconda3/lib/python3.7/site packages/tensorflow_core/python/ops/variables.py:65 getter
返回捕获的获取者(捕获的前一个,**kwargs)
/Users/t.capes/miniconda3/lib/python3.7/site packages/tensorflow_core/python/eager/def_function.py:413无效的\u创建者\u范围
“tf.function-尝试创建的装饰函数”
ValueError:tf.function-decorated函数试图在非第一次调用时创建变量。

tf.函数
不允许在非首次调用时创建变量,因为其语义不清楚:是否应在每次调用时重新创建变量?是否应该隐式缓存它们?(参见2019年tf峰会的“tf.function和签名”演讲)

一种常见的解决方法是使用帮助函数创建变量,并确保每个实例最多调用一次:

class BasicModel(tf.Module):
    # ...

    def _create_parameters(self, ...):
        self._weights = tf.Variable(...)
        self._parameters_created = True

    def rnn(self, ...):
        if not self._parameters_created:
            self._create_parameters(...)
        ...

你可能想考虑使用十四行诗2。code>snt.Module只是
tf.Module
之上的一个精简抽象,它添加了自动名称范围和变量跟踪。十四行诗2还附带了一些内置模块,包括RNN。有关更多详细信息,请参阅。