Python 3.x TensorFlow 1.15中的Autograph在一个分支中赋值的条件语句中给出TypeError

Python 3.x TensorFlow 1.15中的Autograph在一个分支中赋值的条件语句中给出TypeError,python-3.x,tensorflow,conditional-statements,Python 3.x,Tensorflow,Conditional Statements,系统信息 海关代码:是 操作系统平台:Windows 10 PC TensorFlow版本:1.15.0 Python版本:3.6 当前行为 如果在条件语句的一个分支中为变量指定了一个常量值,则会为该变量推断该值的数据类型,该变量不需要与导致TypeError的变量的预期数据类型对齐。问题在于,在使用autograph创建图形和分配占位符数据类型之前,需要了解python源代码中的任何赋值(及其类型),这并不总是实用的。请参阅下面的示例代码 重现问题的代码 从TensorFlow文档中修改平

系统信息

  • 海关代码:是
  • 操作系统平台:Windows 10 PC
  • TensorFlow版本:1.15.0
  • Python版本:3.6
当前行为

如果在条件语句的一个分支中为变量指定了一个常量值,则会为该变量推断该值的数据类型,该变量不需要与导致TypeError的变量的预期数据类型对齐。问题在于,在使用autograph创建图形和分配占位符数据类型之前,需要了解python源代码中的任何赋值(及其类型),这并不总是实用的。请参阅下面的示例代码

重现问题的代码

从TensorFlow文档中修改平方法

import tensorflow as tf
from tensorflow import autograph as ag

#minimal code for method to demonstrate issue
def foo(x):
    if x > 0:
        y = x * x
    else:
        y = 0.0
    return y

#graph construction
mdl = tf.Graph()
with mdl.as_default():
    converted_foo = ag.to_graph(foo)
    print(ag.to_code(foo))
    x = tf.placeholder(tf.double, name='x')
    y = converted_foo(x)
错误消息是:

TypeError: "y" has dtype float64 in the TRUE branch, but dtype=float32 in the FALSE branch. TensorFlow control flow requires that they are the same.
请参阅下面的详细回溯

我们如何修改代码或签名行为才能使代码成功工作

一个(不需要的)解决方法是将
x
定义为:
x=tf.placeholder(tf.float32,name='x')

但是,如果
foo
是:

#minimal code for method to demonstrate issue
def foo(x):
    if x > 0:
        y = x * x
    else:
        y = 0
    return y
新的错误是:

TypeError: "y" has dtype float32 in the TRUE branch, but dtype=int32 in the FALSE branch. TensorFlow control flow requires that they are the same.
是否有更合适的解决方法

其他信息/日志 错误日志:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-47-9455269f4d6f> in <module>
     16     print(ag.to_code(foo))
     17     x = tf.placeholder(tf.double, name='x')
---> 18     y = converted_foo(x)

C:\Users\212613~1\AppData\Local\Temp\tmp3yjohro5.py in tf__foo(x)
     23           return y
     24         cond = x > 0
---> 25         y = ag__.if_stmt(cond, if_true, if_false, get_state, set_state, ('y',), ())
     26         do_return = True
     27         retval_ = foo_scope.mark_return_value(y)

~\AppData\Local\Continuum\anaconda3\envs\kchain\lib\site-packages\tensorflow_core\python\autograph\operators\control_flow.py in if_stmt(cond, body, orelse, get_state, set_state, basic_symbol_names, composite_symbol_names)
    891   if tensors.is_dense_tensor(cond):
    892     return tf_if_stmt(cond, body, orelse, get_state, set_state,
--> 893                       basic_symbol_names, composite_symbol_names)
    894   else:
    895     return _py_if_stmt(cond, body, orelse)

~\AppData\Local\Continuum\anaconda3\envs\kchain\lib\site-packages\tensorflow_core\python\autograph\operators\control_flow.py in tf_if_stmt(cond, body, orelse, get_state, set_state, basic_symbol_names, composite_symbol_names)
    929 
    930   final_vars, final_state = control_flow_ops.cond(cond, error_checking_body,
--> 931                                                   error_checking_orelse)
    932 
    933   set_state(final_state)

