Python 如何从张量中得到值
以下是我的设置:Python 如何从张量中得到值,python,tensorflow,Python,Tensorflow,以下是我的设置: indices = tf.placeholder(tf.int32, shape=[2]) weights = tf.Variable(tf.random_normal([100000, 3], stddev=0.35)) def objective(indices, weights): idx1 = indices[0]; idx2 = indices[1] #extract two indices mask = np.zeros(weights.shape
indices = tf.placeholder(tf.int32, shape=[2])
weights = tf.Variable(tf.random_normal([100000, 3], stddev=0.35))
def objective(indices, weights):
idx1 = indices[0]; idx2 = indices[1] #extract two indices
mask = np.zeros(weights.shape.as_list()[0]) #builds a mask for some tensor "weights"
mask[idx1] = 1 # don't ask why I want to do this. I just do.
mask[idx2] = 1
obj = tf.reduce_sum(tf.multiply(weights[idx1], weights[idx2]))
return obj
optimizer = tf.train.GradientDescentOptimizer(0.01)
obj = objective(indices, weights)
trainer = optimizer.minimize(obj)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run([trainer, obj], feed_dict={indices=[100, 1000]})
重点是我有一些张量,我取一个切片,它对应于我的掩码中的一个索引。该索引是一个tf.stridded\u切片
。我想用idx1
和idx2
为我的掩码编制索引,因为两者的计算结果都是int
但是idx1
和idx2
不是整数而是张量,因此obj=objective(index,weights)
调用会导致错误
如何使代码工作?您可以使用tf.SparseTensor
和tf.sparse\u tensor\u to\u dense
的组合来实现您想要的:
import numpy as np
import tensorflow as tf
indices = tf.placeholder(tf.int64, shape=[2])
weights = tf.Variable(tf.random_normal([5, 3], stddev=0.35))
def objective(indices, weights):
idx1 = indices[0]; idx2 = indices[1] #extract two indices
mask = np.zeros(weights.shape.as_list()[0]) #builds a mask for some tensor "weights"
mask_ones = tf.SparseTensor(tf.reshape(indices, [-1,1]), [1, 1], mask.shape) # Stores the 1s used in the mask
mask = mask + tf.sparse_tensor_to_dense(mask_ones) # Set the mask
obj = tf.reduce_sum(tf.multiply(weights[idx1], weights[idx2]))
return obj, mask
obj, mask = objective(indices, weights)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run([weights, obj, mask], feed_dict={indices:[0, 4]}))
[array([[…],dtype=float32),0.0068909675,array([1,0,0,0,1.],dtype=int32)]
如果您在会话中计算张量,例如通过调用eval(),您可以使用其值对NumPy数组进行索引。您没有提供足够的代码来复制错误,因此很难更具体地描述。我忽略了更多的模板开销。特别是,这是在传递到优化器(例如tf.GradientDescentOptimizer)的目标
函数中实现的。我将更新代码,以便更清楚我的问题是什么,以及为什么您的解决方案不起作用。好的,我已经更新了代码,以显示它在哪里中断。具体地说,在图形定义过程中,它在我的objective
调用中中断。解决方案无法在会话中运行eval()
,因为错误发生在我定义图形时,而不是在执行图形时。这确实有效!我将其标记为正确,因为它对应于我所问的问题,尽管它实际上不是我最终使用的。