Python 我可以在使用所有输出的自定义损耗的多输出模型上使用Keras model.fit()吗';Tensorflow 2中的目标和预测?

Python 我可以在使用所有输出的自定义损耗的多输出模型上使用Keras model.fit()吗';Tensorflow 2中的目标和预测?,python,tensorflow,keras,tensorflow2.0,tf.keras,Python,Tensorflow,Keras,Tensorflow2.0,Tf.keras,我尝试过让Keras model.fit()在我的多输出模型上工作,但失败了,该模型使用了TF2中所有输出的目标和预测(特别是两个输出) 当我尝试在使用Keras函数API制作的模型上执行此操作时,我得到了错误:“SymbolicException:eager execution函数的输入不能是Keras符号张量,但可以找到…” 这意味着我不能使用loss函数,因为它会将一个渴望的张量返回给使用符号张量(函数API模型)的Keras DAG。为了解决这个问题,我使用model.add_loss(

我尝试过让Keras model.fit()在我的多输出模型上工作,但失败了,该模型使用了TF2中所有输出的目标和预测(特别是两个输出)

当我尝试在使用Keras函数API制作的模型上执行此操作时,我得到了错误:“SymbolicException:eager execution函数的输入不能是Keras符号张量,但可以找到…” 这意味着我不能使用loss函数,因为它会将一个渴望的张量返回给使用符号张量(函数API模型)的Keras DAG。为了解决这个问题,我使用model.add_loss()而不是将loss函数传递给model.compile(),但我相信这会占用GPU内存并导致OOM错误

我尝试了一些变通方法,将我的函数API模型放在Keras子类模型中,或者创建一个全新的Keras子类模型

解决方案1的代码如下所示,它的运行给了我跨越不同时代的关于各种梯度剪辑的训练的NaN,并给出了0值输出

解决方案2在override call()方法中给了我一个错误,因为输入参数在模型编译时和运行时的形状不同,因为我的模型(以一种奇怪的方式)有3个输入:1是DLNN的实际输入,另外2个是输入样本的目标。这样,我就可以将每个样本中的目标放入损失函数中

from scipy.io import wavfile
import scipy.signal as sg
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.layers import Input, SimpleRNN, Dense, Lambda, TimeDistributed, Layer, LSTM, Bidirectional, BatchNormalization, Concatenate
from tensorflow.keras.models import Model
from tensorflow.keras.activations import relu
from tensorflow.keras.callbacks import EarlyStopping
import numpy as np
import datetime
import numpy as np
import math
import random
import json
import os
import sys



# Loss function
def discriminative_loss(piano_true, noise_true, piano_pred, noise_pred, loss_const):
    last_dim = piano_pred.shape[1] * piano_pred.shape[2]
    return (
        tf.math.reduce_mean(tf.reshape(noise_pred - noise_true, shape=(-1, last_dim)) ** 2, axis=-1) - 
        (loss_const * tf.math.reduce_mean(tf.reshape(noise_pred - piano_true, shape=(-1, last_dim)) ** 2, axis=-1)) +
        tf.math.reduce_mean(tf.reshape(piano_pred - piano_true, shape=(-1, last_dim)) ** 2, axis=-1) -
        (loss_const * tf.math.reduce_mean(tf.reshape(piano_pred - noise_true, shape=(-1, last_dim)) ** 2, axis=-1))
    )



def make_model(features, sequences, name='Model'):

    input_layer = Input(shape=(sequences, features), dtype='float32', 
                        name='piano_noise_mixed')
    piano_true = Input(shape=(sequences, features), dtype='float32', 
                       name='piano_true')
    noise_true = Input(shape=(sequences, features), dtype='float32', 
                       name='noise_true')

    x = SimpleRNN(features // 2, 
                  activation='relu', 
                  return_sequences=True) (input_layer) 
    piano_pred = TimeDistributed(Dense(features), name='piano_hat') (x)  # source 1 branch
    noise_pred = TimeDistributed(Dense(features), name='noise_hat') (x)  # source 2 branch
  
    model = Model(inputs=[input_layer, piano_true, noise_true],
                  outputs=[piano_pred, noise_pred])

    return model



