Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/tensorflow/5.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181

Warning: file_get_contents(/data/phpspider/zhask/data//catemap/0/docker/10.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Tensorflow tfcompile of tf.cond of constants错误_Tensorflow_Constants_Aot_Xla - Fatal编程技术网

Tensorflow tfcompile of tf.cond of constants错误

Tensorflow tfcompile of tf.cond of constants错误,tensorflow,constants,aot,xla,Tensorflow,Constants,Aot,Xla,使用以下示例代码创建具有cond的图形: from __future__ import absolute_import import tensorflow as tf from tensorflow.compiler.tf2xla.tf2xla_pb2 import Config, Feed, Fetch, TensorId from tensorflow.core.framework.tensor_shape_pb2 import TensorShapeProto def tf2xla

使用以下示例代码创建具有cond的图形:

from __future__ import absolute_import

import tensorflow as tf

from tensorflow.compiler.tf2xla.tf2xla_pb2 import Config, Feed, Fetch, TensorId
from tensorflow.core.framework.tensor_shape_pb2 import TensorShapeProto


def tf2xla_config_feed( feed ):

  name = feed.name.split( ':' )[ 0 ]
  pb_id = TensorId( node_name = name )
  pb_dim = [ TensorShapeProto.Dim( size = x.value ) for x in feed.shape ]
  pb_tensor_shape_proto = TensorShapeProto( dim = pb_dim )
  pb_feed = Feed( id = pb_id, shape = pb_tensor_shape_proto )
  return pb_feed


def tf2xla_config_fetch( fetch ):

  name = fetch.name.split( ':' )[ 0 ]
  pb_id = TensorId( node_name = name )
  pb_fetch = Fetch( id = pb_id )
  return pb_fetch


def tf2xla_config( feeds, fetches ):

  pb_feeds = map( tf2xla_config_feed, feeds )
  pb_fetches = map( tf2xla_config_fetch, fetches )
  return Config( feed = pb_feeds, fetch = pb_fetches )


a = tf.placeholder( tf.float64, shape = ( 2, ), name = 'a' )

a1 = a[ 0 ]
a2 = a[ 1 ]

one = tf.constant( 1 )
two = tf.constant( 2 )

res = tf.cond( a1 < a2, lambda: one, lambda: two )

with open( 'test_graph.pb', 'wb' ) as f:
  f.write( res.graph.as_graph_def().SerializeToString() )

with open( 'test_config.pb', 'wb' ) as f:
  f.write( tf2xla_config( [ a ], [ res ] ).SerializeToString() )
导致以下错误:

2017-11-29 20:40:26.725164:F tensorflow/compiler/aot/tfcompile_main.cc:140]非正常状态:状态 状态:未实现:从TysFROM图转换为XLA 导致1个恒定结果。输出参数的配置 (即获取ID)可能是错误的

看来这个错误是没有根据的?还是我做错了什么

tfcompile --graph=test_graph.pb --config=test_config.pb --entry_point=test_func --cpp_class=test --out_object=test_func.o --out_header=test.hpp