~\AppData\Local\Continuum\anaconda3\envs\kchain\lib\site-packages\tensorflow_core\python\util\deprecation.py in new_func(*args, **kwargs)
    505                 'in a future version' if date is None else ('after %s' % date),
    506                 instructions)
--> 507       return func(*args, **kwargs)
    508 
    509     doc = _add_deprecated_arg_notice_to_docstring(

~\AppData\Local\Continuum\anaconda3\envs\kchain\lib\site-packages\tensorflow_core\python\ops\control_flow_ops.py in cond(pred, true_fn, false_fn, strict, name, fn1, fn2)
   1233     try:
   1234       context_f.Enter()
-> 1235       orig_res_f, res_f = context_f.BuildCondBranch(false_fn)
   1236       if orig_res_f is None:
   1237         raise ValueError("false_fn must have a return value.")

~\AppData\Local\Continuum\anaconda3\envs\kchain\lib\site-packages\tensorflow_core\python\ops\control_flow_ops.py in BuildCondBranch(self, fn)
   1059     """Add the subgraph defined by fn() to the graph."""
   1060     pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
-> 1061     original_result = fn()
   1062     post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
   1063     if len(post_summaries) > len(pre_summaries):

~\AppData\Local\Continuum\anaconda3\envs\kchain\lib\site-packages\tensorflow_core\python\autograph\operators\control_flow.py in error_checking_orelse()
    925     if result[body_branch] is not None:
    926       _verify_tf_cond_vars(result[body_branch], result[orelse_branch],
--> 927                            basic_symbol_names, composite_symbol_names)
    928     return result[orelse_branch]
    929 

~\AppData\Local\Continuum\anaconda3\envs\kchain\lib\site-packages\tensorflow_core\python\autograph\operators\control_flow.py in _verify_tf_cond_vars(body_outputs, orelse_outputs, basic_symbol_names, composite_symbol_names)
    259 
    260     nest.map_structure(
--> 261         functools.partial(_check_same_type, name), body_output, orelse_output)
    262 
    263 

~\AppData\Local\Continuum\anaconda3\envs\kchain\lib\site-packages\tensorflow_core\python\util\nest.py in map_structure(func, *structure, **kwargs)
    534 
    535   return pack_sequence_as(
--> 536       structure[0], [func(*x) for x in entries],
    537       expand_composites=expand_composites)
    538 

~\AppData\Local\Continuum\anaconda3\envs\kchain\lib\site-packages\tensorflow_core\python\util\nest.py in <listcomp>(.0)
    534 
    535   return pack_sequence_as(
--> 536       structure[0], [func(*x) for x in entries],
    537       expand_composites=expand_composites)
    538 

~\AppData\Local\Continuum\anaconda3\envs\kchain\lib\site-packages\tensorflow_core\python\autograph\operators\control_flow.py in _check_same_type(name, body_output_var, orelse_output_var)
    256             ' branch. TensorFlow control flow requires that they are the'
    257             ' same.'.format(name, body_output_var.dtype.name,
--> 258                             orelse_output_var.dtype.name))
    259 
    260     nest.map_structure(

TypeError: "y" has dtype float64 in the TRUE branch, but dtype=float32 in the FALSE branch. TensorFlow control flow requires that they are the same.
---------------------------------------------------------------------------
TypeError回溯(最近一次调用上次)
在里面
16打印(公司代码(foo))
17 x=tf.placeholder(tf.double,name='x')
--->18 y=转换的_foo(x)
C:\Users\212613~1\AppData\Local\Temp\tmp3yjohro5.py在tf\uuufoo(x)中
23返回y
24秒=x>0
--->25 y=ag_uuu.if_stmt(cond,if_true,if_false,get_state,set_state,('y',),())
26 do_return=True
27返回=foo\u范围。标记返回值(y)
if stmt中的~\AppData\Local\Continuum\anaconda3\envs\kchain\lib\site packages\tensorflow\u core\python\autograph\operators\control\u flow.py(cond、body、orelse、get\u state、set\u state、basic\u symbol\u name、composite\u symbol\u name)
891如果张量是稠密张量(cond):
892返回tf_if_stmt(cond、body、orelse、get_state、set_state、,
-->893基本符号名称、复合符号名称)
894其他:
895如果测试(条件、身体或身体)返回
tf\u if stmt中的~\AppData\Local\Continuum\anaconda3\envs\kchain\lib\site packages\tensorflow\u core\python\autograph\operators\control\u flow.py(cond、body、orelse、get\u state、set\u state、basic\u symbol\u name、composite\u symbol\u name)
929
930最终变量,最终状态=控制流量操作条件(条件,错误检查主体,
-->931错误检查(orelse)
932
933设置_状态(最终_状态)
新函数中的~\AppData\Local\Continuum\anaconda3\envs\kchain\lib\site packages\tensorflow\u core\python\util\deprecation.py(*args,**kwargs)
505“在未来版本中”如果日期不是其他日期(“在%s“%date”之后),
506(说明)
-->507返回函数(*args,**kwargs)
508
509 doc=\u添加\u不推荐的\u参数\u通知\u到\u docstring(
~\AppData\Local\Continuum\anaconda3\envs\kchain\lib\site packages\tensorflow\u core\python\ops\control\u flow\u ops.py in cond(pred,true\u fn,false\u fn,strict,name,fn1,fn2)
1233尝试:
1234上下文输入()
->1235 orig_res_f,res_f=context_f.BuildCondBranch(false_fn)
1236如果原始资源为无:
1237 raise value ERROR(“false\u fn必须有返回值”)
BuildCondBranch中的~\AppData\Local\Continuum\anaconda3\envs\kchain\lib\site packages\tensorflow\u core\python\ops\control\u flow\u ops.py(self,fn)
1059“将fn()定义的子图添加到图中。”“”
1060 pre_summaries=ops.get_collection(ops.GraphKeys._SUMMARY_collection)#pylint:disable=protected access
->1061原始结果=fn()
1062 post_summaries=ops.get_collection(ops.GraphKeys._SUMMARY_collection)#pylint:disable=protected access
1063如果len(后总结)>len(前总结):
~\AppData\Local\Continuum\anaconda3\envs\kchain\lib\site packages\tensorflow\u core\python\autograph\operators\control\u flow.py出错\u checking\u orelse()
925如果结果[body_branch]不是无:
926验证条件变量(结果[body_分支]、结果[orelse_分支],
-->927基本符号名称、复合符号名称)
928返回结果[orelse_分支机构]
929
~\AppData\Local\Continuum\anaconda3\envs\kchain\lib\site packages\tensorflow\u core\python\autograph\operators\control\u flow.py in\u verify\u tf\u cond\u vars(body\u输出、orelse\u输出、基本符号名称、复合符号名称)
259
260 nest.map_结构(
-->261 functools.partial(_check_same_type,name),body_输出,或lse_输出)
262
263
映射结构中的~\AppData\Local\Continuum\anaconda3\envs\kchain\lib\site packages\tensorflow\u core\python\util\nest.py(func,*structure,**kwargs)
534
535返回包\u序列\u组件(
-->536结构[0],[func(*x)表示条目中的x],
537扩展_复合材料=扩展_复合材料)
538
~\AppData\Local\Continuum\anaconda3\envs\kchain\lib\site packages\tensorflow\u core\python\util\nest.py in(.0)
534
535返回包\u序列\u组件(
-->536结构[0],[func(*x)表示条目中的x],
537扩展_复合材料=扩展_复合材料)
538
~\AppData\Local\Continuum\anaconda3\envs\kchain\lib\site packages\tensorflow\u core\python\autograph\operators\control\u flow.py in\u check\u same\u type(名称、主体输出变量、或逻辑输出变量)
256'分支。TensorFlow控制流要求它们是'
257“相同”。。格式(名称、正文\u输出\u var.dtype.name、,
-->258 orelse_output_var.dtype.name))
259
x = tf.placeholder(tf.double, name='x')
x = tf.placeholder(tf.float32, name='x')
if x > 0.0:
        y = x * x
    else:
        y = 0.0