Matplotlib 在单独线程中运行的带有Keras模型的PyQt5 GUI,在;“运行”;按钮再次按下

Matplotlib 在单独线程中运行的带有Keras模型的PyQt5 GUI,在;“运行”;按钮再次按下,matplotlib,keras,pyqt5,qthread,Matplotlib,Keras,Pyqt5,Qthread,我有一个应用程序,它接收带标签的train数据的.pickle文件,并且应该构建一个神经网络(使用Keras)。它应该根据数据进行训练,并使用matplotlib在画布上实时显示训练/验证错误,并使用QprogressBar显示进度 我有一个自定义回调,它在每个历元结束时向主GUI发送一个pyqtSignal,发送当前历元以及累积的训练和验证错误。然后在主程序中有一个函数接收信号并触发更新方法 在我按下GUI窗口之前,一切都正常运行——然后应用程序停止运行(但网络仍在shell中运行)。我猜点击

我有一个应用程序,它接收带标签的train数据的.pickle文件,并且应该构建一个神经网络(使用Keras)。它应该根据数据进行训练,并使用matplotlib在画布上实时显示训练/验证错误,并使用QprogressBar显示进度

我有一个自定义回调,它在每个历元结束时向主GUI发送一个pyqtSignal,发送当前历元以及累积的训练和验证错误。然后在主程序中有一个函数接收信号并触发更新方法

在我按下GUI窗口之前,一切都正常运行——然后应用程序停止运行(但网络仍在shell中运行)。我猜点击中断触发了一些循环,使整个程序冻结,但我不知道是哪个

我搜索了其他关于PyQt5 GUI在使用线程时被卡住的问题,但没有找到答案-

我尝试对Qthread使用Qthread.start()而不是Qthread.run(),但在这种情况下,绘图根本不会更新

我已经编写了一个完整的示例来演示这个问题(数据文件应该是.pickle格式,并且包含一个列表[X,y],其中X-示例为numpy ndarray,y-相应的标签为numpy ndarray,可以在以下位置找到):

