Python 如何在Tensorflow中生成条件语句

Python 如何在Tensorflow中生成条件语句,python,tensorflow,Python,Tensorflow,我正在使用Tensorflow 1.14.0并试图编写一个非常简单的函数,其中包含Tensorflow的条件语句。它的常规(非Tenslorflow)版本是: def u(x): if x<7: y=x+x else: y=x**2 return y 我将得到如下错误: x=tf.Variable(3,name='x') sess=tf.Session() sess.run(x.initializer) result=sess.ru

我正在使用Tensorflow 1.14.0并试图编写一个非常简单的函数,其中包含Tensorflow的条件语句。它的常规(非Tenslorflow)版本是:

def u(x):
    if x<7:
        y=x+x
    else:
        y=x**2
    return y
我将得到如下错误:

x=tf.Variable(3,name='x')
sess=tf.Session()
sess.run(x.initializer)
result=sess.run(u(x))
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-26-789531cde07a> in <module>
      2 sess=tf.Session()
      3 sess.run(x.initializer)
----> 4 result=sess.run(u(x))
      5 # print(result)

<ipython-input-23-39f85f34465a> in uu(x)
      2 
      3 def u(x):
----> 4     if x<7:
      5         y=x+x
      6     else:

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\framework\ops.py in __bool__(self)
    688       `TypeError`.
    689     """
--> 690     raise TypeError("Using a `tf.Tensor` as a Python `bool` is not allowed. "
    691                     "Use `if t is not None:` instead of `if t:` to test if a "
    692                     "tensor is defined, and use TensorFlow ops such as "

TypeError: Using a `tf.Tensor` as a Python `bool` is not allowed. Use `if t is not None:` instead of `if t:` to test if a tensor is defined, and use TensorFlow ops such as tf.cond to execute subgraphs conditioned on the value of a tensor.
---------------------------------------------------------------------------
TypeError回溯(最近一次调用上次)
在里面
2 sess=tf.Session()
3个sess.run(x.初始值设定项)
---->4结果=sess运行(u(x))
5#打印(结果)
在uu(x)中
2.
3 def u(x):
---->4如果x这是TF中的错误(相关Github问题:)。例如,以下场景可以工作

什么有效 将
tf.Variable
更改为
tf.constant
加常数 什么不起作用 但是
tf.add(x,x)
x+x
失败。造成这种情况的原因是
tf.add
在使用
tf.Variable
类型时遇到问题,但在
tf.Tensor
类型时效果良好。我有一种预感,可以在源代码中找到一些见解。当我发现任何东西时会更新

解决方案(适用于TF
1.15
) 您需要启用
tf.cond
的版本2,该版本显然已修复此问题。你可以这样做。不幸的是,这不适用于
1.14

在Jupyter上使用魔法 使用Python
这应该会给你想要的结果。

与你的问题没有直接关系,但是为什么你似乎不使用空格?@AMC为什么以及我应该怎么做?这并不难,只要一些简单的调整就可以帮助你:
如果x
如果x<7:
,或者
y=x+x
->
y=x+x
,也可以是
y=2*x
。它使事情更容易阅读@这就是为什么我在我的问题上被否决的原因吗?我似乎否决了你的帖子,因为我不记得了,也许这是个意外。我忘了提到你不应该把这个导入放在函数里面。谢谢。我使用Jupyter,但这并没有解决问题,使用“%env TF\u ENABLE\u COND\u V2='1'”是否尝试了另一个
os.environ[…]=
?没关系,上面的解决方案已经足够了,我可以使用multiply@Albert你说得对
TF\u ENABLE\u COND\u V2
仅适用于
1.15+
。不幸的是,TF1.14不起作用。我将相应地更新我的解决方案。
def u(x):
    import tensorflow as tf
    y=tf.cond(x < 7, lambda: tf.add(x, x), lambda: tf.square(x))
    return y 
InvalidArgumentError                      Traceback (most recent call last)
~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\client\session.py in _do_call(self, fn, *args)
   1355     try:
-> 1356       return fn(*args)
   1357     except errors.OpError as e:

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\client\session.py in _run_fn(feed_dict, fetch_list, target_list, options, run_metadata)
   1340       return self._call_tf_sessionrun(
-> 1341           options, feed_dict, fetch_list, target_list, run_metadata)
   1342 

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\client\session.py in _call_tf_sessionrun(self, options, feed_dict, fetch_list, target_list, run_metadata)
   1428         self._session, options, feed_dict, fetch_list, target_list,
-> 1429         run_metadata)
   1430 

InvalidArgumentError: Retval[0] does not have value

During handling of the above exception, another exception occurred:

InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-27-06e1605182c1> in <module>
      2 sess=tf.Session()
      3 sess.run(x.initializer)
----> 4 result=sess.run(u(x))
      5 # print(result)

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\client\session.py in run(self, fetches, feed_dict, options, run_metadata)
    948     try:
    949       result = self._run(None, fetches, feed_dict, options_ptr,
--> 950                          run_metadata_ptr)
    951       if run_metadata:
    952         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\client\session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
   1171     if final_fetches or final_targets or (handle and feed_dict_tensor):
   1172       results = self._do_run(handle, final_targets, final_fetches,
-> 1173                              feed_dict_tensor, options, run_metadata)
   1174     else:
   1175       results = []

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\client\session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
   1348     if handle is None:
   1349       return self._do_call(_run_fn, feeds, fetches, targets, options,
-> 1350                            run_metadata)
   1351     else:
   1352       return self._do_call(_prun_fn, handle, feeds, fetches)

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\client\session.py in _do_call(self, fn, *args)
   1368           pass
   1369       message = error_interpolation.interpolate(message, self._graph)
-> 1370       raise type(e)(node_def, op, message)
   1371 
   1372   def _extend_graph(self):

InvalidArgumentError: Retval[0] does not have value
x=tf.constant(3,name='x')
def u(x):    
    y=tf.cond(x < 7, lambda: tf.add(x, x), lambda: tf.square(x))
    return y
def u(x):    
    y=tf.cond(x < 7, lambda: tf.multiply(x, x), lambda: tf.square(x))
    return y
def u(x):    
    y=tf.cond(x < 7, lambda: tf.add(x, 2), lambda: tf.square(x))
    return y 
def u(x):    
    y=tf.cond(x < 7, lambda: tf.math.add(tf.identity(x), tf.identity(x)), lambda: tf.square(x))
    return y 
%env TF_ENABLE_COND_V2='1'
os.environ['TF_ENABLE_COND_V2'] = '1'