Python Tensorflow:使用子/超对角线上的输入创建对角线矩阵
我有以下代码:Python Tensorflow:使用子/超对角线上的输入创建对角线矩阵,python,tensorflow,Python,Tensorflow,我有以下代码: import tensorflow as tf N = 10 X = tf.ones([N,], dtype=tf.float64) D = tf.linalg.diag(X, k=1, num_rows=N+1, num_cols=N+1) print(D) 基于,我希望返回一个11x11张量,在第一个超对角线上插入X(即使没有可选的num_rows和num_cols参数)。然而,结果是 tf.Tensor( [[1. 0. 0. 0. 0. 0. 0. 0. 0. 0.
import tensorflow as tf
N = 10
X = tf.ones([N,], dtype=tf.float64)
D = tf.linalg.diag(X, k=1, num_rows=N+1, num_cols=N+1)
print(D)
基于,我希望返回一个11x11张量,在第一个超对角线上插入X(即使没有可选的num_rows
和num_cols
参数)。然而,结果是
tf.Tensor(
[[1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]], shape=(10, 10), dtype=float64)
有什么明显的地方我遗漏了吗?我可以告诉你为什么这不起作用,但我不知道修复方法是什么。可能会引发github问题 如果您查看
数组中的行_ops.py
。它执行兼容性检查tf.compat.forward\u compatible
,查看兼容性窗口是否已过期。返回False
(对于TF 2.0.0和2.1.0rc0)。由于这个原因,它执行
return gen\u array\u ops.matrix\u diag(对角线=对角线,name=名称)
您可以看到,调用时没有使用
k
,num\u行
,num\u列
。因此,如果tf.compat.forward\u compatible
检查失败,该方法目前完全不考虑这些参数。谢谢!现在,我只是转到源文件并对其进行了相应的编辑。我最终会在GitHub上提交一个问题