导入系统 导入操作系统 作为pkl导入pickle 从keras.layers导入输入,密集 从keras.callbacks导入回调 从keras.models导入模型 从PyQt5.QtWidgets导入QApplication、QProgressBar、QWidget、QVBoxLayout、QPushButton、QLineEdit、QFileDialog 从PyQt5.QtCore导入QThread,pyqtSignal 导入matplotlib 从matplotlib.backends.backend_qt5agg导入FigureCanvas qtagg as FigureCanvas 从matplotlib.figure导入图形 从PyQt5.qtwidts导入(QSizePolicy) matplotlib.use('Qt5Agg') 类图(图CAVAS): 定义初始值(self,x_标签,y_标签,父项=None,宽度=5,高度=4,dpi=100): 图=图(图尺寸=(宽度、高度),dpi=dpi) self.axes=fig.add_子批次(111) self.compute_initial_figure() self.axes.set_xlabel(x_标签) self.axes.set_ylabel(y_标签) FigureCanvas.\uuuuu init\uuuuu(self,图) self.setParent(父级) 图Canvas.setSizePolicy(self,QSizePolicy.Expanding,QSizePolicy.Expanding) 图CAVAS.updateGeometry(自我) def计算初始图形(自身): 自轴设置(范围(1100,10)) 类多点图(绘图): 定义初始值(self,parent=None,x_轴名称='x',y_轴名称='y',宽度=5,高度=4,dpi=100): 超级() self.compute_initial_figure() def计算初始图形(自身): 自轴设置(范围(0、100、10)) def plot_multi_数据(自身、x_轴名称='x',y_轴名称='y',plot_标签=无,y_列表=无): 如果y_列表不是无: self.axes.clear() 图形_句柄=[] 标记=['b:','r'] y_指数=0 对于y_列表中的y: x=范围(1,透镜(y)+1) 标签=绘图标签[y\U索引] 新的坐标图,=self.axes.plot(x,y,markers[y\u index],markersize=2,label=label) 图形\u句柄.append(新的\u绘图) y_指数+=1 self.axes.set_xticks(x,int(len(list(x))/10)) self.axes.legend(句柄=图形句柄,loc=0,fontsize=8,shadow=True) self.axes.set_xlabel(x_轴名称) self.axes.set_ylabel(y_轴名称) self.draw() 类TrainPlotCallback(回调): 定义初始化(自身,信号): 回调。初始化(自) self.train_err=[] self.val_err=[] self.signal=信号 _epoch_end上的def(self、epoch、logs={}): self.train\u err.append(1-logs.get('acc')) self.val\u err.append(1-logs.get('val\u acc')) self.signal.emit(历元,[self.train\u err,self.val\u err]) def分类模式(数据输入路径,接通信号): #//测试///////////////////// 如果os.path.存在(数据输入路径): plot\u Loss=列车plot Callback(在信号上) 将打开的(数据输入路径“rb”)作为pickle\u输入: 数据=pkl.load(pickle_in) X=数据[0] y=数据[1] 输入_size=X.shape[1] #模型创建 #//输入层///////////////////// 输入=输入(形状=(输入大小) #//输入层///////////////////// #//隐藏层///////////////////// x=密集(10,激活='relu',内核初始化器='normal')(输入)#第一层 #//隐藏层///////////////////// #//输出图层///////////////////// 预测=密集(len(y[0]),激活='softmax')(x)#输出层的长度与被预测的类的长度相同。 #//输出图层///////////////////// #模型创建 #//模型定义///////////////////// 模型=模型(输入=输入,输出=预测) model.compile(优化器='Adam', 损失class='classifical_crossentropy', 指标=['acc']) #//模型定义///////////////////// #//模型培训///////////////////// model.fit(X,y,验证分割=0.2,批量大小=100,时代=100,回调=[plot\u loss]) #//模型培训///////////////////// 类ModelThread(QThread): epoch_end_signal=pyqtSignal(int,list)#以epoch#作为第一个参数的信号,以及包含列车和验证的错误值的列表。 定义初始化(自身、数据输入路径): QThread.\uuuu init\uuuu(self) 自我数据
import sys
import os
import pickle as pkl
from keras.layers import Input, Dense
from keras.callbacks import Callback
from keras.models import Model
from PyQt5.QtWidgets import QApplication, QProgressBar, QWidget, QVBoxLayout, QPushButton, QLineEdit, QFileDialog
from PyQt5.QtCore import QThread, pyqtSignal
import matplotlib
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure
from PyQt5.QtWidgets import (QSizePolicy)
matplotlib.use('Qt5Agg')


class Plot(FigureCanvas):
    def __init__(self, x_label, y_label, parent=None, width=5, height=4, dpi=100):
        fig = Figure(figsize=(width, height), dpi=dpi)

        self.axes = fig.add_subplot(111)

        self.compute_initial_figure()
        self.axes.set_xlabel(x_label)
        self.axes.set_ylabel(y_label)

        FigureCanvas.__init__(self, fig)
        self.setParent(parent)

        FigureCanvas.setSizePolicy(self, QSizePolicy.Expanding, QSizePolicy.Expanding)

        FigureCanvas.updateGeometry(self)

    def compute_initial_figure(self):
        self.axes.set_xticks(range(1, 100, 10))


class MultiPlot(Plot):
    def __init__(self, parent=None, x_axis_name='X', y_axis_name='Y', width=5, height=4, dpi=100):
        super().__init__(x_axis_name, y_axis_name, parent, width, height, dpi)
        self.compute_initial_figure()

    def compute_initial_figure(self):
        self.axes.set_xticks(range(0, 100, 10))

    def plot_multi_data(self, x_axis_name='X', y_axis_name='Y', plot_labels=None, y_list=None):
        if y_list is not None:
            self.axes.clear()

            graph_handles = []

            markers = ['b:', 'r']
            y_index = 0

            for y in y_list:

                x = range(1, len(y) + 1)
                label = plot_labels[y_index]

                new_plot, = self.axes.plot(x, y, markers[y_index], markersize=2, label=label)
                graph_handles.append(new_plot)
                y_index += 1

                self.axes.set_xticks(x, int(len(list(x))/10))
            self.axes.legend(handles=graph_handles, loc=0, fontsize=8, shadow=True)

        self.axes.set_xlabel(x_axis_name)
        self.axes.set_ylabel(y_axis_name)

        self.draw()


