Python Matplotlib hist()对二维numpy数组输入有什么作用?

Python Matplotlib hist()对二维numpy数组输入有什么作用?,python,numpy,matplotlib,Python,Numpy,Matplotlib,假设我有一个二维Numpy数组。它应该表示PyTorch线性层的学习权重。下面我将创建一个充满高斯随机数的示例数组 将numpy导入为np 作为pd进口熊猫 将matplotlib.pyplot作为plt导入 数据=np.随机.正常(大小=(4768)) 打印(数据.形状)#(4768) 然后,我尝试使用Matplotlib函数创建值的直方图。我使用的是Jupyter笔记本(Google Colab)。当我调用下面这样的函数时(通过传递原始的2-D数组),需要很长时间才能完成,并且视觉输出非常

假设我有一个二维Numpy数组。它应该表示PyTorch线性层的学习权重。下面我将创建一个充满高斯随机数的示例数组

将numpy导入为np
作为pd进口熊猫
将matplotlib.pyplot作为plt导入
数据=np.随机.正常(大小=(4768))
打印(数据.形状)#(4768)
然后,我尝试使用Matplotlib函数创建值的直方图。我使用的是Jupyter笔记本(Google Colab)。当我调用下面这样的函数时(通过传递原始的2-D数组),需要很长时间才能完成,并且视觉输出非常奇怪

%%次
_=plt.hist(数据,存储箱=100)
#结果:
#CPU时间:用户48.5秒,系统:737毫秒,总计:49.2秒
#壁时间:49.2秒

另一方面,当我使用
restrape()
将二维数组重塑为一维数组时,
hist()
函数几乎立即完成,并且可视化具有我所期望的形状,即高斯曲线

data=data.restrape(-1)
打印(数据.形状)#(3072,)
%%次
_=plt.hist(数据,存储箱=100)
#结果:
#CPU时间:用户70.7毫秒,系统2.01毫秒,总计72.7毫秒
#壁时间:70.9毫秒

那么我第一次尝试通过二维阵列时到底发生了什么?为什么要花这么长时间?可视化图形代表什么


谢谢您的帮助。

我很惊讶matplotlib与numpy不同,它没有首先展平输入数组。但是,声明输入
x
可以是
(n,)数组或(n,)数组序列。这就是matplotlib如何将输入768个形状数组(4,)解释为一个图形中的768个直方图,这些数组在输出中显示为768个。你看不到太多,因为这些条相当薄,有76800条要显示——增加图形大小和分辨率可能会改善这一点。
data=np.random.normal(size=(768,4))
的相反情况揭示了这一点,因为现在只需显示400条:

但我们也可以看看matplotlib返回的内容:

hist_count, hist_bins, hist_bars = plt.hist(data, bins=100)
print(hist_count.shape)
>>>(768, 100)
print(hist_bars)
>>><a list of 768 BarContainer objects>
hist_count,hist_bin,hist_bar=plt.hist(数据,bin=100)
打印(历史计数形状)
>>>(768, 100)
打印(历史记录栏)

>>>

您得到768个直方图,其中4个值分布在100个箱子中。感谢您的解释。文档中没有提供太多关于
plt.hist()
如何处理二维数组的信息。我不知道matplotlib的内部工作原理,但这可能(如果不是有意的话)是一个副作用,因为它们使用类似的例程来绘制
hist()
hist2D()
?只是一个猜测,如果有更具洞察力的人发布了答案,请毫不犹豫地接受更好的答案。
import numpy as np
import matplotlib.pyplot as plt

np.random.seed(123)
data = np.random.normal(size=(4, 5))
print(data.shape) #(4, 5)

hist_count, hist_bins, hist_bars = plt.hist(data, bins=6)

print(hist_count.shape) #(5, 6)
print(hist_count)
#[[0. 1. 2. 0. 0. 1.]
# [1. 0. 0. 1. 1. 1.]
# [0. 0. 1. 1. 0. 2.]
# [0. 1. 1. 0. 2. 0.]
# [0. 0. 3. 1. 0. 0.]]
print(hist_bins) #[-2.42667924 -1.65457769 -0.88247613 -0.11037458  0.66172697  1.43382853  2.20593008]
print(hist_bars) #<a list of 5 BarContainer objects>
plt.show()