为包含深度嵌套numpy数组的Python对象实现_eq__

为包含深度嵌套numpy数组的Python对象实现_eq__,python,numpy,oop,Python,Numpy,Oop,在对象属性的上下文中,numpy数组无法与==(使用np.array_equal的语义)进行比较,这是我遇到的问题 考虑以下示例: >>> import numpy as np >>> class A: ... def __init__(self, a): ... self.a = a ... def __eq__(self, other): ... return self.__dict__ == other._

在对象属性的上下文中,numpy数组无法与
==
(使用
np.array_equal
的语义)进行比较,这是我遇到的问题

考虑以下示例:

>>> import numpy as np
>>> class A:
...     def __init__(self, a):
...         self.a = a
...     def __eq__(self, other):
...         return self.__dict__ == other.__dict__
...
>>> x = A(a=[1, np.array([1, 2])])
>>> y = A(a=[1, np.array([1, 2])])
>>> x == y
Traceback (most recent call last):
  File "<ipython-input-33-9cfbd892cdaa>", line 1, in <module>
    x == y
  File "<ipython-input-30-790950997d4f>", line 5, in __eq__
    return self.__dict__ == other.__dict__
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
>>将numpy作为np导入
>>>A类:
...     定义初始化(self,a):
...         self.a=a
...     定义(自身、其他):
...         返回self.\uu dict.\uuu==其他.\uu dict__
...
>>>x=A(A=[1,np.array([1,2]))
>>>y=A(A=[1,np.array([1,2]))
>>>x==y
回溯(最近一次呼叫最后一次):
文件“”,第1行,在
x==y
文件“”,第5行,在__
返回self.\uu dict.\uuu==其他.\uu dict__
ValueError:包含多个元素的数组的真值不明确。使用a.any()或a.all()
(忽略
\uuuu eq\uuuu
不是完美的,它至少应该检查
其他
的类型,但这是为了简洁起见)

我如何实现一个
\uuuuu eq\uuuu
函数来处理嵌套在对象属性深处的numpy数组(假设其他所有内容,如本例中的列表,与
=
相比都很好)?numpy数组可能出现在列表、元组或dict中任意深度的嵌套级别

我尝试了一个递归
eq
函数的“手动”实现,该函数将
=
应用于所有属性,并在遇到numpy数组时使用
np.array\u equal
,但这比预期的要复杂


是否有人有合适的功能或简单的解决方法?

如果可以选择更改对象
x
y
,则可以根据您的偏好覆盖
np.ndarray
\uu eq\uuuuu
方法

class eqarr(np.ndarray):
    def __eq__(self, other):
        return np.array_equal(self, other)

class A:
     def __init__(self, a):
         self.a = a
     def __eq__(self, other):
         return self.__dict__ == other.__dict__

x = A(a=[1, eqarr([1, 2])])
y = A(a=[1, eqarr([1, 2])])
x == y
此结果为
True

如果这是不可能的,我现在能想到的唯一解决方案就是实际实现一个递归等式检查函数。我的尝试如下:

def eq(a, b):
    if not (hasattr(a, '__iter__') or type(a) == str):
        return a == b

    try:
        if not len(a) == len(b):
            return False

        if type(a) == np.ndarray:
            return np.array_equal(a, b)
        if isinstance(a, dict):
            return all(eq(v, b[k]) for k, v in a.items())
        else:
            return all(eq(aa, bb) for aa, bb in zip(a, b))
    except (TypeError, KeyError):
        return False


class A:
     def __init__(self, a):
         self.a = a
     def __eq__(self, other):
         return eq(self.__dict__, other.__dict__)
def eq(a, b):
    try:
        return np.all(a == b)
    except ValueError:
        pass

    try:
        if not len(a) == len(b):
            return False

        if type(a) == np.ndarray:
            return np.array_equal(a, b)
        if isinstance(a, dict):
            return all(eq(v, b[k]) for k, v in a.items())
        else:
            return all(eq(aa, bb) for aa, bb in zip(a, b))
    except (TypeError, KeyError):
        return False
用你的例子和我提出的所有例子,它都奏效了。只要嵌套项目具有
\uuuu iter\uuuu
\uu len\uuuu
属性,该解决方案就应该适用

我希望我解释了所有可能的错误,但您可能需要稍微调整代码,使其绝对故障安全

如果你发现一个反例,请提供它作为评论。我相信代码可以相应地调整

eq
的性能可能不太好,但我不知道这是否是您最关心的问题

如果numpy数组在您的层次结构中非常少见(并且通常接近顶部),那么您可以首先尝试正常比较。这可能如下所示:

def eq(a, b):
    if not (hasattr(a, '__iter__') or type(a) == str):
        return a == b

    try:
        if not len(a) == len(b):
            return False

        if type(a) == np.ndarray:
            return np.array_equal(a, b)
        if isinstance(a, dict):
            return all(eq(v, b[k]) for k, v in a.items())
        else:
            return all(eq(aa, bb) for aa, bb in zip(a, b))
    except (TypeError, KeyError):
        return False


class A:
     def __init__(self, a):
         self.a = a
     def __eq__(self, other):
         return eq(self.__dict__, other.__dict__)
def eq(a, b):
    try:
        return np.all(a == b)
    except ValueError:
        pass

    try:
        if not len(a) == len(b):
            return False

        if type(a) == np.ndarray:
            return np.array_equal(a, b)
        if isinstance(a, dict):
            return all(eq(v, b[k]) for k, v in a.items())
        else:
            return all(eq(aa, bb) for aa, bb in zip(a, b))
    except (TypeError, KeyError):
        return False

更改对象不是一个选项,不幸的是,您是否可以提供一个深度嵌套列表的示例(有或没有所需的新对象)?我可能有
a=(1[1,{'key]:np.array([1,2])}[1[np.array([1,2])],2)
a=[1[MyMutableSequence([1,np.array([1,2])),1]
Sure,手动实现递归的
eq
是可能的(我一开始就是这么做的),但是我担心,要得到一个非常通用的、不会出现奇怪边缘情况的bug的东西几乎是不可能的。你对嵌套结构中的对象了解多少?它必须有多普遍?当然,很容易构造使我的解决方案失败的类。但一般来说,你需要它来工作吗?(我知道我的解决方案不是很优雅,你可能以前也提出过类似或更好的解决方案-如果能得到关于你遇到的问题的更具体的信息,那就太好了。)
np.all\u close
似乎是比较numpy数组的一种可靠方法。我知道
np.all\u close
,但是
==
应该始终是完全相等的,这就是我在这里感兴趣的。此外,这并没有解决numpy数组没有实现
==
,因此无法在嵌套比较中正常工作的问题。