class TrainPlotCallback(Callback):
    def __init__(self, signal):
        Callback.__init__(self)
        self.train_err = []
        self.val_err = []
        self.signal = signal

    def on_epoch_end(self, epoch, logs={}):
        self.train_err.append(1 - logs.get('acc'))
        self.val_err.append(1 - logs.get('val_acc'))

        self.signal.emit(epoch, [self.train_err, self.val_err])


def classification_model(data_input_path, on_epoch_end_signal):

    # ///////////////////// TEST /////////////////////
    if os.path.exists(data_input_path):
        plot_losses = TrainPlotCallback(on_epoch_end_signal)
        with open(data_input_path, 'rb') as pickle_in:
            data = pkl.load(pickle_in)
            X = data[0]
            y = data[1]

        input_size = X.shape[1]

        # MODEL CREATION
        # ///////////////////// INPUT LAYER /////////////////////
        inputs = Input(shape=(input_size,))
        # ///////////////////// INPUT LAYER /////////////////////
        # ///////////////////// HIDDEN LAYER /////////////////////
        x = Dense(10, activation='relu', kernel_initializer='normal')(inputs)   # THE FIRST LAYER
        # ///////////////////// HIDDEN LAYER /////////////////////
        # ///////////////////// OUTPUT LAYERS /////////////////////
        predictions = Dense(len(y[0]), activation='softmax')(x)  # the length of the output layer is as the length of the classes being predicted.
        # ///////////////////// OUTPUT LAYERS /////////////////////
        # MODEL CREATION

        # ///////////////////// MODEL DEFINITION /////////////////////
        model = Model(inputs=inputs, outputs=predictions)
        model.compile(optimizer='Adam',
                      loss='categorical_crossentropy',
                      metrics=['acc'])
        # ///////////////////// MODEL DEFINITION /////////////////////

        # ///////////////////// MODEL TRAINING /////////////////////
        model.fit(X, y, validation_split=0.2, batch_size=100, epochs=100, callbacks=[plot_losses])
        # ///////////////////// MODEL TRAINING /////////////////////


class ModelThread(QThread):
    epoch_end_signal = pyqtSignal(int, list)  # signal that has epoch # as the first parameter, and a list that contains the error values for the train and validation.

    def __init__(self, data_input_path):
        QThread.__init__(self)
        self.data_input_path = data_input_path

    def __del__(self):
        self.wait()

    def run(self):
        classification_model(data_input_path=self.data_input_path,
                             on_epoch_end_signal=self.epoch_end_signal
                             )


class DashBoard(QWidget):
    def __init__(self):
        super().__init__()
        self.main_v_box = QVBoxLayout(self)

        #  <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< STRINGS >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
        self.input_data_path_str = ''
        #  <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< STRINGS >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
        #  <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< PROGRESS BAR >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
        self.progress_bar = QProgressBar()
        #  <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< PROGRESS BAR >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
        #  <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< BUTTONS >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
        self.run_model_btn = QPushButton('Run')
        self.browse_train_data_file_path_btn = QPushButton('Browse')
        #  <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< BUTTONS >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
        #  <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< MULTI PLOTS >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
        self.in_training_plot = MultiPlot(x_axis_name='Epoch Number', y_axis_name='Error')
        #  <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< MULTI PLOTS >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
        #  <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< LINE EDITS >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
        self.train_data_file_path_le = QLineEdit()
        #  <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< LINE EDITS >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
        self.init()
        self.pack()
        self.showMaximized()

    def init(self):
        self.progress_bar.hide()
        self.browse_train_data_file_path_btn.clicked.connect(self.on_btn_click)
        self.run_model_btn.clicked.connect(self.on_btn_click)

    def pack(self):
        self.main_v_box.addWidget(self.train_data_file_path_le)
        self.main_v_box.addWidget(self.browse_train_data_file_path_btn)
        self.main_v_box.addWidget(self.in_training_plot)
        self.main_v_box.addWidget(self.run_model_btn)
        self.main_v_box.addWidget(self.progress_bar)

    def on_btn_click(self):
        btn_index = self.sender()

        if btn_index == self.browse_train_data_file_path_btn:
            self.input_data_path_str = QFileDialog.getOpenFileName(self, '.pickle files', os.getenv('HOME'), '*.pickle')[0]
            self.train_data_file_path_le.setText(self.input_data_path_str)
        elif btn_index == self.run_model_btn:
            model_thread = ModelThread(data_input_path=self.input_data_path_str)
            model_thread.epoch_end_signal.connect(self.update_ui_on_epoch_end)
            self.progress_bar.show()
            model_thread.run()
            self.progress_bar.hide()

    def update_ui_on_epoch_end(self, current_epoch_num, error_lists):
        if current_epoch_num < 100:
            self.progress_bar.setValue(current_epoch_num)
        else:
            self.progress_bar.setValue(100)
        self.in_training_plot.plot_multi_data(x_axis_name='Epoch', y_axis_name='Error', plot_labels=['Train Accuracy', 'Validation Accuracy'], y_list=[error_lists[0], error_lists[1]])

    def run_model(self):
        if os.path.exists(self.train_data_file_path_str) and os.path.exists(self.output_data_path_str):
            train_thread = ModelThread(data_input_path='')
            train_thread.epoch_end_signal.connect(self.update_ui_on_epoch_end)
            # train_thread.start()
            self.progress_bar.show()
            train_thread.run()
            self.progress_bar.hide()


