Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/tensorflow/5.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Python 如何在Keras中为自动编码器洗牌训练数据_Python_Tensorflow_Machine Learning_Keras_Deep Learning - Fatal编程技术网

Python 如何在Keras中为自动编码器洗牌训练数据

Python 如何在Keras中为自动编码器洗牌训练数据,python,tensorflow,machine-learning,keras,deep-learning,Python,Tensorflow,Machine Learning,Keras,Deep Learning,我在Keras中使用自动编码器。我希望对训练数据进行洗牌x_train,以便自动编码器将数据重建为来自同一类的不同样本。这可能吗 model_train = autoencoder.fit(x_train, x_train, batch_size=32, epochs=1000, shuffle=True, callbacks=[checkpoint, early_stopping], valid

我在Keras中使用自动编码器。我希望对训练数据进行洗牌
x_train
,以便自动编码器将数据重建为来自同一类的不同样本。这可能吗

model_train = autoencoder.fit(x_train, x_train,
          batch_size=32,
          epochs=1000,
          shuffle=True,
          callbacks=[checkpoint, early_stopping],
          validation_data=(x_test, x_test))

我假设
shuffle=True
正在洗牌
x\u train
,并根据相同的对计算损失,这不是我想要它做的。

这是可能的,但Keras不会为您这样做,因为它将数据和标签一起洗牌。假设您已获得标签,我发现此函数对于您的目的非常有用:

import numpy as np

def create_pairs(data, labels):
    # Exclude batch dimension
    pairs = np.empty(0, 2, *data.shape[1:])

    for label in np.unique(labels):
        idxs = np.where(labels == label)[0]
        # Indexes must be even in order to create pairs
        idxs = idxs if len(idxs) % 2 == 0 else idxs[:-1]
        np.random.shuffle(idxs)

        samples = data[idxs].reshape((-1, 2, *data.shape[1:]))
        pairs = np.vstack((pairs, samples))
    return pairs[:, 0], pairs[:, 1]
现在,数据已被洗牌并分成对,您可以训练您的模型:

x_train, y_train = create_pairs(data, labels)
history = model.fit(
    x_train, y_train,
    batch_size=32,
    epochs=1000,
    shuffle=True,
    callbacks=[checkpoint, early_stopping],
    validation_split=0.2)