Tensorflow 2.0-LSTM状态和输入大小

Tensorflow 2.0-LSTM状态和输入大小,tensorflow,machine-learning,lstm,tensorflow2.0,recurrent-neural-network,Tensorflow,Machine Learning,Lstm,Tensorflow2.0,Recurrent Neural Network,对于强化学习中的一个特定问题(灵感来源于),我使用了一个RNN,它为L个数据点提供形状数据(批量大小、时间步长、特征)=(1,1,1),然后“循环”结束;使用LSTM单元。我使用的是lstm.stateful=True,在L向网络馈送后,我调用lstm.reset_states() 在一个循环和另一个循环之间,并且在调用lstm.reset_states()之后,我想在形状的输入数据上评估网络的输出(批大小、时间步长、特征)=(L,1,1);然后继续使用输入为batch_size=1的RNN 此

对于强化学习中的一个特定问题(灵感来源于),我使用了一个RNN,它为L个数据点提供形状数据(批量大小、时间步长、特征)=(1,1,1),然后“循环”结束;使用LSTM单元。我使用的是lstm.stateful=True,在L向网络馈送后,我调用lstm.reset_states()

在一个循环和另一个循环之间,并且在调用lstm.reset_states()之后,我想在形状的输入数据上评估网络的输出(批大小、时间步长、特征)=(L,1,1);然后继续使用输入为batch_size=1的RNN

此外,我希望代码尽可能优化,为此,我尝试通过@tf.function decorators使用AutoGraph

问题是我遇到了一个错误,可以用下面的例子重新创建(注意,如果@tf.function被删除,一切正常,我不明白为什么?)

将tensorflow导入为tf
将numpy作为np导入
班级演员(tf.keras.Model):
定义初始化(自):
超级(演员,自我)。\uuuuu init\uuuuuuu()
self.lstm=tf.keras.layers.lstm(5,返回_序列=True,有状态=True,输入_形状=(无,无,1))#,输入_形状=(无,无,1))
def呼叫(自我,输入):
feat=self.lstm(输入)
回击壮举
actor=actor()
@功能
def g(演员):
context1=tf.重塑(np.数组([0.]*10)、(10,1,1))
演员(上下文1)
actor.reset_states()
actor.lstm.stateful=False
context=tf.reforme(np.array([0.]),(1,1,1))
演员(背景)
g(演员)
---------------------------------------------------------------------------
ValueError回溯(最近一次调用上次)
在里面
23演员(背景)
24
--->25克(演员)
~/.local/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py在调用中(self,*args,**kwds)
578 xla_context.Exit()
579其他:
-->580结果=自调用(*args,**kwds)
581
582如果跟踪计数==self.\u获取跟踪计数():
调用中的~/.local/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py(self,*args,**kwds)
625#这是u call u的第一个调用,因此我们必须初始化。
626初始值设定项=[]
-->627自我初始化(参数、KWD、添加初始化器到=初始化器)
628最后:
629#此时我们知道初始化已完成(或更少)
初始化中的~/.local/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py(self、args、kwds、add_initializers_to)
504自身.\u具体的\u有状态的\u fn=(
505 self._stateful_fn._get_concrete_function_internal_garbage_collected(#pylint:disable=受保护的访问
-->506*args,**科威特第纳尔)
507
508 def无效的创建者范围(*未使用的参数,**未使用的参数):
~/.local/lib/python3.6/site-packages/tensorflow/python/eager/function.py in\u get\u concrete\u function\u internal\u garbage\u collected(self,*args,**kwargs)
2444 args,kwargs=None,None
2445带自锁:
->2446图形函数,u,u=self._可能定义函数(args,kwargs)
2447返回图函数
2448
~/.local/lib/python3.6/site-packages/tensorflow/python/eager/function.py in\u maybe\u define\u function(self、args、kwargs)
2775
2776 self.\u function\u cache.missed.add(调用上下文键)
->2777图形函数=自身。创建图形函数(args、kwargs)
2778 self.\u function\u cache.primary[cache\u key]=图形函数
2779返回图_函数,args,kwargs
创建图形函数中的~/.local/lib/python3.6/site-packages/tensorflow/python/eager/function.py(self、args、kwargs、override\u flat\u arg\u形状)
2665 arg_name=arg_name,
2666覆盖平面形状=覆盖平面形状,
->2667按值捕获=自身。_按值捕获),
2668自我功能属性,
2669#告诉concrete函数在退出时清理其图形
~/.local/lib/python3.6/site-packages/tensorflow/python/framework/func\u graph.py in func\u graph\u from\u py func(名称、python\u func、args、kwargs、签名、func\u图、autograph、autograph\u选项、添加控制依赖项、arg\u名称、op\u返回值、集合、按值捕获、覆盖平面arg\u形状)
979,original\u func=tf\u decorator.unwrap(python\u func)
980
-->981 func_outputs=python_func(*func_args,**func_kwargs)
982
983#不变量:`func_outputs`只包含张量、复合传感器、,
包装中的~/.local/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py(*args,**kwds)
439#uuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuu。我们给予
440#函数对自身进行弱引用以避免引用循环。
-->441返回弱_-wrapped_-fn()
442弱包裹的=weakref.ref(包裹的)
443
包装器中的~/.local/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py(*args,**kwargs)
966例外情况为e:#pylint:disable=broad Exception
967如果hasattr(e,“ag\u错误\u元数据”):
-->968将e.ag\u错误\u元数据引发到\u异常(e)
969其他:
970加薪
ValueError:在用户代码中:
:23克*
演员(背景)
:11电话*
feat=self.lstm(输入)
/home/cooper-cooper/.local/lib/python3.6/site-packages/tensorflow/python/keras/layers/recurrent.py:654(调用)**
返回super(RNN,self).\u调用\u(输入,**kwargs)
/home/cooper-cooper/.local/lib/python3.6/site-packages/tensorflow/python/keras/engine/base\u-layer.py:886\u调用__
(姓名)
/home/cooper-cooper/.local/lib/python3.6/site-packages/tensorflow/python/keras/engine/input\u spec.py:227断言\u输入\u兼容性
“,找到的形状=”+str(形状))
ValueError:输入0不完整
import tensorflow as tf
import numpy as np