# Model "wrapper" for many-input loss function
class RestorationModel2(Model):
    def __init__(self, model, loss_const):
        super(RestorationModel2, self).__init__()
        self.model = model
        self.loss_const = loss_const
       
    def call(self, inputs):
        return self.model(inputs)

    def compile(self, optimizer, loss):
        super(RestorationModel2, self).compile()
        self.optimizer = optimizer
        self.loss = loss

    def train_step(self, data):
        # Unpack data - what generator yeilds
        x, piano_true, noise_true = data

        with tf.GradientTape() as tape:
            piano_pred, noise_pred = self.model((x, piano_true, noise_true), training=True)
            loss = self.loss(piano_true, noise_true, piano_pred, noise_pred, self.loss_const)

        trainable_vars = self.model.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        
        return {'loss': loss}

    def test_step(self, data):
        x, piano_true, noise_true = data

        piano_pred, noise_pred = self.model((x, piano_true, noise_true), training=False)
        loss = self.loss(piano_true, noise_true, piano_pred, noise_pred, self.loss_const)
        
        return {'loss': loss}



def make_imp_model(features, sequences, loss_const=0.05, 
                   optimizer=tf.keras.optimizers.RMSprop(clipvalue=0.7),
                   name='Restoration Model', epsilon=10 ** (-10)):
    
    # NEW Semi-imperative model
    model = RestorationModel2(make_model(features, sequences, name='Training Model'),
                              loss_const=loss_const)

    model.compile(optimizer=optimizer, loss=discriminative_loss)

    return model



# MODEL TRAIN & EVAL FUNCTION
def evaluate_source_sep(train_generator, validation_generator,
                        num_train, num_val, n_feat, n_seq, batch_size, 
                        loss_const, epochs=20, 
                        optimizer=tf.keras.optimizers.RMSprop(clipvalue=0.75),
                        patience=10, epsilon=10 ** (-10)):
   
    print('Making model...')    # IMPERATIVE MODEL - Customize Fit
    model = make_imp_model(n_feat, n_seq, loss_const=loss_const, optimizer=optimizer, epsilon=epsilon)
    
    print('Going into training now...')
    hist = model.fit(train_generator,
                     steps_per_epoch=math.ceil(num_train / batch_size),
                     epochs=epochs,
                     validation_data=validation_generator,
                     validation_steps=math.ceil(num_val / batch_size),
                     callbacks=[EarlyStopping('val_loss', patience=patience, mode='min')])
    print(model.summary())



# NEURAL NETWORK DATA GENERATOR
def my_dummy_generator(num_samples, batch_size, train_seq, train_feat):

    while True:
        for offset in range(0, num_samples, batch_size):

            # Initialise x, y1 and y2 arrays for this batch
            x, y1, y2 = (np.empty((batch_size, train_seq, train_feat)),
                            np.empty((batch_size, train_seq, train_feat)),
                            np.empty((batch_size, train_seq, train_feat)))

            yield (x, y1, y2)



def main():
    epsilon = 10 ** (-10)
    train_batch_size = 5
    loss_const, epochs, val_split = 0.05, 10, 0.25
    optimizer = tf.keras.optimizers.RMSprop(clipvalue=0.9)

    TRAIN_SEQ_LEN, TRAIN_FEAT_LEN = 1847, 2049
    TOTAL_SMPLS = 60 

    # Validation & Training Split
    indices = list(range(TOTAL_SMPLS))
    val_indices = indices[:math.ceil(TOTAL_SMPLS * val_split)]
    num_val = len(val_indices)
    num_train = TOTAL_SMPLS - num_val
   
    train_seq, train_feat = TRAIN_SEQ_LEN, TRAIN_FEAT_LEN
    print('Train Input Stats:')
    print('N Feat:', train_feat, 'Seq Len:', train_seq, 'Batch Size:', train_batch_size)

    # Create data generators and evaluate model with them
    train_generator = my_dummy_generator(num_train,
                        batch_size=train_batch_size, train_seq=train_seq,
                        train_feat=train_feat)
    validation_generator = my_dummy_generator(num_val,
                        batch_size=train_batch_size, train_seq=train_seq,
                        train_feat=train_feat)

    evaluate_source_sep(train_generator, validation_generator, num_train, num_val,
                            n_feat=train_feat, n_seq=train_seq, 
                            batch_size=train_batch_size, 
                            loss_const=loss_const, epochs=epochs,
                            optimizer=optimizer, epsilon=epsilon)

