Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/tensorflow/5.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Python Tensorflow在使用tf.map_fn()迭代到张量时获取维度信息_Python_Tensorflow - Fatal编程技术网

Python Tensorflow在使用tf.map_fn()迭代到张量时获取维度信息

Python Tensorflow在使用tf.map_fn()迭代到张量时获取维度信息,python,tensorflow,Python,Tensorflow,假设我有一个形状[s1,s2,s3]的张量ts,我想用tf.map\u fn迭代它: tf.map_fn(lambda dim1: tf.map_fn(lambda dim2: do_sth(dim, idx1, idx2) ,dim) ,ts) 上面的idx1和idx2是dou sth()当前所在的ts维度0和维度1的索引。我怎样才能得到这些?我想把它当作我正在做的事情: for idx1 in range(s1): for idx2 in range

假设我有一个形状
[s1,s2,s3]
的张量
ts
,我想用
tf.map\u fn
迭代它:

tf.map_fn(lambda dim1:
    tf.map_fn(lambda dim2:
        do_sth(dim, idx1, idx2)
    ,dim)
,ts)
上面的
idx1
idx2
dou sth()
当前所在的
ts
维度0和维度1的索引。我怎样才能得到这些?我想把它当作我正在做的事情:

for idx1 in range(s1):
    for idx2 in range(s2):
        tensor = ts[idx1][idx2]
        do_sth(tensor, idx1, idx2)
我不能这样做的原因是大部分时间
s1、s2、s3
是未知的(即
ts
的形状
(?,?,t3)
或类似)


这可能吗?

我建议您在运行命令之前,在末尾添加一个额外的维度,并用索引填充它

该代码(已测试)将指数分别添加到X和Y指数乘以10和100的值中:

from __future__ import print_function
import tensorflow as tf
import numpy as np

r = 3

a = tf.reshape( tf.constant( range( r * r ) ), ( r, r ) )
x = tf.tile( tf.cast( tf.lin_space( 0.0, r - 1, r )[ None, : ], tf.int32 ), [ r, 1 ] )
y = tf.tile( tf.cast( tf.lin_space( 0.0, r - 1, r )[ :, None ], tf.int32 ), [ 1, r ] )

b = tf.stack( [ a, x, y ], axis = -1 )

c = tf.map_fn( lambda y: tf.map_fn( lambda x: 
        x[ 0 ] + 10 * x[ 1 ] + 100 * x[ 2 ]
    , y ), b )

with tf.Session() as sess:
    res = sess.run( [ c ] )
    for x in res:
        print()
        print( x )
产出:

[[0 11 22]
[103 114 125]
[206 217 228]]