Tensorflow 如何创建具有子模型条件评估的组合tf.keras模型

Tensorflow 如何创建具有子模型条件评估的组合tf.keras模型,tensorflow,keras,tf.keras,Tensorflow,Keras,Tf.keras,我想创建多个tf.keras.Sequential模型的组合,以便在任何给定的时间点只对其中一个子模型进行评估。为了更好地解释,我创建了以下模型(模型的代码在本文末尾): 在此图中,sequential、sequential_1、sequential_2和sequential_3模型是基于LSTM的子模型,label_0是一个简单的第五个子模型。最后一层arbiter根据输入数据中的值(从\u arb中的提取)决定五条并行路径中的哪一条将实际提供网络输出。其他四个值将被丢弃 当然,其他四个并行

我想创建多个
tf.keras.Sequential
模型的组合,以便在任何给定的时间点只对其中一个子模型进行评估。为了更好地解释,我创建了以下模型(模型的代码在本文末尾):

在此图中,
sequential
sequential_1
sequential_2
sequential_3
模型是基于LSTM的子模型,
label_0
是一个简单的第五个子模型。最后一层
arbiter
根据输入数据中的值(从\u arb中的
提取)决定五条并行路径中的哪一条将实际提供网络输出。其他四个值将被丢弃

当然,其他四个并行层(对结果没有贡献)中的计算是浪费的。所以我的问题是:有没有办法在TensorFlow内部解决这个问题,例如,用某种条件图路由而不是并行执行

模型的示例代码:

import tensorflow as tf
tf.compat.v1.enable_eager_execution()

config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.compat.v1.Session(config=config)
tf.compat.v1.keras.backend.set_session(sess)

batch_size = 1

def gen_lstm(base_label, num_features, num_units):
    return tf.keras.Sequential([
        tf.keras.layers.InputLayer(input_shape=(1, num_features), batch_size=batch_size,
                                   name="input_{}".format(base_label)),
        tf.keras.layers.LSTM(num_units,
                             batch_input_shape=(batch_size, 1, num_features),
                             return_sequences=False, stateful=True,
                             name="lstm_{}".format(base_label)),
        tf.keras.layers.Dense(1, name="dense_{}".format(base_label)), # binary
        tf.keras.layers.Activation('sigmoid', name="activ_{}".format(base_label)), # binary
    ])

models = {}
for l in [1, 3, 4, 5]:
    global m
    m = gen_lstm(l, 130, 88)
    models[l] = m

in_all = tf.keras.layers.InputLayer(input_shape=(1, 132), batch_size=batch_size, name="input_all")
in_lstm = tf.keras.layers.Lambda(lambda x: tf.slice(x, [0, 0, 2], [-1, -1, -1]), name="in_lstm")(in_all.output)

def model_0(x):
    return x[:, :, 1]

out_model_0 = tf.keras.layers.Lambda(model_0, name="label_0")(in_all.output)

out_concat = tf.keras.layers.Concatenate(axis=1, name="concat_infer")([out_model_0] + [m(in_lstm) for m in models.values()])

in_arb = tf.keras.layers.Lambda(lambda x: tf.reshape(tf.slice(x, [0, 0, 0], [-1, -1, 1]), (batch_size, 1)), name="in_arb")(in_all.output)
out_merged = tf.keras.layers.Concatenate(axis=1, name="concat_arb")([in_arb, out_concat])

def arbiter(x):
    return tf.where(tf.equal(x[:, 0], tf.constant(0.0, dtype=tf.float32)), x[:, 1], tf.where(
        tf.equal(x[:, 0], tf.constant(2.0, dtype=tf.float32)), x[:, 2] + tf.constant(2.0), tf.where(
            tf.equal(x[:, 0], tf.constant(3.0, dtype=tf.float32)), x[:, 3] + tf.constant(3.0), tf.where(
                tf.equal(x[:, 0], tf.constant(4.0, dtype=tf.float32)), x[:, 4] + tf.constant(4.0),
                x[:, 5] + tf.constant(5.0)))))

out_merged = tf.keras.layers.Lambda(arbiter, name="arbiter")(out_merged)

lstm_model = tf.keras.Model([in_all.input], out_merged)

print(lstm_model.summary())
tf.keras.utils.plot_model(lstm_model, to_file="./temp.png", show_shapes=True)