Python Tensorflow:从任意长度的复张量中提取连续面片

Python Tensorflow:从任意长度的复张量中提取连续面片,python,tensorflow,tensorflow2.0,tensorflow-datasets,Python,Tensorflow,Tensorflow2.0,Tensorflow Datasets,我试图找出如何从长度可变的复数张量中提取序列面片。提取是作为tf.data管道的一部分执行的 如果张量不复杂,我会使用tf.image.extract\u image\u patches,如下所示 但是,该函数不适用于复张量。我试过下面的方法,但失败了,因为张量的长度未知 def extract_sequential_补丁(图像): 图像长度=tf.shape(图像)[0] num\u patches=图像长度//(128//4) 补丁=[] 对于范围内的i(num_补丁): 开始=i*128

我试图找出如何从长度可变的复数张量中提取序列面片。提取是作为
tf.data
管道的一部分执行的

如果张量不复杂,我会使用
tf.image.extract\u image\u patches
,如下所示

但是,该函数不适用于复张量。我试过下面的方法,但失败了,因为张量的长度未知

def extract_sequential_补丁(图像):
图像长度=tf.shape(图像)[0]
num\u patches=图像长度//(128//4)
补丁=[]
对于范围内的i(num_补丁):
开始=i*128
结束=开始+128
append(图像[开始:结束,…])
返回tf.stack(补丁)
但是我得到了一个错误:

InaccessibleTensorError: The tensor 'Tensor("strided_slice:0", shape=(None, 512, 2), dtype=complex64)' cannot be accessed here: it is defined in another function or code block. Use return values, explicit Python locals or TensorFlow collections to access it. Defined in: FuncGraph(name=while_body_2100, id=140313967335120)

我尝试过用
@tf.function

进行自由修饰,我认为您需要调整索引的计算,以确保它们不会超出范围,但撇开细节不谈,您的代码几乎是
tf.function
所期望的,除了使用Python列表;你需要改用Tensorary

类似的方法应该会奏效(指数计算可能不完全正确):

您可以找到更多关于Python列表目前无法在中工作的详细信息


也就是说,如果extract_image_补丁的内核得到优化,那么使用real/imag技巧可能会更快。我建议对这两种方法进行测试。

你能提供一个可复制的例子(作为函数输入的复数张量)吗?你不能用/来分割数字的实部和虚部,对每一个数字应用实值的解,然后用@jdehesa将它们连接起来吗?我想到了。我担心所有的拆分和重组都会增加很多开销,但也许这就是解决方法。再看看你的代码,你想要一个补丁的“滑动窗口”,就像链接的答案一样,还是像在代码中那样,只是将数据“拆分”成128个元素的一部分?因为我认为你的代码做的和重塑一样?我想要一个滑动窗口。为了简单起见,我刚刚输入了128。我现在就改。
@tf.function
def extract_sequential_patches(image, size, stride):
    image_length = tf.shape(image)[0]
    num_patches = (image_length - size) // stride + 1
    patches = tf.TensorArray(image.dtype, size=num_patches)
    for i in range(num_patches):
        start = i * stride
        end = start + size
        patches = patches.write(i, image[start:end, ...])
    return patches.stack()