Python tf.GradientTape()不';不能在切片输出上工作
下面是我试图运行的一段代码:Python tf.GradientTape()不';不能在切片输出上工作,python,tensorflow,gradienttape,Python,Tensorflow,Gradienttape,下面是我试图运行的一段代码: import tensorflow as tf a = tf.constant([[1, 2], [2, 3]], dtype=tf.float32) b = tf.constant([[1, 2], [2, 3]], dtype=tf.float32) with tf.GradientTape() as tape1, tf.GradientTape() as tape2: tape1.watch(a) tape2.watch(a)
import tensorflow as tf
a = tf.constant([[1, 2], [2, 3]], dtype=tf.float32)
b = tf.constant([[1, 2], [2, 3]], dtype=tf.float32)
with tf.GradientTape() as tape1, tf.GradientTape() as tape2:
tape1.watch(a)
tape2.watch(a)
c = a * b
grad1 = tape1.gradient(c, a)
grad2 = tape2.gradient(c[:, 0], a)
print(grad1)
print(grad2)
这是输出:
tf.Tensor(
[[1. 2.]
[2. 3.]], shape=(2, 2), dtype=float32)
None
正如您所观察到的,tf.GradientTape()无法处理切片输出。有什么办法吗?是的,对张量所做的一切都需要在磁带上下文中进行。您可以这样相对轻松地修复它:
import tensorflow as tf
a = tf.constant([[1, 2], [2, 3]], dtype=tf.float32)
b = tf.constant([[1, 2], [2, 3]], dtype=tf.float32)
with tf.GradientTape() as tape1, tf.GradientTape() as tape2:
tape1.watch(a)
tape2.watch(a)
c = a * b
c_sliced = c[:, 0]
grad1 = tape1.gradient(c, a)
grad2 = tape2.gradient(c_sliced, a)
print(grad1)
print(grad2)