Validation 如何在验证时关闭自动混合精度?

Validation 如何在验证时关闭自动混合精度?,validation,tensorflow,precision,mixed,Validation,Tensorflow,Precision,Mixed,我尝试以自动混合精度运行MAC网络() def addOptimizerOp(self): with tf.variable_scope("trainAddOptimizer"): self.globalStep = tf.Variable(0, dtype = tf.int32, trainable = False, name = "globalStep") # init to 0 every run? optimize

我尝试以自动混合精度运行MAC网络()

def addOptimizerOp(self): 
    with tf.variable_scope("trainAddOptimizer"):            
            self.globalStep = tf.Variable(0, dtype = tf.int32, trainable = False, name = "globalStep") # init to 0 every run?
        optimizer = tf.train.AdamOptimizer(learning_rate = self.lr)
        optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(optimizer)
        if config.subsetOpt:
            self.subsetOptimizer = tf.train.AdamOptimizer(learning_rate = self.lr * config.subsetOptMult)

    return optimizer
在第一个时代,训练是可以的。然而,当模型在验证集上运行评估时,我得到了这个错误

Training epoch 1...
2019-08-05 14:51:13.625899: I tensorflow/core/grappler/optimizers/auto_mixed_precision.cc:1767] Running auto_mixed_precision graph optimizer
2019-08-05 14:51:13.709959: I tensorflow/core/grappler/optimizers/auto_mixed_precision.cc:1723] Converted 1504/6920 nodes to float16 precision using 150 cast(s) to float16 (excluding Const and Variable casts)
2019-08-05 14:51:16.930248: I tensorflow/stream_executor/platform/default/dso_loader.cc:42] Successfully opened dynamic library libcublas.so.10.0
2019-08-05 14:51:17.331687: I tensorflow/stream_executor/platform/default/dso_loader.cc:42] Successfully opened dynamic library libcudnn.so.7
2019-08-05 14:51:29.378905: I tensorflow/core/grappler/optimizers/auto_mixed_precision.cc:1767] Running auto_mixed_precision graph optimizer
2019-08-05 14:51:29.380633: I tensorflow/core/grappler/optimizers/auto_mixed_precision.cc:1241] No whitelist ops found, nothing to do
eb  1, 10000,(160010 / 943000), t = 0.12 (0.00+0.11), lr 0.0003, l = 2.8493, a = 0.4250, avL = 2.5323, avA = 0.4188, g = 3.7617, emL = 2.3097, emA = 0.4119; gqaExperiment
Restoring EMA weights
2019-08-05 14:51:31.132804: I tensorflow/core/grappler/optimizers/auto_mixed_precision.cc:1767] Running auto_mixed_precision graph optimizer
2019-08-05 14:51:31.136122: I tensorflow/core/grappler/optimizers/auto_mixed_precision.cc:1241] No whitelist ops found, nothing to do
2019-08-05 14:51:32.322369: I tensorflow/core/grappler/optimizers/auto_mixed_precision.cc:1767] Running auto_mixed_precision graph optimizer
2019-08-05 14:51:32.341609: I tensorflow/core/grappler/optimizers/auto_mixed_precision.cc:1723] Converted 661/1848 nodes to float16 precision using 38 cast(s) to float16 (excluding Const and Variable casts)
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py", line 1356, in _do_call
    return fn(*args)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py", line 1341, in _run_fn
    options, feed_dict, fetch_list, target_list, run_metadata)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py", line 1429, in _call_tf_sessionrun
    run_metadata)
tensorflow.python.framework.errors_impl.InvalidArgumentError: 2 root error(s) found.
  (0) Invalid argument: TensorArray dtype is float but Op is trying to write dtype half.
     [[{{node macModel/tower0/encoder/birnnLayer/bidirectional_rnn/fw/fw/while/TensorArrayWrite/TensorArrayWriteV3}}]]
     [[macModel/tower0/MACnetwork/MACCell_3/write/inter2attselfAttention/Softmax/_1661]]
  (1) Invalid argument: TensorArray dtype is float but Op is trying to write dtype half.
     [[{{node macModel/tower0/encoder/birnnLayer/bidirectional_rnn/fw/fw/while/TensorArrayWrite/TensorArrayWriteV3}}]]