if __name__ == '__main__':
    app = QApplication(sys.argv)
    main_menu = DashBoard()
    sys.exit(app.exec_())
import os
import sys
from functools import partial
import pickle as pkl
from keras.layers import Input, Dense
from keras.callbacks import Callback
from keras.models import Model

from PyQt5 import QtCore, QtWidgets

import matplotlib
matplotlib.use('Qt5Agg')

from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure

class Plot(FigureCanvas):
    def __init__(self, x_label, y_label, parent=None, width=5, height=4, dpi=100):
        fig = Figure(figsize=(width, height), dpi=dpi)

        self.axes = fig.add_subplot(111)

        self.compute_initial_figure()
        self.axes.set_xlabel(x_label)
        self.axes.set_ylabel(y_label)

        FigureCanvas.__init__(self, fig)
        self.setParent(parent)

        FigureCanvas.setSizePolicy(self, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding)

        FigureCanvas.updateGeometry(self)

    def compute_initial_figure(self):
        self.axes.set_xticks(range(1, 100, 10))


class MultiPlot(Plot):
    def __init__(self, parent=None, x_axis_name='X', y_axis_name='Y', width=5, height=4, dpi=100):
        super().__init__(x_axis_name, y_axis_name, parent, width, height, dpi)
        self.compute_initial_figure()

    def compute_initial_figure(self):
        self.axes.set_xticks(range(0, 100, 10))

    def plot_multi_data(self, x_axis_name='X', y_axis_name='Y', plot_labels=None, y_list=None):
        if y_list is not None:
            self.axes.clear()

            graph_handles = []

            markers = ['b:', 'r']
            y_index = 0

            for y in y_list:

                x = range(1, len(y) + 1)
                label = plot_labels[y_index]

                new_plot, = self.axes.plot(x, y, markers[y_index], markersize=2, label=label)
                graph_handles.append(new_plot)
                y_index += 1

                self.axes.set_xticks(x, int(len(list(x))/10))
            self.axes.legend(handles=graph_handles, loc=0, fontsize=8, shadow=True)

        self.axes.set_xlabel(x_axis_name)
        self.axes.set_ylabel(y_axis_name)

        self.draw()


class TrainPlotCallback(Callback):
    def __init__(self, signal):
        Callback.__init__(self)
        self.train_err = []
        self.val_err = []
        self.signal = signal

    def on_epoch_end(self, epoch, logs={}):
        self.train_err.append(1 - logs.get('acc'))
        self.val_err.append(1 - logs.get('val_acc'))
        self.signal.emit(epoch, [self.train_err, self.val_err])

