Python 使用bokeh在Jupyter笔记本中选择数据区域

Python 使用bokeh在Jupyter笔记本中选择数据区域,python,jupyter-notebook,bokeh,Python,Jupyter Notebook,Bokeh,我正在尝试使用Bokeh以交互方式选择Jupyter笔记本中的数据区域。选择数据后,将在笔记本中的后续单元格中使用Python对其进行进一步操作 下面的代码将在Jupyter笔记本中生成一个绘图。使用LassoSelectTool或其他选择工具,用户将能够选择数据的一个区域 import numpy as np from bokeh.plotting import figure, show from bokeh.io import output_notebook from bokeh.model

我正在尝试使用Bokeh以交互方式选择Jupyter笔记本中的数据区域。选择数据后,将在笔记本中的后续单元格中使用Python对其进行进一步操作

下面的代码将在Jupyter笔记本中生成一个绘图。使用LassoSelectTool或其他选择工具,用户将能够选择数据的一个区域

import numpy as np
from bokeh.plotting import figure, show
from bokeh.io import output_notebook
from bokeh.models import ColumnDataSource
# Direct output to this notebook
output_notebook()

n = 100
x = np.random.random(size=n) * 100
y = np.random.random(size=n) * 100
source = ColumnDataSource(data=dict(x=x, y=y))

figkwds = dict(plot_width=400, plot_height=300, webgl=True,
           tools="pan,lasso_select,box_select,help", 
           active_drag="lasso_select")

p1 = figure(**figkwds)
p1.scatter('x', 'y', source=source, alpha=0.8)

show(p1)

如何访问后续Jupyter单元格中的选定数据?文档建议与选择交互,但我只能将其用于更新其他Bokeh图。我不知道如何从图中获取所选数据以进行更严格的操作。

答案似乎是使用kernel.execute从Bokeh javascript返回Jupyter内核。我基于一个实例开发了以下代码:

此代码目前可作为Jupyter笔记本在此处使用:

import numpy as np
from bokeh.plotting import gridplot, figure, show
from bokeh.io import output_notebook, push_notebook
from bokeh.models import ColumnDataSource, CustomJS
# Direct output to this notebook
output_notebook()

# Create some data
n = 100
source = ColumnDataSource(data=dict(
    x=np.random.random(size=n) * 100, 
    y=np.random.random(size=n) * 100)
)
model = ColumnDataSource(data=dict(
    x=[],
    y_obs=[],
    y_pred=[],
))

# Create a callback with a kernel.execute to return to Jupyter
source.callback = CustomJS(code="""
        // Define a callback to capture errors on the Python side
        function callback(msg){
            console.log("Python callback returned unexpected message:", msg)
        }
        callbacks = {iopub: {output: callback}};

        // Select the data
        var inds = cb_obj.selected['1d'].indices;
        var d1 = cb_obj.data;
        var x = []
        var y = []
        for (i = 0; i < inds.length; i++) {
            x.push(d1['x'][inds[i]])
            y.push(d1['y'][inds[i]])
        }

        // Generate a command to execute in Python              
        data = {
            'x': x,
            'y': y,
        }        
        var data_str = JSON.stringify(data)
        var cmd = "saved_selected(" + data_str + ")"

        // Execute the command on the Python kernel
        var kernel = IPython.notebook.kernel;
        kernel.execute(cmd, callbacks, {silent : false});
""")


selected = dict()
def saved_selected(values):
    x = np.array(values['x'])
    y_obs = np.array(values['y'])

    # Sort by increasing x
    sorted_indices = x.argsort()
    x = x[sorted_indices]
    y_obs = y_obs[sorted_indices]

    if len(x) > 2:
        # Do a simple linear model
        A = np.vstack([x, np.ones(len(x))]).T
        m, c = np.linalg.lstsq(A, y_obs)[0]
        y_pred = m * x + c

        data = {'x': x,  'y_obs': y_obs, 'y_pred': y_pred}
        model.data.update(data)
        # Update the selected dict for further manipulation
        selected.update(data)
        # Update the drawing
        push_notebook(handle=handle)

figkwds = dict(plot_width=500, plot_height=300, # webgl=True,
               x_axis_label='X', y_axis_label='Y',
               tools="pan,lasso_select,box_select,reset,help")

p1 = figure(active_drag="lasso_select", **figkwds)
p1.scatter('x', 'y', source=source, alpha=0.8)

p2 = figure(**figkwds, 
            x_axis_type='log', x_range=[1, 100],
            y_axis_type='log', y_range=[1, 100])
p2.scatter('x', 'y', source=source, alpha=0.8)

p3 = figure(plot_width=500, plot_height=300, # webgl=True,
            x_axis_label='X', y_axis_label='Y',
            tools="pan,reset,help")
p3.scatter('x', 'y', source=source, alpha=0.6)
p3.scatter('x', 'y_obs', source=model, alpha=0.8, color='red')
p3.line('x', 'y_pred', source=model)

layout = gridplot([[p1], [p2], [p3]])

handle = show(layout, notebook_handle=True)