if __name__ == '__main__':
    main()
class TimeFreqMasking(Layer):
    # Init is for input-independent variables
    def __init__(self, epsilon, **kwargs):
        super(TimeFreqMasking, self).__init__(**kwargs)
        self.epsilon = epsilon

    # No build method, b/c passing in multiple inputs to layer (no single shape)

    def call(self, inputs):
        y_hat_self, y_hat_other, x_mixed = inputs
        mask = tf.abs(y_hat_self) / (tf.abs(y_hat_self) + tf.abs(y_hat_other) + self.epsilon)
        y_tilde_self = mask * x_mixed
        return y_tilde_self


def discrim_loss(y_true, y_pred):
    piano_true, noise_true = tf.split(y_true, num_or_size_splits=2, axis=-1)
    loss_const = y_pred[-1, :, :][0][0]
    piano_pred, noise_pred = tf.split(y_pred[:-1, :, :], num_or_size_splits=2, axis=0)

    last_dim = piano_pred.shape[1] * piano_pred.shape[2]
    return (
        tf.math.reduce_mean(tf.reshape(noise_pred - noise_true, shape=(-1, last_dim)) ** 2) - 
        (loss_const * tf.math.reduce_mean(tf.reshape(noise_pred - piano_true, shape=(-1, last_dim)) ** 2)) +
        tf.math.reduce_mean(tf.reshape(piano_pred - piano_true, shape=(-1, last_dim)) ** 2) -
        (loss_const * tf.math.reduce_mean(tf.reshape(piano_pred - noise_true, shape=(-1, last_dim)) ** 2))
    )


def make_model(features, sequences, epsilon, loss_const):
    input_layer = Input(shape=(sequences, features), name='piano_noise_mixed')
    x = SimpleRNN(features // 2, 
                activation='relu', 
                return_sequences=True) (input_layer) 
    x = SimpleRNN(features // 2, 
            activation='relu',
            return_sequences=True) (x)
    piano_hat = TimeDistributed(Dense(features), name='piano_hat') (x)  # source 1 branch
    noise_hat = TimeDistributed(Dense(features), name='noise_hat') (x)  # source 2 branch
    piano_pred = TimeFreqMasking(epsilon=epsilon, 
                                name='piano_pred') ((piano_hat, noise_hat, input_layer))
    noise_pred = TimeFreqMasking(epsilon=epsilon, 
                                name='noise_pred') ((noise_hat, piano_hat, input_layer))

    preds_and_gamma = Concatenate(axis=0) ([piano_pred, 
                                        noise_pred, 
                                        #  loss_const_tensor
                                        tf.broadcast_to(tf.constant(loss_const), [1, sequences, features])
                                        ])
    model = Model(inputs=input_layer, outputs=preds_and_gamma)
    model.compile(optimizer=optimizer, loss=discrim_loss)
    return model


def dummy_generator(num_samples, batch_size, num_seq, num_feat):
    while True:
        for _ in range(0, num_samples, batch_size):
            x, y1, y2 = (np.random.rand(batch_size, num_seq, num_feat),
                        np.random.rand(batch_size, num_seq, num_feat),
                        np.random.rand(batch_size, num_seq, num_feat))

            yield ([x, np.concatenate((y1, y2), axis=-1)])


total_samples = 6
batch_size = 2
time_steps = 3
features = 4
loss_const = 2
epochs = 10
val_split = 0.25
epsilon = 10 ** (-10)

model = make_model(features, time_steps, epsilon, loss_const)
print(model.summary())

num_val = math.ceil(actual_samples * val_split)
num_train = total_samples - val_samples
train_dataset = dummy_generator(num_train, batch_size, time_steps, features)
val_dataset = dummy_generator(num_val, batch_size, time_steps, features)