def classification_model(data_input_path, on_epoch_end_signal):

    # ///////////////////// TEST /////////////////////
    if os.path.exists(data_input_path):
        plot_losses = TrainPlotCallback(on_epoch_end_signal)
        with open(data_input_path, 'rb') as pickle_in:
            data = pkl.load(pickle_in)
            X = data[0]
            y = data[1]

        input_size = X.shape[1]

        # MODEL CREATION
        # ///////////////////// INPUT LAYER /////////////////////
        inputs = Input(shape=(input_size,))
        # ///////////////////// INPUT LAYER /////////////////////
        # ///////////////////// HIDDEN LAYER /////////////////////
        x = Dense(10, activation='relu', kernel_initializer='normal')(inputs)   # THE FIRST LAYER
        # ///////////////////// HIDDEN LAYER /////////////////////
        # ///////////////////// OUTPUT LAYERS /////////////////////
        predictions = Dense(len(y[0]), activation='softmax')(x)  # the length of the output layer is as the length of the classes being predicted.
        # ///////////////////// OUTPUT LAYERS /////////////////////
        # MODEL CREATION

        # ///////////////////// MODEL DEFINITION /////////////////////
        model = Model(inputs=inputs, outputs=predictions)
        model.compile(optimizer='Adam',
                      loss='categorical_crossentropy',
                      metrics=['acc'])
        # ///////////////////// MODEL DEFINITION /////////////////////

        # ///////////////////// MODEL TRAINING /////////////////////
        model.fit(X, y, validation_split=0.2, batch_size=100, epochs=100, callbacks=[plot_losses])
        # ///////////////////// MODEL TRAINING /////////////////////


class Worker(QtCore.QObject):
    started = QtCore.pyqtSignal()
    finished = QtCore.pyqtSignal()
    epoch_end_signal = QtCore.pyqtSignal(int, list)  # signal that has epoch # as the first parameter, and a list that contains the error values for the train and validation.

    @QtCore.pyqtSlot(str)
    def start_task(self, input_path):
        self.started.emit()
        classification_model(data_input_path=input_path,
                             on_epoch_end_signal=self.epoch_end_signal)
        self.finished.emit()


class DashBoard(QtWidgets.QWidget):
    def __init__(self):
        super().__init__()
        self.main_v_box = QtWidgets.QVBoxLayout(self)
        self.input_data_path_str = ''
        self.progress_bar = QtWidgets.QProgressBar()
        self.run_model_btn = QtWidgets.QPushButton('Run')
        self.browse_train_data_file_path_btn = QtWidgets.QPushButton('Browse')
        self.in_training_plot = MultiPlot(x_axis_name='Epoch Number', y_axis_name='Error')
        self.train_data_file_path_le = QtWidgets.QLineEdit()
        self.init()
        self.pack()
        self.showMaximized()

    def init(self):
        self.worker = Worker()
        thread = QtCore.QThread(self)
        thread.start()
        self.worker.moveToThread(thread)
        self.progress_bar.hide()
        self.browse_train_data_file_path_btn.clicked.connect(self.on_btn_click)
        self.run_model_btn.clicked.connect(self.on_btn_click)
        self.worker.epoch_end_signal.connect(self.update_ui_on_epoch_end)
        self.worker.started.connect(self.progress_bar.show)
        self.worker.finished.connect(self.progress_bar.hide)
        self.worker.started.connect(partial(self.run_model_btn.setEnabled, False))
        self.worker.finished.connect(partial(self.run_model_btn.setEnabled, True))

    def pack(self):
        self.main_v_box.addWidget(self.train_data_file_path_le)
        self.main_v_box.addWidget(self.browse_train_data_file_path_btn)
        self.main_v_box.addWidget(self.in_training_plot)
        self.main_v_box.addWidget(self.run_model_btn)
        self.main_v_box.addWidget(self.progress_bar)

    def on_btn_click(self):
        btn_index = self.sender()

        if btn_index == self.browse_train_data_file_path_btn:
            self.input_data_path_str, _ = QtWidgets.QFileDialog.getOpenFileName(self, '.pickle files', os.getenv('HOME'), '*.pickle')
            self.train_data_file_path_le.setText(self.input_data_path_str)
        elif btn_index == self.run_model_btn:
            QtCore.QTimer.singleShot(0, partial(self.worker.start_task, self.input_data_path_str))

    def update_ui_on_epoch_end(self, current_epoch_num, error_lists):
        if current_epoch_num < 100:
            self.progress_bar.setValue(current_epoch_num)
        else:
            self.progress_bar.setValue(100)
        self.in_training_plot.plot_multi_data(x_axis_name='Epoch', y_axis_name='Error', plot_labels=['Train Accuracy', 'Validation Accuracy'], y_list=[error_lists[0], error_lists[1]])


if __name__ == '__main__':
    app = QtWidgets.QApplication(sys.argv)
    main_menu = DashBoard()
    sys.exit(app.exec_())