0 successful operations.
0 derived errors ignored.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "main.py", line 866, in <module>
    main()
  File "main.py", line 728, in main
    evalRes = runEvaluation(sess, model, data["main"], dataOps, epoch, getPreds = getPreds, prevRes = evalRes)
  File "main.py", line 248, in runEvaluation
    minLoss = prevRes["train"]["minLoss"] if prevRes else float("inf"))
  File "main.py", line 594, in runEpoch
    res = model.runBatch(sess, batch, imagesBatch, train, getPreds, getAtt)
  File "/content/model.py", line 948, in runBatch
    feed_dict = feed)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py", line 950, in run
    run_metadata_ptr)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py", line 1173, in _run
    feed_dict_tensor, options, run_metadata)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py", line 1350, in _do_run
    run_metadata)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py", line 1370, in _do_call
    raise type(e)(node_def, op, message)
def main():
    with open(config.configFile(), "a+") as outFile:
        json.dump(vars(config), outFile)

    # set gpus
    if config.gpus != "":
        config.gpusNum = len(config.gpus.split(","))
        os.environ["CUDA_VISIBLE_DEVICES"] = config.gpus
    tf.logging.set_verbosity(tf.logging.ERROR)

    # process data
    print(bold("Preprocess data..."))
    start = time.time()
    preprocessor = Preprocesser()
    data, embeddings, answerDict, questionDict = preprocessor.preprocessData()
    print("took {} seconds".format(bcolored("{:.2f}".format(time.time() - start), "blue")))

    nextElement = None
    dataOps = None

    # build model
    print(bold("Building model..."))
    start = time.time()
    model = MACnet(embeddings, answerDict, questionDict, nextElement)
    print("took {} seconds".format(bcolored("{:.2f}".format(time.time() - start), "blue")))

    # initializer
    init = tf.global_variables_initializer()

    # savers
    savers = setSavers(model)
    saver, emaSaver = savers["saver"], savers["emaSaver"]

    # sessionConfig
    sessionConfig = setSession()

    with tf.Session(config = sessionConfig) as sess:

        # ensure no more ops are added after model is built
        sess.graph.finalize()

        # restore / initialize weights, initialize epoch variable
        epoch = loadWeights(sess, saver, init)

        trainRes, evalRes = None, None

        if config.train:
            start0 = time.time()

            bestEpoch = epoch
            bestRes = None
            prevRes = None

            # epoch in [restored + 1, epochs]
            for epoch in range(config.restoreEpoch + 1, config.epochs + 1):
                print(bcolored("Training epoch {}...".format(epoch), "green"))
                start = time.time()

                # train
                # calle = lambda: model.runEpoch(), collectRuntimeStats, writer
                trainingData, alterData = chooseTrainingData(data)
                trainRes = runEpoch(sess, model, trainingData, dataOps, train = True, epoch = epoch, 
                    saver = saver, alterData = alterData, 
                    maxAcc = trainRes["maxAcc"] if trainRes else 0.0,
                    minLoss = trainRes["minLoss"] if trainRes else float("inf"),)

                # save weights
                saver.save(sess, config.weightsFile(epoch))
                if config.saveSubset:
                    subsetSaver.save(sess, config.subsetWeightsFile(epoch))                   

                # load EMA weights 
                if config.useEMA:
                    print(bold("Restoring EMA weights"))
                    emaSaver.restore(sess, config.weightsFile(epoch))

                # evaluation  
                getPreds = config.getPreds or (config.analysisType != "")

                evalRes = runEvaluation(sess, model, data["main"], dataOps, epoch, getPreds = getPreds, prevRes = evalRes)
                extraEvalRes = runEvaluation(sess, model, data["extra"], dataOps, epoch, 
                    evalTrain = not config.extraVal, getPreds = getPreds)

                # restore standard weights
                if config.useEMA:
                    print(bold("Restoring standard weights"))
                    saver.restore(sess, config.weightsFile(epoch))

                print("")

                epochTime = time.time() - start
                print("took {:.2f} seconds".format(epochTime))

                # print results
                printDatasetResults(trainRes, evalRes, extraEvalRes)