Python 能否使用tf.scatter\u update或tf.scatter\u nd\u update来更新张量的列切片?

Python 能否使用tf.scatter\u update或tf.scatter\u nd\u update来更新张量的列切片?,python,tensorflow,Python,Tensorflow,我想实现一个函数,它接受一个变量作为输入,对它的一些行或列进行变异,并将它们替换回原始变量中。我可以使用tf.gather和tf.scatter\u update为行切片实现它,但无法为列切片实现它,因为tf.scatter\u update显然只更新行切片,没有轴功能。我不是tensorflow方面的专家,因此我可能遗漏了一些东西。有人能帮忙吗 def matrix_reg(t, percent_t, beta): ''' Takes a variable tensor t

我想实现一个函数,它接受一个变量作为输入,对它的一些行或列进行变异,并将它们替换回原始变量中。我可以使用tf.gather和tf.scatter\u update为行切片实现它,但无法为列切片实现它,因为tf.scatter\u update显然只更新行切片,没有轴功能。我不是tensorflow方面的专家,因此我可能遗漏了一些东西。有人能帮忙吗

def matrix_reg(t, percent_t, beta):
    
    ''' Takes a variable tensor t as input and regularizes some of its rows.
    The number of rows to be regularized are specified by the percent_t. Returns the original tensor by updating its rows indexed by row_ind.
    
    Arguments:
        t -- input tensor
        percent_t -- percentage of the total rows
        beta -- the regularization factor
    Output:
        the regularized tensor
        '''
    row_ind = np.random.choice(int(t.shape[0]), int(percent_t*int(t.shape[0])), replace = False)
    t_ = tf.gather(t,row_ind)
    t_reg = (1+beta)*t_-beta*(tf.matmul(tf.matmul(t_,tf.transpose(t_)),t_))
    return tf.scatter_update(t, row_ind, t_reg)

下面是一个如何更新行或列的小演示。其思想是指定变量的行和列索引,以便更新中的每个元素都在这些索引中结束。这很容易做到

输出:

行已更新:
[[0. 0. 0.]
[4. 5. 6.]
[0. 0. 0.]
[1. 2. 3.]]
更新的栏目:
[[1. 0. 5.]
[2. 5. 6.]
[3. 0. 7.]
[4. 2. 8.]]

有关tf.变量,请参阅Tensorflow2文档

\uuuu获取项目\uuuuu
( 变量,切片(规格)

在给定变量的情况下创建切片辅助对象

这允许从当前内容的一部分创建子张量 一个变量。有关切片的详细示例,请参见tf.Tensor.getitem

此外,此功能还允许分配到切片范围。 这类似于Python中的
\uuuuuu setitem\uuuuu
功能。但是, 语法不同,因此用户可以捕获分配 用于分组或传递给sess.run()的操作。比如说,

下面是一个简单的工作示例:

将tensorflow导入为tf
将numpy作为np导入
var=tf.变量(np.rand.rand(3,3,3))
打印(var)
#将三(3x3)个矩阵的最后一列更新为随机整数值
#请注意,更新值需要具有相同的形状
#自TF2起不支持广播
var[:,:,2].assign(np.random.randint(10,size=(3,3)))
打印(var)

您可以发布行的工作代码吗?请查看更新的问题。
import tensorflow as tf

var = tf.get_variable('var', [4, 3], tf.float32, initializer=tf.zeros_initializer())
updates = tf.placeholder(tf.float32, [None, None])
indices = tf.placeholder(tf.int32, [None])
# Update rows
var_update_rows = tf.scatter_update(var, indices, updates)
# Update columns
col_indices_nd = tf.stack(tf.meshgrid(tf.range(tf.shape(var)[0]), indices, indexing='ij'), axis=-1)
var_update_cols = tf.scatter_nd_update(var, col_indices_nd, updates)
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    print('Rows updated:')
    print(sess.run(var_update_rows, feed_dict={updates: [[1, 2, 3], [4, 5, 6]], indices: [3, 1]}))
    print('Columns updated:')
    print(sess.run(var_update_cols, feed_dict={updates: [[1, 5], [2, 6], [3, 7], [4, 8]], indices: [0, 2]}))