Python Tensorflow:自定义操作需要定义哪些渐变?

Python Tensorflow:自定义操作需要定义哪些渐变?,python,tensorflow,keras,gradient,Python,Tensorflow,Keras,Gradient,虽然有很多参考资料显示了如何注册渐变,但我仍然不太清楚到底需要定义什么样的渐变 一些类似的主题: 好的,我的问题来了: 我有一个正向函数y=f(a,B),其中每个函数的大小为: y: (batch_size, m, n) A: (batch_size, a, a) B: (batch_size, b, b) 假设我可以写下y的每个元素对A和B的每个元素的数学偏导数。dy/dA,dy/dB。我的问题是,我应该在梯度函数中返回什么 @ops.RegisterGradient("f") de

虽然有很多参考资料显示了如何注册渐变,但我仍然不太清楚到底需要定义什么样的渐变

一些类似的主题:


好的,我的问题来了:

我有一个正向函数
y=f(a,B)
,其中每个函数的大小为:

y: (batch_size, m, n)
A: (batch_size, a, a)
B: (batch_size, b, b)

假设我可以写下y的每个元素对A和B的每个元素的数学偏导数。
dy/dA,dy/dB
。我的问题是,我应该在梯度函数中返回什么

@ops.RegisterGradient("f")
def f_grad(op, grad):
    ...
    return ???, ???
表示梯度函数的结果必须是表示每个输入的梯度的张量对象列表

y
是标量,
A
B
是矩阵时,很容易理解要定义的梯度。但是当
y
是矩阵,
A
B
也是矩阵时,梯度应该是什么?

计算每个输出张量之和相对于输入张量中每个值的梯度。渐变操作接收要计算其渐变的op,
op
,以及在此点累积的渐变,
grad
。在您的示例中,
grad
将是一个与
y
形状相同的张量,每个值将是
y
中相应值的梯度-也就是说,如果
grad[0,0]==2
,这意味着将
y[0,0]
增加1将使输出张量之和增加2(我知道,你可能已经很清楚了)。现在你必须为
A
B
计算相同的东西。假设你计算出将
A[2,3]
增加1将增加
y[0,0]
增加3,并且对
y
中的任何其他值没有影响。这意味着输出值的总和将增加3×2=6,因此
A[2,3]
的梯度将为6

例如,让我们以矩阵乘法的梯度(op
MatMul
)为例,您可以在以下内容中找到:

我们将重点讨论
transpose_a
transpose_b
都是
False
,因此我们在第一个分支中,
如果不是t_a,也不是t_b:
(也忽略
conj
,它是用于复数值的)“a”和“b”是这里的操作数,如前所述,
grad
具有输出总和相对于乘法结果中每个值的梯度。那么,如果我增加
a[0,0],情况会发生什么变化
增加1?基本上,乘积矩阵第一行中的每个元素都会增加
b
第一行中的值。因此
a[0,0]的梯度
b
的第一行和
grad
的第一行的点积-也就是说,我将增加多少输出值乘以每个值的累积梯度。如果你想一想,行
grad\u a=gen\u math\u ops.mat\u mul(grad,b,transpose\u b=True)
正是这样做的。
grad\u a[0,0]
将是
grad
第一行和
b
第一行的点积(因为我们在这里转置
b
),一般来说,
grad\u a[i,j]
将是
i
-grad的第i行和
j
-b的第i行的点积。您也可以遵循类似的推理


编辑:

作为示例,请参见和注册的渐变如何相互关联:

import tensorflow as tf
# Import gradient registry to lookup gradient functions
from tensorflow.python.framework.ops import _gradient_registry

# Gradient function for matrix multiplication
matmul_grad = _gradient_registry.lookup('MatMul')
# A matrix multiplication
a = tf.constant([[1, 2], [3, 4]], dtype=tf.float32)
b = tf.constant([[6, 7, 8], [9, 10, 11]], dtype=tf.float32)
c = tf.matmul(a, b)
# Gradient of sum(c) wrt each element of a
grad_c_a_1, = tf.gradients(c, a)
# The same is obtained by backpropagating an all-ones matrix
grad_c_a_2, _ = matmul_grad(c.op, tf.ones_like(c))
# Multiply each element of c by itself, but stopping the gradients
# This should scale the gradients by the values of c
cc = c * tf.stop_gradient(c)
# Regular gradients computation
grad_cc_a_1, = tf.gradients(cc, a)
# Gradients function called with c as backpropagated gradients
grad_cc_a_2, _ = matmul_grad(c.op, c)
with tf.Session() as sess:
    print('a:')
    print(sess.run(a))
    print('b:')
    print(sess.run(b))
    print('c = a * b:')
    print(sess.run(c))
    print('tf.gradients(c, a)[0]:')
    print(sess.run(grad_c_a_1))
    print('matmul_grad(c.op, tf.ones_like(c))[0]:')
    print(sess.run(grad_c_a_2))
    print('tf.gradients(c * tf.stop_gradient(c), a)[0]:')
    print(sess.run(grad_cc_a_1))
    print('matmul_grad(c.op, c)[0]:')
    print(sess.run(grad_cc_a_2))
输出:

a:
[[1. 2.]
 [3. 4.]]
b:
[[ 6.  7.  8.]
 [ 9. 10. 11.]]
c = a * b:
[[24. 27. 30.]
 [54. 61. 68.]]
tf.gradients(c, a)[0]:
[[21. 30.]
 [21. 30.]]
matmul_grad(c.op, tf.ones_like(c))[0]:
[[21. 30.]
 [21. 30.]]
tf.gradients(c * tf.stop_gradient(c), a)[0]:
[[ 573.  816.]
 [1295. 1844.]]
matmul_grad(c.op, c)[0]:
[[ 573.  816.]
 [1295. 1844.]]

谢谢!这是否意味着在自定义梯度函数中,我需要返回与
tf.gradients
应该给出的结果相同的结果,其中每个元素都是dy/dx的偏导数之和?@NathanExplosion是的,听起来不错。我添加了一个片段(我希望)演示和梯度函数如何相互关联。我尝试了
tf.gradients(c[0,0],a)
,它将返回
dc[0,0]/da
。但是如果我们将返回的梯度定义为部分梯度的总和,它怎么能导出单个梯度呢?@NathanExplosion在这种情况下,流程是这样的。您有一个切片操作,它给您一个标量,因此您使用标量1开始梯度计算,即
dc[0,0]/dc[0,0]
。然后你计算
dc[0,0]/dc
,这是一个形状像
c
的矩阵,梯度为
c[0,0]
wrt每个元素-因此它是一个矩阵,
grad\u c
,除了第一行第一个值中的1之外,所有0。然后你可以得到
dc[0,0]/da
,即
(dc[0,0]/dc)*(dc/da)
。我们看到它最终是
grad_c*b.T
,因此得到一个大小为
a
的矩阵,其中第一行是
b
的第一列,所有其他行都是0。
a:
[[1. 2.]
 [3. 4.]]
b:
[[ 6.  7.  8.]
 [ 9. 10. 11.]]
c = a * b:
[[24. 27. 30.]
 [54. 61. 68.]]
tf.gradients(c, a)[0]:
[[21. 30.]
 [21. 30.]]
matmul_grad(c.op, tf.ones_like(c))[0]:
[[21. 30.]
 [21. 30.]]
tf.gradients(c * tf.stop_gradient(c), a)[0]:
[[ 573.  816.]
 [1295. 1844.]]
matmul_grad(c.op, c)[0]:
[[ 573.  816.]
 [1295. 1844.]]