model.fit(train_dataset,
                steps_per_epoch=math.ceil(num_train / batch_size),
                epochs=epochs,
                validation_data=val_dataset,
                validation_steps=math.ceil(num_val / batch_size)


谢谢你的帮助

解决方案,不要将您的损失传递到模型中。添加_loss()。相反,将输出连接在一起,这样可以将自定义损失传递到model.compile()。然后在自定义损失函数中处理输出

from scipy.io import wavfile
import scipy.signal as sg
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.layers import Input, SimpleRNN, Dense, Lambda, TimeDistributed, Layer, LSTM, Bidirectional, BatchNormalization, Concatenate
from tensorflow.keras.models import Model
from tensorflow.keras.activations import relu
from tensorflow.keras.callbacks import EarlyStopping
import numpy as np
import datetime
import numpy as np
import math
import random
import json
import os
import sys



# Loss function
def discriminative_loss(piano_true, noise_true, piano_pred, noise_pred, loss_const):
    last_dim = piano_pred.shape[1] * piano_pred.shape[2]
    return (
        tf.math.reduce_mean(tf.reshape(noise_pred - noise_true, shape=(-1, last_dim)) ** 2, axis=-1) - 
        (loss_const * tf.math.reduce_mean(tf.reshape(noise_pred - piano_true, shape=(-1, last_dim)) ** 2, axis=-1)) +
        tf.math.reduce_mean(tf.reshape(piano_pred - piano_true, shape=(-1, last_dim)) ** 2, axis=-1) -
        (loss_const * tf.math.reduce_mean(tf.reshape(piano_pred - noise_true, shape=(-1, last_dim)) ** 2, axis=-1))
    )



def make_model(features, sequences, name='Model'):

    input_layer = Input(shape=(sequences, features), dtype='float32', 
                        name='piano_noise_mixed')
    piano_true = Input(shape=(sequences, features), dtype='float32', 
                       name='piano_true')
    noise_true = Input(shape=(sequences, features), dtype='float32', 
                       name='noise_true')

    x = SimpleRNN(features // 2, 
                  activation='relu', 
                  return_sequences=True) (input_layer) 
    piano_pred = TimeDistributed(Dense(features), name='piano_hat') (x)  # source 1 branch
    noise_pred = TimeDistributed(Dense(features), name='noise_hat') (x)  # source 2 branch
  
    model = Model(inputs=[input_layer, piano_true, noise_true],
                  outputs=[piano_pred, noise_pred])

    return model



# Model "wrapper" for many-input loss function
class RestorationModel2(Model):
    def __init__(self, model, loss_const):
        super(RestorationModel2, self).__init__()
        self.model = model
        self.loss_const = loss_const
       
    def call(self, inputs):
        return self.model(inputs)

    def compile(self, optimizer, loss):
        super(RestorationModel2, self).compile()
        self.optimizer = optimizer
        self.loss = loss

    def train_step(self, data):
        # Unpack data - what generator yeilds
        x, piano_true, noise_true = data

        with tf.GradientTape() as tape:
            piano_pred, noise_pred = self.model((x, piano_true, noise_true), training=True)
            loss = self.loss(piano_true, noise_true, piano_pred, noise_pred, self.loss_const)

        trainable_vars = self.model.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        
        return {'loss': loss}

    def test_step(self, data):
        x, piano_true, noise_true = data

        piano_pred, noise_pred = self.model((x, piano_true, noise_true), training=False)
        loss = self.loss(piano_true, noise_true, piano_pred, noise_pred, self.loss_const)
        
        return {'loss': loss}



def make_imp_model(features, sequences, loss_const=0.05, 
                   optimizer=tf.keras.optimizers.RMSprop(clipvalue=0.7),
                   name='Restoration Model', epsilon=10 ** (-10)):
    
    # NEW Semi-imperative model
    model = RestorationModel2(make_model(features, sequences, name='Training Model'),
                              loss_const=loss_const)

    model.compile(optimizer=optimizer, loss=discriminative_loss)

    return model



# MODEL TRAIN & EVAL FUNCTION
def evaluate_source_sep(train_generator, validation_generator,
                        num_train, num_val, n_feat, n_seq, batch_size, 
                        loss_const, epochs=20, 
                        optimizer=tf.keras.optimizers.RMSprop(clipvalue=0.75),
                        patience=10, epsilon=10 ** (-10)):
   
    print('Making model...')    # IMPERATIVE MODEL - Customize Fit
    model = make_imp_model(n_feat, n_seq, loss_const=loss_const, optimizer=optimizer, epsilon=epsilon)
    
    print('Going into training now...')
    hist = model.fit(train_generator,
                     steps_per_epoch=math.ceil(num_train / batch_size),
                     epochs=epochs,
                     validation_data=validation_generator,
                     validation_steps=math.ceil(num_val / batch_size),
                     callbacks=[EarlyStopping('val_loss', patience=patience, mode='min')])
    print(model.summary())



# NEURAL NETWORK DATA GENERATOR
def my_dummy_generator(num_samples, batch_size, train_seq, train_feat):

    while True:
        for offset in range(0, num_samples, batch_size):

            # Initialise x, y1 and y2 arrays for this batch
            x, y1, y2 = (np.empty((batch_size, train_seq, train_feat)),
                            np.empty((batch_size, train_seq, train_feat)),
                            np.empty((batch_size, train_seq, train_feat)))

            yield (x, y1, y2)



def main():
    epsilon = 10 ** (-10)
    train_batch_size = 5
    loss_const, epochs, val_split = 0.05, 10, 0.25
    optimizer = tf.keras.optimizers.RMSprop(clipvalue=0.9)

    TRAIN_SEQ_LEN, TRAIN_FEAT_LEN = 1847, 2049
    TOTAL_SMPLS = 60 

    # Validation & Training Split
    indices = list(range(TOTAL_SMPLS))
    val_indices = indices[:math.ceil(TOTAL_SMPLS * val_split)]
    num_val = len(val_indices)
    num_train = TOTAL_SMPLS - num_val
   
    train_seq, train_feat = TRAIN_SEQ_LEN, TRAIN_FEAT_LEN
    print('Train Input Stats:')
    print('N Feat:', train_feat, 'Seq Len:', train_seq, 'Batch Size:', train_batch_size)

    # Create data generators and evaluate model with them
    train_generator = my_dummy_generator(num_train,
                        batch_size=train_batch_size, train_seq=train_seq,
                        train_feat=train_feat)
    validation_generator = my_dummy_generator(num_val,
                        batch_size=train_batch_size, train_seq=train_seq,
                        train_feat=train_feat)

    evaluate_source_sep(train_generator, validation_generator, num_train, num_val,
                            n_feat=train_feat, n_seq=train_seq, 
                            batch_size=train_batch_size, 
                            loss_const=loss_const, epochs=epochs,
                            optimizer=optimizer, epsilon=epsilon)

if __name__ == '__main__':
    main()
class TimeFreqMasking(Layer):
    # Init is for input-independent variables
    def __init__(self, epsilon, **kwargs):
        super(TimeFreqMasking, self).__init__(**kwargs)
        self.epsilon = epsilon

    # No build method, b/c passing in multiple inputs to layer (no single shape)

    def call(self, inputs):
        y_hat_self, y_hat_other, x_mixed = inputs
        mask = tf.abs(y_hat_self) / (tf.abs(y_hat_self) + tf.abs(y_hat_other) + self.epsilon)
        y_tilde_self = mask * x_mixed
        return y_tilde_self


def discrim_loss(y_true, y_pred):
    piano_true, noise_true = tf.split(y_true, num_or_size_splits=2, axis=-1)
    loss_const = y_pred[-1, :, :][0][0]
    piano_pred, noise_pred = tf.split(y_pred[:-1, :, :], num_or_size_splits=2, axis=0)

    last_dim = piano_pred.shape[1] * piano_pred.shape[2]
    return (
        tf.math.reduce_mean(tf.reshape(noise_pred - noise_true, shape=(-1, last_dim)) ** 2) - 
        (loss_const * tf.math.reduce_mean(tf.reshape(noise_pred - piano_true, shape=(-1, last_dim)) ** 2)) +
        tf.math.reduce_mean(tf.reshape(piano_pred - piano_true, shape=(-1, last_dim)) ** 2) -
        (loss_const * tf.math.reduce_mean(tf.reshape(piano_pred - noise_true, shape=(-1, last_dim)) ** 2))
    )


def make_model(features, sequences, epsilon, loss_const):
    input_layer = Input(shape=(sequences, features), name='piano_noise_mixed')
    x = SimpleRNN(features // 2, 
                activation='relu', 
                return_sequences=True) (input_layer) 
    x = SimpleRNN(features // 2, 
            activation='relu',
            return_sequences=True) (x)
    piano_hat = TimeDistributed(Dense(features), name='piano_hat') (x)  # source 1 branch
    noise_hat = TimeDistributed(Dense(features), name='noise_hat') (x)  # source 2 branch
    piano_pred = TimeFreqMasking(epsilon=epsilon, 
                                name='piano_pred') ((piano_hat, noise_hat, input_layer))
    noise_pred = TimeFreqMasking(epsilon=epsilon, 
                                name='noise_pred') ((noise_hat, piano_hat, input_layer))

    preds_and_gamma = Concatenate(axis=0) ([piano_pred, 
                                        noise_pred, 
                                        #  loss_const_tensor
                                        tf.broadcast_to(tf.constant(loss_const), [1, sequences, features])
                                        ])
    model = Model(inputs=input_layer, outputs=preds_and_gamma)
    model.compile(optimizer=optimizer, loss=discrim_loss)
    return model


def dummy_generator(num_samples, batch_size, num_seq, num_feat):
    while True:
        for _ in range(0, num_samples, batch_size):
            x, y1, y2 = (np.random.rand(batch_size, num_seq, num_feat),
                        np.random.rand(batch_size, num_seq, num_feat),
                        np.random.rand(batch_size, num_seq, num_feat))

            yield ([x, np.concatenate((y1, y2), axis=-1)])


total_samples = 6
batch_size = 2
time_steps = 3
features = 4
loss_const = 2
epochs = 10
val_split = 0.25
epsilon = 10 ** (-10)

model = make_model(features, time_steps, epsilon, loss_const)
print(model.summary())

num_val = math.ceil(actual_samples * val_split)
num_train = total_samples - val_samples
train_dataset = dummy_generator(num_train, batch_size, time_steps, features)
val_dataset = dummy_generator(num_val, batch_size, time_steps, features)

model.fit(train_dataset,
                steps_per_epoch=math.ceil(num_train / batch_size),
                epochs=epochs,
                validation_data=val_dataset,
                validation_steps=math.ceil(num_val / batch_size)


解决方案,不要将您的损失传递到模型中。添加_loss()。相反,将输出连接在一起,这样可以将自定义损失传递到model.compile()。然后在自定义损失函数中处理输出

from scipy.io import wavfile
import scipy.signal as sg
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.layers import Input, SimpleRNN, Dense, Lambda, TimeDistributed, Layer, LSTM, Bidirectional, BatchNormalization, Concatenate
from tensorflow.keras.models import Model
from tensorflow.keras.activations import relu
from tensorflow.keras.callbacks import EarlyStopping
import numpy as np
import datetime
import numpy as np
import math
import random
import json
import os
import sys



# Loss function
def discriminative_loss(piano_true, noise_true, piano_pred, noise_pred, loss_const):
    last_dim = piano_pred.shape[1] * piano_pred.shape[2]
    return (
        tf.math.reduce_mean(tf.reshape(noise_pred - noise_true, shape=(-1, last_dim)) ** 2, axis=-1) - 
        (loss_const * tf.math.reduce_mean(tf.reshape(noise_pred - piano_true, shape=(-1, last_dim)) ** 2, axis=-1)) +
        tf.math.reduce_mean(tf.reshape(piano_pred - piano_true, shape=(-1, last_dim)) ** 2, axis=-1) -
        (loss_const * tf.math.reduce_mean(tf.reshape(piano_pred - noise_true, shape=(-1, last_dim)) ** 2, axis=-1))
    )



def make_model(features, sequences, name='Model'):

    input_layer = Input(shape=(sequences, features), dtype='float32', 
                        name='piano_noise_mixed')
    piano_true = Input(shape=(sequences, features), dtype='float32', 
                       name='piano_true')
    noise_true = Input(shape=(sequences, features), dtype='float32', 
                       name='noise_true')

    x = SimpleRNN(features // 2, 
                  activation='relu', 
                  return_sequences=True) (input_layer) 
    piano_pred = TimeDistributed(Dense(features), name='piano_hat') (x)  # source 1 branch
    noise_pred = TimeDistributed(Dense(features), name='noise_hat') (x)  # source 2 branch
  
    model = Model(inputs=[input_layer, piano_true, noise_true],
                  outputs=[piano_pred, noise_pred])

    return model



# Model "wrapper" for many-input loss function
class RestorationModel2(Model):
    def __init__(self, model, loss_const):
        super(RestorationModel2, self).__init__()
        self.model = model
        self.loss_const = loss_const
       
    def call(self, inputs):
        return self.model(inputs)

    def compile(self, optimizer, loss):
        super(RestorationModel2, self).compile()
        self.optimizer = optimizer
        self.loss = loss

    def train_step(self, data):
        # Unpack data - what generator yeilds
        x, piano_true, noise_true = data

        with tf.GradientTape() as tape:
            piano_pred, noise_pred = self.model((x, piano_true, noise_true), training=True)
            loss = self.loss(piano_true, noise_true, piano_pred, noise_pred, self.loss_const)

        trainable_vars = self.model.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        
        return {'loss': loss}

    def test_step(self, data):
        x, piano_true, noise_true = data

        piano_pred, noise_pred = self.model((x, piano_true, noise_true), training=False)
        loss = self.loss(piano_true, noise_true, piano_pred, noise_pred, self.loss_const)
        
        return {'loss': loss}



def make_imp_model(features, sequences, loss_const=0.05, 
                   optimizer=tf.keras.optimizers.RMSprop(clipvalue=0.7),
                   name='Restoration Model', epsilon=10 ** (-10)):
    
    # NEW Semi-imperative model
    model = RestorationModel2(make_model(features, sequences, name='Training Model'),
                              loss_const=loss_const)

    model.compile(optimizer=optimizer, loss=discriminative_loss)

    return model



# MODEL TRAIN & EVAL FUNCTION
def evaluate_source_sep(train_generator, validation_generator,
                        num_train, num_val, n_feat, n_seq, batch_size, 
                        loss_const, epochs=20, 
                        optimizer=tf.keras.optimizers.RMSprop(clipvalue=0.75),
                        patience=10, epsilon=10 ** (-10)):
   
    print('Making model...')    # IMPERATIVE MODEL - Customize Fit
    model = make_imp_model(n_feat, n_seq, loss_const=loss_const, optimizer=optimizer, epsilon=epsilon)
    
    print('Going into training now...')
    hist = model.fit(train_generator,
                     steps_per_epoch=math.ceil(num_train / batch_size),
                     epochs=epochs,
                     validation_data=validation_generator,
                     validation_steps=math.ceil(num_val / batch_size),
                     callbacks=[EarlyStopping('val_loss', patience=patience, mode='min')])
    print(model.summary())



# NEURAL NETWORK DATA GENERATOR
def my_dummy_generator(num_samples, batch_size, train_seq, train_feat):

    while True:
        for offset in range(0, num_samples, batch_size):

            # Initialise x, y1 and y2 arrays for this batch
            x, y1, y2 = (np.empty((batch_size, train_seq, train_feat)),
                            np.empty((batch_size, train_seq, train_feat)),
                            np.empty((batch_size, train_seq, train_feat)))

            yield (x, y1, y2)



def main():
    epsilon = 10 ** (-10)
    train_batch_size = 5
    loss_const, epochs, val_split = 0.05, 10, 0.25
    optimizer = tf.keras.optimizers.RMSprop(clipvalue=0.9)

    TRAIN_SEQ_LEN, TRAIN_FEAT_LEN = 1847, 2049
    TOTAL_SMPLS = 60 

    # Validation & Training Split
    indices = list(range(TOTAL_SMPLS))
    val_indices = indices[:math.ceil(TOTAL_SMPLS * val_split)]
    num_val = len(val_indices)
    num_train = TOTAL_SMPLS - num_val
   
    train_seq, train_feat = TRAIN_SEQ_LEN, TRAIN_FEAT_LEN
    print('Train Input Stats:')
    print('N Feat:', train_feat, 'Seq Len:', train_seq, 'Batch Size:', train_batch_size)

    # Create data generators and evaluate model with them
    train_generator = my_dummy_generator(num_train,
                        batch_size=train_batch_size, train_seq=train_seq,
                        train_feat=train_feat)
    validation_generator = my_dummy_generator(num_val,
                        batch_size=train_batch_size, train_seq=train_seq,
                        train_feat=train_feat)

    evaluate_source_sep(train_generator, validation_generator, num_train, num_val,
                            n_feat=train_feat, n_seq=train_seq, 
                            batch_size=train_batch_size, 
                            loss_const=loss_const, epochs=epochs,
                            optimizer=optimizer, epsilon=epsilon)

if __name__ == '__main__':
    main()
class TimeFreqMasking(Layer):
    # Init is for input-independent variables
    def __init__(self, epsilon, **kwargs):
        super(TimeFreqMasking, self).__init__(**kwargs)
        self.epsilon = epsilon

    # No build method, b/c passing in multiple inputs to layer (no single shape)

    def call(self, inputs):
        y_hat_self, y_hat_other, x_mixed = inputs
        mask = tf.abs(y_hat_self) / (tf.abs(y_hat_self) + tf.abs(y_hat_other) + self.epsilon)
        y_tilde_self = mask * x_mixed
        return y_tilde_self


def discrim_loss(y_true, y_pred):
    piano_true, noise_true = tf.split(y_true, num_or_size_splits=2, axis=-1)
    loss_const = y_pred[-1, :, :][0][0]
    piano_pred, noise_pred = tf.split(y_pred[:-1, :, :], num_or_size_splits=2, axis=0)

    last_dim = piano_pred.shape[1] * piano_pred.shape[2]
    return (
        tf.math.reduce_mean(tf.reshape(noise_pred - noise_true, shape=(-1, last_dim)) ** 2) - 
        (loss_const * tf.math.reduce_mean(tf.reshape(noise_pred - piano_true, shape=(-1, last_dim)) ** 2)) +
        tf.math.reduce_mean(tf.reshape(piano_pred - piano_true, shape=(-1, last_dim)) ** 2) -
        (loss_const * tf.math.reduce_mean(tf.reshape(piano_pred - noise_true, shape=(-1, last_dim)) ** 2))
    )


def make_model(features, sequences, epsilon, loss_const):
    input_layer = Input(shape=(sequences, features), name='piano_noise_mixed')
    x = SimpleRNN(features // 2, 
                activation='relu', 
                return_sequences=True) (input_layer) 
    x = SimpleRNN(features // 2, 
            activation='relu',
            return_sequences=True) (x)
    piano_hat = TimeDistributed(Dense(features), name='piano_hat') (x)  # source 1 branch
    noise_hat = TimeDistributed(Dense(features), name='noise_hat') (x)  # source 2 branch
    piano_pred = TimeFreqMasking(epsilon=epsilon, 
                                name='piano_pred') ((piano_hat, noise_hat, input_layer))
    noise_pred = TimeFreqMasking(epsilon=epsilon, 
                                name='noise_pred') ((noise_hat, piano_hat, input_layer))

    preds_and_gamma = Concatenate(axis=0) ([piano_pred, 
                                        noise_pred, 
                                        #  loss_const_tensor
                                        tf.broadcast_to(tf.constant(loss_const), [1, sequences, features])
                                        ])
    model = Model(inputs=input_layer, outputs=preds_and_gamma)
    model.compile(optimizer=optimizer, loss=discrim_loss)
    return model


def dummy_generator(num_samples, batch_size, num_seq, num_feat):
    while True:
        for _ in range(0, num_samples, batch_size):
            x, y1, y2 = (np.random.rand(batch_size, num_seq, num_feat),
                        np.random.rand(batch_size, num_seq, num_feat),
                        np.random.rand(batch_size, num_seq, num_feat))

            yield ([x, np.concatenate((y1, y2), axis=-1)])


total_samples = 6
batch_size = 2
time_steps = 3
features = 4
loss_const = 2
epochs = 10
val_split = 0.25
epsilon = 10 ** (-10)

model = make_model(features, time_steps, epsilon, loss_const)
print(model.summary())

num_val = math.ceil(actual_samples * val_split)
num_train = total_samples - val_samples
train_dataset = dummy_generator(num_train, batch_size, time_steps, features)
val_dataset = dummy_generator(num_val, batch_size, time_steps, features)

model.fit(train_dataset,
                steps_per_epoch=math.ceil(num_train / batch_size),
                epochs=epochs,
                validation_data=val_dataset,
                validation_steps=math.ceil(num_val / batch_size)