Python 如何在Tensorflow 2培训期间更新SGD动量?

Python 如何在Tensorflow 2培训期间更新SGD动量?,python,tensorflow,tensorflow2.0,tensorflow2.x,Python,Tensorflow,Tensorflow2.0,Tensorflow2.x,在Tensorflow 2中,您可以在培训开始之前为SGD优化器设置动量。我想在训练期间为自定义循环中的每个历元更新它。从这里考虑代码: 我想更新动量,比如#########和##这里#之间的部分 for epoch in range(epochs): print("\nStart of epoch %d" % (epoch,)) # Iterate over the batches of the dataset. for step, (x_batc

在Tensorflow 2中,您可以在培训开始之前为SGD优化器设置动量。我想在训练期间为自定义循环中的每个历元更新它。从这里考虑代码:

我想更新动量,比如
#########
##这里#
之间的部分

for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))

    # Iterate over the batches of the dataset.
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):

        # Open a GradientTape to record the operations run
        # during the forward pass, which enables auto-differentiation.
        with tf.GradientTape() as tape:

            # Run the forward pass of the layer.
            # The operations that the layer applies
            # to its inputs are going to be recorded
            # on the GradientTape.
            logits = model(x_batch_train, training=True)  # Logits for this minibatch

            # Compute the loss value for this minibatch.
            loss_value = loss_fn(y_batch_train, logits)

        # Use the gradient tape to automatically retrieve
        # the gradients of the trainable variables with respect to the loss.
        grads = tape.gradient(loss_value, model.trainable_weights)

        # Run one step of gradient descent by updating
        # the value of the variables to minimize the loss.
        optimizer.apply_gradients(zip(grads, model.trainable_weights))
        ### HERE ###
        # here we update momentum directly or through some method
        optimizer.momentum = epoch/epochs 
        ### HERE ###

        # Log every 200 batches.
        if step % 200 == 0:
            print(
                "Training loss (for one batch) at step %d: %.4f"
                % (step, float(loss_value))
            )
            print("Seen so far: %s samples" % ((step + 1) * batch_size))