Python 构造一个在传递给tf.concat()时表现为空的TensorFlow张量
我正在构造一个计算图,它的拓扑结构根据一些超参数而变化。在某个时刻,会发生连接:Python 构造一个在传递给tf.concat()时表现为空的TensorFlow张量,python,tensorflow,Python,Tensorflow,我正在构造一个计算图,它的拓扑结构根据一些超参数而变化。在某个时刻,会发生连接: c = tf.concat([a, b], axis=-1) 张量a具有形状(无,m)。 张量b具有形状(无,n),其中n取决于超参数。对于超参数的一个值,张量b在概念上应为空,例如,我们希望c和a相同 我可以通过以下方法成功构建图形: b = tf.placeholder(tf.float32, (None, 0), name="Empty") 但是,如果我运行一个会话,TensorFlow会引发一个Inva
c = tf.concat([a, b], axis=-1)
张量a
具有形状(无,m)
。
张量b
具有形状(无,n)
,其中n
取决于超参数。对于超参数的一个值,张量b
在概念上应为空,例如,我们希望c
和a
相同
我可以通过以下方法成功构建图形:
b = tf.placeholder(tf.float32, (None, 0), name="Empty")
但是,如果我运行一个会话,TensorFlow会引发一个InvalidArguMemError
声明:
You must feed a value for placeholder tensor 'Empty' with dtype float and shape [?,0]
有没有办法构造一个在concat
操作中表现为空的张量,但不需要输入虚假输入
显然,我知道我可以在构建图的代码中添加一个特例、包装器等。我希望避免这种情况
完整代码:
import tensorflow as tf
import numpy as np
a = tf.placeholder(tf.float32, (None, 10))
b = tf.placeholder(tf.float32, (None, 0), name="Empty")
c = tf.concat([a, b], axis=-1)
assert c.shape.as_list() == [None, 10]
with tf.Session() as sess:
a_feed = np.zeros((100, 10))
c = sess.run(c, {a : a_feed})
您可以使用不需要输入占位符的
import tensorflow as tf
import numpy as np
# Hparams
batch_size = 100
a_dim = 10
b_dim = 0
# Placeholder for a which is required to be fed.
a = tf.placeholder(tf.float32, (None, a_dim))
# Placeholder for b, which doesn't have to be fed.
b_default = np.zeros((batch_size, b_dim), dtype=np.float32)
b = tf.placeholder_with_default(
b_default, (None, b_dim), name="Empty"
)
c = tf.concat([a, b], axis=-1)
assert c.shape.as_list() == [None, a_dim + b_dim]
with tf.Session() as sess:
a_feed = np.zeros((batch_size, a_dim))
b_feed = np.ones((batch_size, b_dim))
c_out = sess.run(c, {a : a_feed})
# You can optionally feed in b:
# c_out = sess.run(c, {a : a_feed, b : b_feed})
print(c_out)
如果您不是使用
tf.placeholder()
来提供数据,而是使用tf.Estimator
,那么解决方案很简单,因为您可以定义:
b = tf.zeros([a.shape[0].value, 0])
如果a的形状已知
c = tf.concat([a,b],axis=-1)
assert c.shape == a.shape
将始终成功。如消息所示,您还必须为b提供反馈
with tf.Session() as sess:
a_feed = np.zeros((100, 10))
b_feed = np.zeros((100, 0))
c = sess.run(c, {a : a_feed, b: b_feed})
它在我的电脑上传递 为什么不在b为空时将一个无意义的numpy传递给b?因为此代码的使用者不需要知道这个奇怪的空占位符,它的唯一目的是使
concat
在特殊情况下工作。在问题中,我问:“有没有办法构造一个在concat
操作中表现为空的张量,但不需要输入虚假的输入?”-因此我特别试图避免您描述的解决方案。