Python 3.x 当Tensorflow(域自适应)中有自定义对象作为梯度反转层时,在负载模型中放置什么

Python 3.x 当Tensorflow(域自适应)中有自定义对象作为梯度反转层时,在负载模型中放置什么,python-3.x,tensorflow,keras,deep-learning,transfer-learning,Python 3.x,Tensorflow,Keras,Deep Learning,Transfer Learning,这里是域自适应模型的示例代码,我只想保存模型并加载它 @tf.custom_gradient def grad_reverse(x): y = tf.identity(x) def custom_grad(dy): return -dy return y, custom_grad class GradReverse(tf.keras.layers.Layer): def __init__(self): super().__init

这里是域自适应模型的示例代码,我只想保存模型并加载它

@tf.custom_gradient
def grad_reverse(x):
    y = tf.identity(x)
    def custom_grad(dy):
        return -dy
    return y, custom_grad

class GradReverse(tf.keras.layers.Layer):
    def __init__(self):
        super().__init__(name="grl")

    def call(self, x):
        return grad_reverse(x)


def get_adaptable_network(input_shape=x_source_train.shape[1:]):
    
    inputs = Input(shape=input_shape)
    x = Conv2D(32, 5, padding='same', activation='relu', name='conv2d_1')(inputs)
    x = MaxPool2D(pool_size=2, strides=2, name='max_pooling2d_1')(x)
    x = Conv2D(48, 5, padding='same', activation='relu', name='conv2d_2')(x)
    x = MaxPool2D(pool_size=2, strides=2, name='max_pooling2d_2')(x)
    features = Flatten(name='flatten_1')(x)
    x = Dense(100, activation='relu', name='dense_digits_1')(features)
    x = Dense(100, activation='relu', name='dense_digits_2')(x)
    digits_classifier = Dense(10, activation="softmax", name="digits_classifier")(x)

    domain_branch = Dense(100, activation="relu", name="dense_domain")(GradReverse()(features))
    domain_classifier = Dense(1, activation="sigmoid", name="domain_classifier")(domain_branch)

    return Model(inputs=inputs, outputs=[digits_classifier, domain_classifier])

model = get_adaptable_network()
model.summary()

# download the model in computer for later use
model.save('DA_MNIST_to_MNIST_m.h5')

from tensorflow import keras
model = keras.models.load_model('DA_MNIST_to_MNIST_m.h5',custom_objects={'?':? })
由于tensorflow中有一个用于域自适应的自定义渐变反转层,所以我不确定要在自定义对象部分添加什么。当我加载模型时,它会给出一个错误:

/usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/utils/generic_utils.py in class_and_config_for_serialized_keras_object(config, module_objects, custom_objects, printable_module_name)
    294   cls = get_registered_object(class_name, custom_objects, module_objects)
    295   if cls is None:
--> 296     raise ValueError('Unknown ' + printable_module_name + ': ' + class_name)
    297 
    298   cls_config = config['config']

ValueError: Unknown layer: GradReverse

我正在进行MNIST到MNIST\M域的自适应,任何帮助都会很有用

我明白了,我需要用**kwargs更改GradReverse层的init函数,然后这个对象将接受我没有包含的任何其他关键字参数

class GradReverse(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(name="grl")

    def call(self, x):
        return grad_reverse(x)
在负载模型中,我们可以使用

from tensorflow import keras
model = keras.models.load_model('DA_MNIST_to_MNIST_m.h5',custom_objects={'GradReverse':GradReverse})