class Actor(tf.keras.Model):
    def __init__(self):
        super(Actor,self).__init__()
        self.lstm = tf.keras.layers.LSTM(5, return_sequences=True, stateful=True, input_shape=(None,None,1))#, input_shape=(None,None,1))

    def call(self, inputs):
        feat= self.lstm(inputs)
        return feat

actor = Actor()

@tf.function
def g(actor):
    context1 = tf.reshape(np.array([0.]*10),(10,1,1))
    actor(context1)
    actor.reset_states()
    actor.lstm.stateful=False
    context = tf.reshape(np.array([0.]),(1,1,1))
    actor(context)

g(actor)    



---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-28-4487772bee64> in <module>
     23     actor(context)
     24 
---> 25 g(actor)

~/.local/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
    578         xla_context.Exit()
    579     else:
--> 580       result = self._call(*args, **kwds)
    581 
    582     if tracing_count == self._get_tracing_count():

~/.local/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
    625       # This is the first call of __call__, so we have to initialize.
    626       initializers = []
--> 627       self._initialize(args, kwds, add_initializers_to=initializers)
    628     finally:
    629       # At this point we know that the initialization is complete (or less

~/.local/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
    504     self._concrete_stateful_fn = (
    505         self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
--> 506             *args, **kwds))
    507 
    508     def invalid_creator_scope(*unused_args, **unused_kwds):

~/.local/lib/python3.6/site-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
   2444       args, kwargs = None, None
   2445     with self._lock:
-> 2446       graph_function, _, _ = self._maybe_define_function(args, kwargs)
   2447     return graph_function
   2448 

~/.local/lib/python3.6/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
   2775 
   2776       self._function_cache.missed.add(call_context_key)
-> 2777       graph_function = self._create_graph_function(args, kwargs)
   2778       self._function_cache.primary[cache_key] = graph_function
   2779       return graph_function, args, kwargs

~/.local/lib/python3.6/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
   2665             arg_names=arg_names,
   2666             override_flat_arg_shapes=override_flat_arg_shapes,
-> 2667             capture_by_value=self._capture_by_value),
   2668         self._function_attributes,
   2669         # Tell the ConcreteFunction to clean up its graph once it goes out of

~/.local/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
    979         _, original_func = tf_decorator.unwrap(python_func)
    980 
--> 981       func_outputs = python_func(*func_args, **func_kwargs)
    982 
    983       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

~/.local/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
    439         # __wrapped__ allows AutoGraph to swap in a converted function. We give
    440         # the function a weak reference to itself to avoid a reference cycle.
--> 441         return weak_wrapped_fn().__wrapped__(*args, **kwds)
    442     weak_wrapped_fn = weakref.ref(wrapped_fn)
    443 

~/.local/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
    966           except Exception as e:  # pylint:disable=broad-except
    967             if hasattr(e, "ag_error_metadata"):
--> 968               raise e.ag_error_metadata.to_exception(e)
    969             else:
    970               raise

ValueError: in user code:

    <ipython-input-28-4487772bee64>:23 g  *
        actor(context)
    <ipython-input-28-4487772bee64>:11 call  *
        feat= self.lstm(inputs)
    /home/cooper-cooper/.local/lib/python3.6/site-packages/tensorflow/python/keras/layers/recurrent.py:654 __call__  **
        return super(RNN, self).__call__(inputs, **kwargs)
    /home/cooper-cooper/.local/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer.py:886 __call__
        self.name)
    /home/cooper-cooper/.local/lib/python3.6/site-packages/tensorflow/python/keras/engine/input_spec.py:227 assert_input_compatibility
        ', found shape=' + str(shape))

    ValueError: Input 0 is incompatible with layer lstm_7: expected shape=(10, None, 1), found shape=[1, 1, 1]

import tensorflow as tf
import numpy as np


class Actor(tf.keras.Model):
    def __init__(self):
        super(Actor,self).__init__()
        self.lstm = tf.keras.layers.LSTM(5, return_sequences=True, stateful=True,input_shape=(1,1))#, input_shape=(None,None,1))

    def call(self, inputs):
        feat= self.lstm(inputs)
        return feat

    def reset_states_workaround(self, new_batch_size=1):
        self.lstm.states = [tf.Variable(tf.random.uniform((new_batch_size,5))), tf.Variable(tf.random.uniform((new_batch_size,5)))]
        self.lstm.input_spec = [tf.keras.layers.InputSpec(shape=(new_batch_size,None,1), ndim=3)]

actor = Actor()
@tf.function
def g(actor):
    context1 = tf.reshape(np.array([0.]*10),(10,1,1))
    preds = actor(context1)
    return preds

g(actor)    
actor.reset_states_workaround(new_batch_size=1)
@tf.function
def g2(actor):
    context1 = tf.reshape(np.array([0.]*1),(1,1,1))
    preds = actor(context1)
    return preds

g2(actor)