Python 2.7 如何在TensorFlow中提供自定义渐变

Python 2.7 如何在TensorFlow中提供自定义渐变,python-2.7,tensorflow,Python 2.7,Tensorflow,我试图理解如何使用TensorFlow 1.7中提供的@tf.custom_gradient函数来提供向量相对于向量的自定义梯度。下面的代码是解决以下问题的最小工作示例,以获得dz/dx y=Ax z=| y | 2 如果我不使用@tf.custom_gradient,那么TensorFlow会给出预期的解决方案。我的问题是如何为y=Ax提供自定义渐变?我们知道,dy/dx=A^T如上述附件所示,该附件显示了与TensorFlow输出匹配的计算步骤 import tensorflow as t

我试图理解如何使用TensorFlow 1.7中提供的
@tf.custom_gradient
函数来提供向量相对于向量的自定义梯度。下面的代码是解决以下问题的最小工作示例,以获得
dz/dx

y=Ax
z=| y | 2

如果我不使用
@tf.custom_gradient
,那么TensorFlow会给出预期的解决方案。我的问题是如何为y=Ax提供自定义渐变?我们知道,
dy/dx=A^T
如上述附件所示,该附件显示了与TensorFlow输出匹配的计算步骤

import tensorflow as tf

#I want to write custom gradient for this function f1
def f1(A,x):
    y=tf.matmul(A,x,name='y')
    return y

#for y= Ax, the derivative is: dy/dx= transpose(A)
@tf.custom_gradient
def f2(A,x):
    y=f1(A,x)
    def grad(dzByDy): # dz/dy = 2y reaches here correctly.
        dzByDx=tf.matmul(A,dzByDy,transpose_a=True) 
        return dzByDx
    return y,grad


x= tf.constant([[1.],[0.]],name='x')
A= tf.constant([ [1., 2.], [3., 4.]],name='A')

y=f1(A,x) # This works as desired
#y=f2(A,x) #This line gives Error


z=tf.reduce_sum(y*y,name='z')

g=tf.gradients(ys=z,xs=x)

with tf.Session() as sess:
    print sess.run(g)

由于函数
f2()
有两个输入,因此必须提供一个梯度以返回到每个输入。您看到的错误是:

为op name生成的Num渐变2:“IdentityN”[…]与Num输入3不匹配

不过,不可否认的是,它相当神秘。假设您永远不想计算dy/dA,您可以返回None,dzByDx。下面的代码(已测试):

产出:

[数组([[20.], [28.]],dtype=32]


根据需要。

当然,我很乐意帮忙!:)
import tensorflow as tf

#I want to write custom gradient for this function f1
def f1(A,x):
    y=tf.matmul(A,x,name='y')
    return y

#for y= Ax, the derivative is: dy/dx= transpose(A)
@tf.custom_gradient
def f2(A,x):
    y=f1(A,x)
    def grad(dzByDy): # dz/dy = 2y reaches here correctly.
        dzByDx=tf.matmul(A,dzByDy,transpose_a=True) 
        return None, dzByDx
    return y,grad

x= tf.constant([[1.],[0.]],name='x')
A= tf.constant([ [1., 2.], [3., 4.]],name='A')

#y=f1(A,x) # This works as desired
y=f2(A,x) #This line gives Error

z=tf.reduce_sum(y*y,name='z')

g=tf.gradients(ys=z,xs=x)

with tf.Session() as sess:
    print sess.run( g )