为什么这个布尔值在这个贝叶斯分类器中?(Python问题?)

为什么这个布尔值在这个贝叶斯分类器中?(Python问题?),python,boolean,generator,bayesian,generative-adversarial-network,Python,Boolean,Generator,Bayesian,Generative Adversarial Network,我正在学习GANs(我是python的初学者),我在前面的练习中发现了我不理解的这部分代码。具体地说,我不明白为什么要用第9行的布尔值(Xk=X[Y==k]),原因如下所述 class BayesClassifier: def fit(self, X, Y): # assume classes are numbered 0...K-1 self.K = len(set(Y)) self.gaussians = [] self.p_y = np.zeros(s

我正在学习GANs(我是python的初学者),我在前面的练习中发现了我不理解的这部分代码。具体地说,我不明白为什么要用第9行的布尔值(Xk=X[Y==k]),原因如下所述

class BayesClassifier:
  def fit(self, X, Y):
    # assume classes are numbered 0...K-1
    self.K = len(set(Y))

    self.gaussians = []
    self.p_y = np.zeros(self.K)
    for k in range(self.K):
      Xk = X[Y == k]
      self.p_y[k] = len(Xk)
      mean = Xk.mean(axis=0)
      cov = np.cov(Xk.T)
      g = {'m': mean, 'c': cov}
      self.gaussians.append(g)
    # normalize p(y)
    self.p_y /= self.p_y.sum()
  • 该布尔值根据Y的真实性返回0或1== k、 由于这个原因,Xk总是X列表的第一个或第二个值。你不觉得这有什么用
  • 在第10行中,len(Xk)始终为1,为什么它使用该参数而不是单个1
  • 每次仅使用一个值计算下一行的平均值和协方差

  • 我觉得我没有理解一些非常基本的东西。

    你应该考虑到,
    X,Y,k
    是NumPy数组,而不是标量,一些运算符对它们重载。特别是,
    ==
    和基于布尔的索引<代码>=将是按元素比较,而不是整个数组比较

    看看它是如何工作的:

    In [9]: Y = np.array([0,1,2])                                                                                        
    In [10]: k = np.array([0,1,3])                                                                                       
    In [11]: Y==k                                                                                                        
    
    Out[11]: array([ True,  True, False])
    
    因此,
    =
    的结果是一个布尔数组

    In [12]: X=np.array([0,2,4])                                                                                         
    In [13]: X[Y==k]                                                                                                     
    
    Out[13]: array([0, 2])
    
    当条件为
    True


    因此
    len(Xk)
    将是
    X
    k
    之间匹配元素的数量
    X,Y,k
    您应该考虑到
    X,Y,k
    是NumPy数组,而不是标量,并且一些运算符对它们重载。特别是,
    ==
    和基于布尔的索引<代码>=将是按元素比较,而不是整个数组比较

    看看它是如何工作的:

    In [9]: Y = np.array([0,1,2])                                                                                        
    In [10]: k = np.array([0,1,3])                                                                                       
    In [11]: Y==k                                                                                                        
    
    Out[11]: array([ True,  True, False])
    
    因此,
    =
    的结果是一个布尔数组

    In [12]: X=np.array([0,2,4])                                                                                         
    In [13]: X[Y==k]                                                                                                     
    
    Out[13]: array([0, 2])
    
    当条件为
    True

    因此
    len(Xk)
    将是
    X
    k
    之间匹配元素的数量。谢谢,阿尔泰

    你说得对。我通过另一个渠道找到了另一个答案,这里是:

    这是一个Numpy数组-它是Numpy数组的一个特殊特性,称为 布尔索引,用于仅过滤数组中的值 过滤器返回True时:

    将numpy作为np导入

    数组([1,2,3,4,5])过滤器=a>3

    打印(过滤器)

    [假,假,假,真,真]

    打印(一个[过滤器])

    [4,5]

    谢谢你,阿泰姆

    你说得对。我通过另一个渠道找到了另一个答案,这里是:

    这是一个Numpy数组-它是Numpy数组的一个特殊特性,称为 布尔索引,用于仅过滤数组中的值 过滤器返回True时:

    将numpy作为np导入

    数组([1,2,3,4,5])过滤器=a>3

    打印(过滤器)

    [假,假,假,真,真]

    打印(一个[过滤器])

    [4,5]