Python 设置为numpy阵列切片时,如何禁用大小更改广播?

Python 设置为numpy阵列切片时,如何禁用大小更改广播?,python,arrays,numpy,numpy-ndarray,Python,Arrays,Numpy,Numpy Ndarray,让nobroadcastary成为np.ndarray的子类。如果x是nobroadcastary的实例,arr是np.ndarray,那么我想要 x[slice] = arr 当且仅当arr.size与切片大小匹配时成功 x[1] = 1 # should succeed x[1:2] = 1 # should fail - scalar doesn't have size 2 x[1:2] = [1,2] # should succeed x[1:2] = np.array([[1,

nobroadcastary
成为
np.ndarray
的子类。如果
x
nobroadcastary
的实例,
arr
np.ndarray
,那么我想要

x[slice] = arr
当且仅当
arr.size
与切片大小匹配时成功

x[1] = 1  # should succeed
x[1:2] = 1  # should fail - scalar doesn't have size 2
x[1:2] = [1,2]  # should succeed
x[1:2] = np.array([[1,2]])  # should succeed - shapes don't match but sizes do.
x[1:2, 3:4] = np.array([1,2])  # should fail - 1x2 array doesn't have same size as 2x2 array
换句话说,只有在RHS不必改变大小以适应LHS切片的情况下,分配才会成功。我不介意它是否改变形状,例如,如果它从形状1x2的数组变成形状2x1x1的数组


我如何才能做到这一点?我现在尝试的路径是覆盖
nobroadcastary
中的_setitem_uuu,以使切片的大小与要设置的项的大小相匹配。这被证明是很棘手的,所以我想知道是否有人有更好的想法,可以使用uuu数组uu包装uuu或uuu数组u最终确定uu。

这是我提出的实现:

import numpy as np

class NoBroadcastArray(np.ndarray):

    def __new__(cls, input_array):
        return np.asarray(input_array).view(cls)

    def __setitem__(self, args, value):
        value = np.asarray(value, dtype=self.dtype)
        expected_size = self._compute_expected_size(args)
        if expected_size != value.size:
            raise ValueError(("assigned value size {} does not match expected size {} "
                              "in non-broadcasting assignment".format(value.size, expected_size)))
        return super(NoBroadcastArray, self).__setitem__(args, value)

    def _compute_expected_size(self, args):
        if not isinstance(args, tuple):
            args = (args,)
        # Iterate through indexing arguments
        arr_dim = 0
        ellipsis_dim = len(args)
        i_arg = 0
        size = 1
        adv_idx_shapes = []
        for i_arg, arg in enumerate(args):
            if isinstance(arg, slice):
                size *=  self._compute_slice_size(arg, arr_dim)
                arr_dim += 1
            elif arg is Ellipsis:
                ellipsis_dim = arr_dim
                break
            elif arg is np.newaxis:
                pass
            else:
                adv_idx_shapes.append(np.shape(arg))
                arr_dim += 1
        # Go backwards from end after ellipsis if necessary
        arr_dim = -1
        for arg in args[:i_arg:-1]:
            if isinstance(arg, slice):
                size *= self._compute_slice_size(arg, arr_dim)
                arr_dim -= 1
            elif arg is Ellipsis:
                raise IndexError("an index can only have a single ellipsis ('...')")
            elif arg is np.newaxis:
                pass
            else:
                adv_idx_shapes.append(np.shape(arg))
                arr_dim -= 1
        # Include dimensions under ellipsis
        ellipsis_end_dim = arr_dim + self.ndim + 1
        if ellipsis_dim > ellipsis_end_dim:
            raise IndexError("too many indices for array")
        for i_dim in range(ellipsis_dim, ellipsis_end_dim):
            size *= self.shape[i_dim]
        size *= NoBroadcastArray._advanced_index_size(adv_idx_shapes)
        return size

    def _compute_slice_size(self, slice, axis):
        if axis >= self.ndim or axis < -self.ndim:
            raise IndexError("too many indices for array")
        size = self.shape[axis]
        start = slice.start
        stop = slice.stop
        step = slice.step if slice.step is not None else 1
        if step == 0:
            raise ValueError("slice step cannot be zero")
        if start is not None:
            start = start if start >= 0 else start + size
            start = min(max(start, 0), size - 1)
        else:
            start = 0 if step > 0 else size - 1
        if stop is not None:
            stop = stop if stop >= 0 else stop + size
            stop = min(max(stop, 0), size)
        else:
            stop = size if step > 0 else -1
        slice_size = stop - start
        if step < 0:
            slice_size = -slice_size
            step = -step
        slice_size = ((slice_size - 1) // step + 1 if slice_size > 0 else 0)
        return slice_size

    @staticmethod
    def _advanced_index_size(shapes):
        size = 1
        if not shapes:
            return size
        dims = max(len(s) for s in shapes)
        for dim_sizes in zip(*(s[::-1] + (1,) * (dims - len(s)) for s in shapes)):
            d = 1
            for dim_size in dim_sizes:
                if dim_size != 1:
                    if d != 1 and dim_size != d:
                        raise IndexError("shape mismatch: indexing arrays could not be "
                                         "broadcast together with shapes " + " ".join(map(str, shapes)))
                    d = dim_size
            size *= d
        return size

这只会检查给定值的大小是否与索引匹配,但不会对值进行任何重塑,因此仍然可以像往常一样使用NumPy(即,可以添加额外的外部尺寸)。

以下是一个略短的解决方案:

class FixedSizeSetitemArray(np.ndarray):
    def __setitem__(self, index, value):
        value = np.asarray(value)
        current = self[index]
        if value.shape != current.shape:
            super().__setitem__(index, value)
        elif value.size == current.size:
            super().__setitem__(index, value.reshape(current.shape))
        else:
            old, new, cls = current.size, value.size, self.__class__.__name__
            raise ValueError(f"{cls} will not broadcast in __setitem__ "
                             f"(expected size {old}, got size {new})")

虽然这符合给定的确切要求,但这包括任意重塑阵列以适应给定的区域,这实际上可能并不可取。例如,这将很高兴地将形状数组
(2,2,2)
重塑为
(8,)
,反之亦然。要删除该行为,只需取出
elif

如果只想删除无关的维度,可以使用


squence
上的一些其他变体允许更广泛地删除额外维度,但如果遇到这种情况,修复您正在使用的索引可能是一个更好的主意。

看起来这可能比我最初想象的要复杂得多。这个解决方案似乎有效,所以我会将其标记为已接受,但我很好奇其他可能更简洁和简单的方法。谢谢你的时间和努力@没问题。如果你还想等待其他可能的答案,请不要将其标记为已接受-我也会发现这很有趣。如果你想吸引更多的注意力(在出版后两天),你可以考虑在这个问题上提出一个赏金。
class FixedSizeSetitemArray(np.ndarray):
    def __setitem__(self, index, value):
        value = np.asarray(value)
        current = self[index]
        if value.shape != current.shape:
            super().__setitem__(index, value)
        elif value.size == current.size:
            super().__setitem__(index, value.reshape(current.shape))
        else:
            old, new, cls = current.size, value.size, self.__class__.__name__
            raise ValueError(f"{cls} will not broadcast in __setitem__ "
                             f"(expected size {old}, got size {new})")

elif value.squeeze().shape == current.shape:
    super().__setitem__(index, value.squeeze())