Tensorflow GAN的全球步骤

Tensorflow GAN的全球步骤,tensorflow,generative-adversarial-network,Tensorflow,Generative Adversarial Network,我正在用分布式张量流训练一个GAN 在典型的GAN中,我们将交替优化G和D,而在分布式训练中,全局步骤将同时传递给G优化器和D优化器,然后全局步骤将在一次迭代中计数两次,但实际上我们只希望计数一次 以下是我代码的一部分: # build model.. model = GAN(FLAGS) # get G and D optimizer g_opt = tf.train.AdamOptimizer(FLAGS.lr, FLAGS.beta1, FLAGS.beta2,

我正在用分布式张量流训练一个GAN

在典型的GAN中,我们将交替优化G和D,而在分布式训练中,全局步骤将同时传递给G优化器和D优化器,然后全局步骤将在一次迭代中计数两次,但实际上我们只希望计数一次

以下是我代码的一部分:

# build model..
model = GAN(FLAGS)
# get G and D optimizer
g_opt = tf.train.AdamOptimizer(FLAGS.lr, FLAGS.beta1, FLAGS.beta2,
                                   name="g_Adam")
d_opt = tf.train.AdamOptimizer(FLAGS.lr, FLAGS.beta1, FLAGS.beta2,
                                   name="d_Adam")
global_step = tf.train.get_or_create_global_step()
# my question will focus on passing global step into optimizer.minimize()
g_train_opt = g_opt.minimize(model.g_loss, global_step, model.g_vars)
d_train_opt = d_opt.minimize(model.d_loss, global_step, model.d_vars)
# get hooks...
hooks = [stop_hook, ckpt_hook, sum_hook]

if is_chief:
    session_creator = tf.train.ChiefSessionCreator(
        master=server.target, config=sess_config)
else:
    session_creator = tf.train.WorkerSessionCreator(
        master=server.target, config=sess_config)
    hooks = None

with tf.train.MonitoredSession(session_creator, hooks=hooks) as mon_sess:

    # training
    local_step = 0
    try: 
        while not mon_sess.should_stop():
            t0 = time.time()
            gstep = mon_sess.run(global_step)
            # G and D will be trained alternately
            mon_sess.run(g_train_opt)
            mon_sess.run(d_train_opt)
            delta = time.time() - t0
            g_loss, d_loss = mon_sess.run([model.g_loss, model_d_loss])
            local_step += 1
            print("delta: {}".format(delta))
            print("global_step: {}, local_step: {}".format(gstep, local_step))
            print("g_loss: {} d_loss: {}".format(g_loss, d_loss))
            print("#" * 50)
    except KeyboardInterrupt:
        print("Interrupted")
因此,全局步长将在G和D上计数两次。

可能相关:。您可以为G和D定义伪全局_步骤变量,并按照代码中的真实全局_步骤进行更新。