Python Matplotlib:创建轮廓后,将轮廓绘制到轴

Python Matplotlib:创建轮廓后,将轮廓绘制到轴,python,matplotlib,contour,Python,Matplotlib,Contour,我在轴上添加了等高线图。现在我想在其他轴上绘制相同的轮廓,而无需重新创建轮廓(因为创建轮廓可能非常复杂)。我将函数创建的轮廓保存为CS。有没有办法在所有其他轴上重新绘制相同的轮廓 import numpy as np import matplotlib.mlab as mlab def draw_contour(): delta = 0.025 x = np.arange(-3.0, 3.0, delta) y = np.arange(-2.0, 2.0, delta)

我在轴上添加了等高线图。现在我想在其他轴上绘制相同的轮廓,而无需重新创建轮廓(因为创建轮廓可能非常复杂)。我将函数创建的轮廓保存为
CS
。有没有办法在所有其他轴上重新绘制相同的轮廓

import numpy as np
import matplotlib.mlab as mlab

def draw_contour():
    delta = 0.025
    x = np.arange(-3.0, 3.0, delta)
    y = np.arange(-2.0, 2.0, delta)
    X, Y = np.meshgrid(x, y)
    Z1 = mlab.bivariate_normal(X, Y, 1.0, 1.0, 0.0, 0.0)
    Z2 = mlab.bivariate_normal(X, Y, 1.5, 0.5, 1, 1)
    # difference of Gaussians
    Z = 10.0 * (Z2 - Z1)
    C = plt.contour(X, Y, Z, linewidth=2, colors='k')
    return C

fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2,2)
plt.sca(ax1)
CS = draw_contour()
plt.clabel(CS, inline=1, fontsize=10)

plt.sca(ax2)
# redraw CS

plt.sca(ax3)
# redraw CS

plt.sca(ax3)
# redraw CS


如果可以做到这一点,那么如何对plt.pcolormesh()和plt.scatter()执行相同的操作呢。您可以将带有等高线图的
轴的内容复制到其他

import numpy as np
import matplotlib as mpl
import matplotlib.mlab as mlab

def draw_contour(ax):
    delta = 0.025
    x = np.arange(-3.0, 3.0, delta)
    y = np.arange(-2.0, 2.0, delta)
    X, Y = np.meshgrid(x, y)
    Z1 = mlab.bivariate_normal(X, Y, 1.0, 1.0, 0.0, 0.0)
    Z2 = mlab.bivariate_normal(X, Y, 1.5, 0.5, 1, 1)
    # difference of Gaussians
    Z = 10.0 * (Z2 - Z1)
    C = ax.contour(X, Y, Z, colors='k')
    return C

# dfix is a hack to fix dashing size in copied lines. May need to adjust
def copy_linecollection(x, axdst, dfix=1.5):
    ls = [(ls[0], (ls[1][0]/dfix, ls[1][1]/dfix)) if ls[0] is not None else ls for ls in x.get_linestyles()]

    axdst.add_collection(mpl.collections.LineCollection(
        [p.vertices for p in x.get_paths()],
        linewidths=x.get_linewidths(), 
        colors=x.get_colors(),
        linestyles=ls,
    ))

def copy_text(x, axdst):
    axdst.text(
        *x.get_position(), 
        s=x.get_text(),
        color=x.get_color(), 
        verticalalignment=x.get_verticalalignment(), 
        horizontalalignment=x.get_horizontalalignment(), 
        fontproperties=x.get_fontproperties(), 
        rotation=x.get_rotation(),
        clip_box=axdst.get_position(),
        clip_on=True
    )

def copy_ax(axsrc, axdst):
    for c in axsrc.get_children():
        if isinstance(c, mpl.collections.LineCollection):
            copy_linecollection(c, axdst)

        elif isinstance(c, mpl.text.Text) and c.get_text():
            copy_text(c, axdst)

subplots_kw = {
    'sharex': True, 
    'sharey': True, 
    'figsize': (10,10),
    'gridspec_kw': {
        'hspace': 0,
        'wspace': 0
    }
}
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, **subplots_kw)

CS = draw_contour(ax1)
ax1.clabel(CS, inline=1, fontsize=10)

for ax in (ax2,ax3,ax4):
    copy_ax(ax1, ax)
输出:

这种方法不适用于其他类型的绘图(甚至不一定适用于所有
等高线
绘图),但可以推广使用。例如,您需要运行
ax.scatter(…)
,然后查看
ax.get\u children()
返回的列表内容。然后,您需要为子列表中存在的所有不同类型编写
copy_x
函数,并更新
copy_ax
以合并这些新的复制函数

编辑 提出了一个更简单、更通用的
copy\u ax版本

from copy import copy as shallowcopy

def copy_artist(x, axdst):
    xc = shallowcopy(x)
    xc.axes = None
    xc.figure = None
    xc.set_transform(axdst.transData)
    axdst.add_artist(xc)

def copy_ax(axsrc, *axdsts):
    for axdst in axdsts:
        # don't need the last 10 items (frame, spines, etc) in get_children
        for c in axsrc.get_children()[:-10]:
            copy_artist(c, axdst)
此版本的
copy_ax
可以很好地处理许多不同打印功能的结果(
plt.plot
)。。。但遗憾的是,不是
plt.scatter
plt.pcolormesh
。事实上,我现在不认为这个方法或我以前的方法可以处理这两种方法中的任何一种。作为
散点
的替代方法,您至少可以使用
绘图
制作散点图:

def draw_scatter(ax):
    x = np.random.uniform(-2,2,size=100)
    y = np.random.uniform(-2,2,size=100)
    return ax.plot(x,y,ls='none',marker='.')

subplots_kw = {
    'sharex': True, 
    'sharey': True, 
    'figsize': (10,10),
    'gridspec_kw': {
        'hspace': 0,
        'wspace': 0
    }
}
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, **subplots_kw)

draw_scatter(ax1)
copy_ax(ax1, ax2, ax3, ax4)
输出: