Python 列车检查点是否正在恢复?

Python 列车检查点是否正在恢复?,python,tensorflow,machine-learning,keras,deep-learning,Python,Tensorflow,Machine Learning,Keras,Deep Learning,我正在colab上运行tensorflow 2.4。我尝试使用tf.train.Checkpoint()保存模型,因为它包含模型子类化,但在恢复后,我发现它没有恢复模型的任何权重 以下是几个片段: ### From tensorflow tutorial nmt_with_attention class Encoder(tf.keras.Model): def __init__(self, vocab_size, embedding_dim, enc_units, batch_sz):

我正在colab上运行tensorflow 2.4。我尝试使用
tf.train.Checkpoint()
保存模型,因为它包含模型子类化,但在恢复后,我发现它没有恢复模型的任何权重

以下是几个片段:

### From tensorflow tutorial nmt_with_attention
class Encoder(tf.keras.Model):
  def __init__(self, vocab_size, embedding_dim, enc_units, batch_sz):
    ...
    self.gru = tf.keras.layers.GRU(self.enc_units,
                                   return_sequences=True,
                                   return_state=True,
                                   recurrent_initializer='glorot_uniform')

.
.
.

class NMT_Train(tf.keras.Model):
  def __init__(self, inp_vocab_size, tar_vocab_size, max_length_inp, max_length_tar, emb_dims, units, batch_size, source_tokenizer, target_tokenizer):
    super(NMT_Train, self).__init__()
    self.encoder = Encoder(inp_vocab_size, emb_dims, units, batch_size)
    ...

.
.
.

model = NMT_Train(INP_VOCAB, TAR_VOCAB, MAXLEN, MAXLEN, EMB_DIMS, UNITS, BATCH_SIZE, english_tokenizer, hindi_tokenizer)
model.compile(optimizer = tf.keras.optimizers.Adam(),
              loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits= True))
model.fit(dataset, epochs=2)

checkpoint = tf.train.Checkpoint(model = model)
manager = tf.train.CheckpointManager(checkpoint, './ckpts', max_to_keep=1)
manager.save()

model.encoder.gru.get_weights() ### get the output
##[array([[-0.0627057 ,  0.05900152,  0.06614069, ...

model.optimizer.get_weights() ### get the output
##[90, array([[ 6.6851695e-05, -4.6736805e-06, -2.3183979e-05, ...
当我后来修复它时,我没有得到任何gru重量:

model = NMT_Train(INP_VOCAB, TAR_VOCAB, MAXLEN, MAXLEN, EMB_DIMS, UNITS, BATCH_SIZE, english_tokenizer, hindi_tokenizer)
model.compile(optimizer = tf.keras.optimizers.Adam(),
              loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits= True))

checkpoint = tf.train.Checkpoint(model = model)
manager = tf.train.CheckpointManager(checkpoint, './ckpts', max_to_keep=1)

manager.restore_or_initialize()

model.encoder.gru.get_weights() ### empty list
## []

model.optimizer.get_weights() ### empty list
## []
我还尝试了
checkpoint.restore(manager.latest\u checkpoint)
,但没有任何更改


我做错什么了吗??或者建议任何其他方法来保存模型,以便我可以对其进行重新训练,以备将来使用。

您正在定义一个keras模型,那么为什么不使用keras模型检查点呢

发件人:

model.compile(损失=…,优化器=。。。,
指标=[‘准确度’])
纪元=10
检查点\文件路径='/tmp/checkpoint'
model\u checkpoint\u callback=tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint\u filepath,
仅保存权重=真,
监视器='val_精度',
mode='max',
保存(仅限最佳值=真)
#模型权重在每个纪元结束时保存,如果它是最好看到的
#到目前为止。
fit(epochs=epochs,callbacks=[model\u checkpoint\u callback])
#将模型权重(被认为是最佳的)加载到模型中。
模型加载权重(检查点文件路径)

嘿,我试过了,但它仍然向我展示了同样的东西:(。这也保存了优化器,因为我想再次训练它,让它进入下一个时代。我的错,一切都很好。可能有一些修复延迟。当我在虚拟句子上调用模型,然后检查权重时,它显示了我的正确内容。谢谢你的回复。