Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/tensorflow/5.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Tensorflow tf.one_hot()是否支持SparSetSensor作为索引参数?_Tensorflow_Tflearn - Fatal编程技术网

Tensorflow tf.one_hot()是否支持SparSetSensor作为索引参数?

Tensorflow tf.one_hot()是否支持SparSetSensor作为索引参数?,tensorflow,tflearn,Tensorflow,Tflearn,我想问一下,函数是否支持SparSetSensor作为“索引”参数。我想做一个多标签分类(每个例子都有多个标签),这需要计算交叉熵损失 我试图直接将SparseTensor置于“Indexs”参数中,但它会引发以下错误: TypeError:无法将类型的对象转换为Tensor。内容:SparseTensor(索引=张量(“读取批处理特征/fifo队列输出:106”,形状=(?,2),数据类型=int64,设备=/job:worker),值=张量(“字符串到索引查找:0”,形状=(?,),数据类型

我想问一下,函数是否支持SparSetSensor作为“索引”参数。我想做一个多标签分类(每个例子都有多个标签),这需要计算交叉熵损失

我试图直接将SparseTensor置于“Indexs”参数中,但它会引发以下错误:

TypeError:无法将类型的对象转换为Tensor。内容:SparseTensor(索引=张量(“读取批处理特征/fifo队列输出:106”,形状=(?,2),数据类型=int64,设备=/job:worker),值=张量(“字符串到索引查找:0”,形状=(?,),数据类型=int64,设备=/job:worker),密集形状=张量(“读取批处理特征/fifo队列输出:108”,形状=(2,),数据类型=int64,设备=/job:worker))。将铸造元素考虑为支持类型。

对可能的原因有什么建议吗


谢谢。

one\u hot不支持将SparseTensor作为索引参数。不过,您可以将稀疏张量的索引/值张量作为索引参数传递,这可能会解决您的问题。

您可以从初始SparseTensor构建另一个形状为
(批大小,num类)
。例如,如果将类保留在单个字符串要素列中(用空格分隔),则可以使用以下选项:

将tensorflow导入为tf
所有类=[“1类”、“2类”、“3类”]
classes_列=[“class1 class3”、“class1 class2”、“class2”、“class3”]
table=tf.contrib.lookup.index\u table\u from\u tensor(
映射=tf.常数(所有类)
)
classes=tf.常数(classes\U列)
class=tf.string\u split(类)
idx=table.lookup(classes)#SparseTensor的形状(4,2),因为4行中的每一行最多有2个类
num_items=tf.cast(tf.shape(idx)[0],tf.int64)#批处理中的项目数
num#entries=tf.shape(idx.index)[0]#num非零项
y=tf.SparseTensor(
索引=tf.stack([idx.index[:,0],idx.values],axis=1),
value=tf.ones(shape=(num_条目,),dtype=tf.int32),
密集形状=(num\u项,len(所有类)),
)
y=tf.稀疏张量到稠密(y,验证指数=False)
使用tf.Session()作为sess:
tf.tables_initializer().run()
打印(sess.run(y))
#产出:
# [[1 0 1]
#  [1 1 0]
#  [0 1 0]
#  [0 0 1]]
这里的
idx
是一个稀疏传感器。其索引的第一列
idx.indexs[:,0]
包含批次的行号,其值
idx.values
包含相关类id的索引。我们将这两列结合起来创建新的
y.indexs

有关多标签分类的完整实现,请参见