Python 张量';对象没有属性';初始化为';在keras vgg16中,当在tensorflow 1.14中使用它并进行混合精度训练时

Python 张量';对象没有属性';初始化为';在keras vgg16中,当在tensorflow 1.14中使用它并进行混合精度训练时,python,tensorflow,keras,nvidia,vgg-net,Python,Tensorflow,Keras,Nvidia,Vgg Net,让我从乞讨开始。我在tensorflow 1.14中实现了一个用于基于的图像修复的部分卷积层(我已经对它进行了测试,它可以在我的数据集上工作) 该体系结构使用预训练(imagenet)VGG16来计算一些损失项。遗憾的是,在tensorflow中实现的VGG不起作用(我已经尝试了keras应用程序中的VGG。因此,我使用它将keras应用程序VGG16合并到tensorflow 1.14代码中 一切正常,但随后我将混合精度训练()合并到代码中,VGG16部分出现以下错误: Instruction

让我从乞讨开始。我在tensorflow 1.14中实现了一个用于基于的图像修复的部分卷积层(我已经对它进行了测试,它可以在我的数据集上工作)

该体系结构使用预训练(imagenet)VGG16来计算一些损失项。遗憾的是,在tensorflow中实现的VGG不起作用(我已经尝试了keras应用程序中的VGG。因此,我使用它将keras应用程序VGG16合并到tensorflow 1.14代码中

一切正常,但随后我将混合精度训练()合并到代码中,VGG16部分出现以下错误:

Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
ERROR:tensorflow:==================================
Object was never used (type <class 'tensorflow.python.framework.ops.Tensor'>):
<tf.Tensor 'VGG16/model/IsVariableInitialized_3:0' shape=() dtype=bool>
If you want to mark it as used call its "mark_used()" method.
It was originally created here:
File "main.py", line 131, in <module>
psi_gt, psi_out, psi_comp, I_comp, layers = model.build_vgg(data_gt, unet_pconv, 
data_mask)  File "/workspace/model.py", line 52, in build_vgg
vgg = vgg16.VGG16(image_shape=gt.shape, input_tensor=gt)  File "/workspace/vgg.py", line 
17, in __init__
self._build_graph(input_tensor)  File "/workspace/vgg.py", line 35, in _build_graph
self.vgg16 = tf.keras.applications.VGG16(weights='imagenet', include_top=False, 
input_tensor=img)  File "/usr/local/lib/python3.6/dist- 
packages/tensorflow/python/keras/applications/__init__.py", line 70, in wrapper
return base_fun(*args, **kwargs)  File "/usr/local/lib/python3.6/dist- 
packages/tensorflow/python/keras/applications/vgg16.py", line 32, in VGG16
return vgg16.VGG16(*args, **kwargs)  File "/usr/local/lib/python3.6/dist- 
packages/keras_applications/vgg16.py", line 210, in VGG16
model.load_weights(weights_path)  File "/usr/local/lib/python3.6/dist- 
packages/tensorflow/python/keras/engine/training.py", line 162, in load_weights
return super(Model, self).load_weights(filepath, by_name)  File 
"/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/network.py", line 
1424, in load_weights
saving.load_weights_from_hdf5_group(f, self.layers)  File "/usr/local/lib/python3.6/dist- 
packages/tensorflow/python/keras/saving/hdf5_format.py", line 759, in 
load_weights_from_hdf5_group
K.batch_set_value(weight_value_tuples)  File "/usr/local/lib/python3.6/dist- 
packages/tensorflow/python/keras/backend.py", line 3071, in batch_set_value
get_session().run(assign_ops, feed_dict=feed_dict)  File "/usr/local/lib/python3.6/dist- 
packages/tensorflow/python/keras/backend.py", line 462, in get_session
_initialize_variables(session)  File "/usr/local/lib/python3.6/dist- 
packages/tensorflow/python/keras/backend.py", line 879, in _initialize_variables
[variables_module.is_variable_initialized(v) for v in candidate_vars])  File 
"/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/backend.py", line 879, in 
<listcomp>
[variables_module.is_variable_initialized(v) for v in candidate_vars])  File 
"/usr/local/lib/python3.6/dist-packages/tensorflow/python/util/tf_should_use.py", line 193, 
in wrapped
return _add_should_use_warning(fn(*args, **kwargs))
==================================
ERROR:tensorflow:==================================
Object was never used (type <class 'tensorflow.python.framework.ops.Tensor'>):
<tf.Tensor 'VGG16/model/IsVariableInitialized_2:0' shape=() dtype=bool>
If you want to mark it as used call its "mark_used()" method.
It was originally created here:
File "main.py", line 131, in <module>
psi_gt, psi_out, psi_comp, I_comp, layers = model.build_vgg(data_gt, unet_pconv, data_mask)  
File "/workspace/model.py", line 52, in build_vgg
vgg = vgg16.VGG16(image_shape=gt.shape, input_tensor=gt)  File "/workspace/vgg.py", line 17, 
in __init__
self._build_graph(input_tensor)  File "/workspace/vgg.py", line 35, in _build_graph
self.vgg16 = tf.keras.applications.VGG16(weights='imagenet', include_top=False, 
input_tensor=img)  File "/usr/local/lib/python3.6/dist- 
packages/tensorflow/python/keras/applications/__init__.py", line 70, in wrapper
return base_fun(*args, **kwargs)  File "/usr/local/lib/python3.6/dist- 
packages/tensorflow/python/keras/applications/vgg16.py", line 32, in VGG16
return vgg16.VGG16(*args, **kwargs)  File "/usr/local/lib/python3.6/dist- 
packages/keras_applications/vgg16.py", line 210, in VGG16
model.load_weights(weights_path)  File "/usr/local/lib/python3.6/dist- 
packages/tensorflow/python/keras/engine/training.py", line 162, in load_weights
return super(Model, self).load_weights(filepath, by_name)  File 
"/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/network.py", line 
1424, in load_weights
saving.load_weights_from_hdf5_group(f, self.layers)  File "/usr/local/lib/python3.6/dist- 
packages/tensorflow/python/keras/saving/hdf5_format.py", line 759, in 
load_weights_from_hdf5_group
K.batch_set_value(weight_value_tuples)  File "/usr/local/lib/python3.6/dist- 
packages/tensorflow/python/keras/backend.py", line 3071, in batch_set_value
get_session().run(assign_ops, feed_dict=feed_dict)  File "/usr/local/lib/python3.6/dist- 
packages/tensorflow/python/keras/backend.py", line 462, in get_session
_initialize_variables(session)  File "/usr/local/lib/python3.6/dist- 
packages/tensorflow/python/keras/backend.py", line 879, in _initialize_variables
[variables_module.is_variable_initialized(v) for v in candidate_vars])  File 
"/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/backend.py", line 879, in 
<listcomp>
[variables_module.is_variable_initialized(v) for v in candidate_vars])  File 
"/usr/local/lib/python3.6/dist-packages/tensorflow/python/util/tf_should_use.py", line 193, 
in wrapped
return _add_should_use_warning(fn(*args, **kwargs))
==================================
ERROR:tensorflow:==================================
Object was never used (type <class 'tensorflow.python.framework.ops.Tensor'>):
<tf.Tensor 'VGG16/model/IsVariableInitialized_1:0' shape=() dtype=bool>
If you want to mark it as used call its "mark_used()" method.
It was originally created here:
File "main.py", line 131, in <module>
psi_gt, psi_out, psi_comp, I_comp, layers = model.build_vgg(data_gt, unet_pconv, data_mask)  
File "/workspace/model.py", line 52, in build_vgg
vgg = vgg16.VGG16(image_shape=gt.shape, input_tensor=gt)  File "/workspace/vgg.py", line 17, 
in __init__
self._build_graph(input_tensor)  File "/workspace/vgg.py", line 35, in _build_graph
self.vgg16 = tf.keras.applications.VGG16(weights='imagenet', include_top=False, 
input_tensor=img)  File "/usr/local/lib/python3.6/dist- 
packages/tensorflow/python/keras/applications/__init__.py", line 70, in wrapper
return base_fun(*args, **kwargs)  File "/usr/local/lib/python3.6/dist- 
packages/tensorflow/python/keras/applications/vgg16.py", line 32, in VGG16
return vgg16.VGG16(*args, **kwargs)  File "/usr/local/lib/python3.6/dist- 
packages/keras_applications/vgg16.py", line 210, in VGG16
model.load_weights(weights_path)  File "/usr/local/lib/python3.6/dist- 
packages/tensorflow/python/keras/engine/training.py", line 162, in load_weights
return super(Model, self).load_weights(filepath, by_name)  File 
"/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/network.py", line 
1424, in load_weights
saving.load_weights_from_hdf5_group(f, self.layers)  File "/usr/local/lib/python3.6/dist- 
packages/tensorflow/python/keras/saving/hdf5_format.py", line 759, in 
load_weights_from_hdf5_group
K.batch_set_value(weight_value_tuples)  File "/usr/local/lib/python3.6/dist- 
packages/tensorflow/python/keras/backend.py", line 3071, in batch_set_value
get_session().run(assign_ops, feed_dict=feed_dict)  File "/usr/local/lib/python3.6/dist- 
packages/tensorflow/python/keras/backend.py", line 462, in get_session
_initialize_variables(session)  File "/usr/local/lib/python3.6/dist- 
packages/tensorflow/python/keras/backend.py", line 879, in _initialize_variables
[variables_module.is_variable_initialized(v) for v in candidate_vars])  File 
"/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/backend.py", line 879, in 
<listcomp>
[variables_module.is_variable_initialized(v) for v in candidate_vars])  File 
"/usr/local/lib/python3.6/dist-packages/tensorflow/python/util/tf_should_use.py", line 193, 
in wrapped
return _add_should_use_warning(fn(*args, **kwargs))
==================================
ERROR:tensorflow:==================================
Object was never used (type <class 'tensorflow.python.framework.ops.Tensor'>):
<tf.Tensor 'VGG16/model/IsVariableInitialized:0' shape=() dtype=bool>
If you want to mark it as used call its "mark_used()" method.
It was originally created here:
File "main.py", line 131, in <module>
psi_gt, psi_out, psi_comp, I_comp, layers = model.build_vgg(data_gt, unet_pconv, data_mask)  
File "/workspace/model.py", line 52, in build_vgg
vgg = vgg16.VGG16(image_shape=gt.shape, input_tensor=gt)  File "/workspace/vgg.py", line 17, 
in __init__
self._build_graph(input_tensor)  File "/workspace/vgg.py", line 35, in _build_graph
self.vgg16 = tf.keras.applications.VGG16(weights='imagenet', include_top=False, 
input_tensor=img)  File "/usr/local/lib/python3.6/dist- 
packages/tensorflow/python/keras/applications/__init__.py", line 70, in wrapper
return base_fun(*args, **kwargs)  File "/usr/local/lib/python3.6/dist- 
packages/tensorflow/python/keras/applications/vgg16.py", line 32, in VGG16
return vgg16.VGG16(*args, **kwargs)  File "/usr/local/lib/python3.6/dist- 
packages/keras_applications/vgg16.py", line 210, in VGG16
model.load_weights(weights_path)  File "/usr/local/lib/python3.6/dist- 
packages/tensorflow/python/keras/engine/training.py", line 162, in load_weights
return super(Model, self).load_weights(filepath, by_name)  File 
"/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/network.py", line 
1424, in load_weights
saving.load_weights_from_hdf5_group(f, self.layers)  File "/usr/local/lib/python3.6/dist- 
packages/tensorflow/python/keras/saving/hdf5_format.py", line 759, in 
load_weights_from_hdf5_group
K.batch_set_value(weight_value_tuples)  File "/usr/local/lib/python3.6/dist- 
packages/tensorflow/python/keras/backend.py", line 3071, in batch_set_value
get_session().run(assign_ops, feed_dict=feed_dict)  File "/usr/local/lib/python3.6/dist- 
packages/tensorflow/python/keras/backend.py", line 462, in get_session
_initialize_variables(session)  File "/usr/local/lib/python3.6/dist- 
packages/tensorflow/python/keras/backend.py", line 879, in _initialize_variables
[variables_module.is_variable_initialized(v) for v in candidate_vars])  File 
"/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/backend.py", line 879, in 
<listcomp>
[variables_module.is_variable_initialized(v) for v in candidate_vars])  File 
"/usr/local/lib/python3.6/dist-packages/tensorflow/python/util/tf_should_use.py", line 193, 
in wrapped
return _add_should_use_warning(fn(*args, **kwargs))
==================================
Traceback (most recent call last):
File "main.py", line 131, in <module>
psi_gt, psi_out, psi_comp, I_comp, layers = model.build_vgg(data_gt, unet_pconv, data_mask)
File "/workspace/model.py", line 52, in build_vgg
vgg = vgg16.VGG16(image_shape=gt.shape, input_tensor=gt)
File "/workspace/vgg.py", line 17, in __init__
self._build_graph(input_tensor)
File "/workspace/vgg.py", line 35, in _build_graph
self.vgg16 = tf.keras.applications.VGG16(weights='imagenet', include_top=False, 
input_tensor=img)
File "/usr/local/lib/python3.6/dist- 
packages/tensorflow/python/keras/applications/__init__.py", line 70, in wrapper
return base_fun(*args, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/applications/vgg16.py", line 32, in VGG16
return vgg16.VGG16(*args, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/keras_applications/vgg16.py", line 210, in VGG16
model.load_weights(weights_path)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py", line 162, in load_weights
return super(Model, self).load_weights(filepath, by_name)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/network.py", line 1424, in load_weights
saving.load_weights_from_hdf5_group(f, self.layers)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/hdf5_format.py", line 759, in load_weights_from_hdf5_group
K.batch_set_value(weight_value_tuples)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/backend.py", line 3071, in batch_set_value
get_session().run(assign_ops, feed_dict=feed_dict)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/backend.py", line 462, in get_session
_initialize_variables(session)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/backend.py", line 879, in _initialize_variables
[variables_module.is_variable_initialized(v) for v in candidate_vars])
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/backend.py", line 879, in <listcomp>
[variables_module.is_variable_initialized(v) for v in candidate_vars])
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/util/tf_should_use.py", line 193, in wrapped
return _add_should_use_warning(fn(*args, **kwargs))
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variables.py", line 3083, in is_variable_initialized
return state_ops.is_variable_initialized(variable)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/state_ops.py", line 133, in is_variable_initialized
return ref.is_initialized(name=name)
AttributeError: 'Tensor' object has no attribute 'is_initialized'
主脚本中使用的上一个函数:

import tensorflow as tf
import PConv
import model
import layers
import math
import os
import data
import utils
import numpy as np
import datetime

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

# Mixed precision training variable storage
def float32_variable_storage_getter(getter, name, shape=None, dtype=None,
                                    initializer=None, regularizer=None,
                                    trainable=True, *args, **kwargs):
    storage_dtype = tf.float32 if trainable else dtype
    variable = getter(name, shape, dtype=storage_dtype,
                      initializer=initializer, regularizer=regularizer,
                      trainable=trainable, *args, **kwargs)
    if trainable and dtype != tf.float32:
        variable = tf.cast(variable, dtype)
    return variable


# ==============================================================================
#                                   SETTINGS
# ==============================================================================

path_ =''

batch_size = 16
best_val = math.inf
best_val_epoch = 0
patience = 0
stop = 300
epochs = 2000
steps_train = 25
steps_val = 8 

template = '{}, Epoch {}, train_loss: {:.4f} - val_loss: {:.4f}'

path = path_ + 'tmp/'
if not os.path.isdir(path):
   os.mkdir(path)

# ==============================================================================
#                                       DATA
# ==============================================================================

X_train, m_train, y_train = data.get_filenames()
X_val, m_val, y_val = data.get_filenames(train=False)

# ==============================================================================
#                                     DATASET
# ==============================================================================

train_dataset = tf.data.Dataset.from_tensor_slices((X_train, m_train, y_train))#(images, mask, gt))

train_dataset = train_dataset.map(data.load, num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_dataset = train_dataset.batch(batch_size)
train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

val_dataset = tf.data.Dataset.from_tensor_slices((X_val, m_val, y_val))#(images, mask, gt))

val_dataset = val_dataset.map(data.load, num_parallel_calls=tf.data.experimental.AUTOTUNE)
val_dataset = val_dataset.batch(batch_size)
val_dataset = val_dataset.prefetch(buffer_size=1)

iterator = tf.data.Iterator.from_structure(train_dataset.output_types,
                                       train_dataset.output_shapes)


data_im, data_mask, data_gt = iterator.get_next()

# create the initialization operations
train_init_op = iterator.make_initializer(train_dataset)
val_init_op = iterator.make_initializer(val_dataset)

# ==============================================================================
#                                     MODEL
# ==============================================================================

data_im = tf.cast(data_im, tf.float16)
data_mask = tf.cast(data_mask, tf.float16)

with tf.variable_scope('fp32_vars', custom_getter=float32_variable_storage_getter):
    unet_pconv = model.pconv_unet(data_im, data_mask)

unet_pconv = tf.cast(unet_pconv, tf.float32)
data_mask = tf.cast(data_mask, tf.float32)

psi_gt, psi_out, psi_comp, I_comp, layers = model.build_vgg(data_gt, unet_pconv, data_mask)

I_comp = tf.cast(I_comp, tf.float32)


# # ==============================================================================
# #                                     LOSS
# # ==============================================================================
loss = utils.get_total_loss(unet_pconv, data_gt, data_mask, psi_gt, psi_out, psi_comp, I_comp, layers)
lr = 0.0002
optimizer = utils.optimize(loss, lr)

saver = tf.train.Saver()

# # ==============================================================================
# #                                  TRAINING
# # ==============================================================================

output_summary = tf.summary.image(name='output', tensor=unet_pconv)
merged = tf.summary.merge_all()

with tf.Session() as sess:

    sess.run(tf.local_variables_initializer())
    sess.run(tf.global_variables_initializer())

    writer = tf.summary.FileWriter('graphs',sess.graph)

    train_loss_, val_loss_ = [], []

    for epoch in range(epochs):
        pred_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')

        tl, vl = [], []

        #Initialize iterator with training data
        sess.run(train_init_op)
        try:
            for step in range (steps_train):
                _, train_loss, summ = sess.run([optimizer, loss, merged])
                writer.add_summary(summ, epoch)
                tl.append(train_loss)
            mean_train = utils.list_mean(tl)
            train_loss_.append(mean_train)

        except tf.errors.OutOfRangeError:
            pass

        if (epoch+1) % 1 == 0:
            sess.run(val_init_op)
            try:
                for step in range (steps_val):
                    val_loss = sess.run([loss])
                    vl.append(val_loss)
                mean_val = utils.list_mean(vl)
                val_loss_.append(mean_val)

            except tf.errors.OutOfRangeError:
                pass


        print(template.format(pred_time, epoch, mean_train, mean_val))

        # early stopping
        if mean_val < best_val:
           print('Saving on epoch {0}'.format(epoch))
           best_val = mean_val
           patience = 0

           best_val_epoch = epoch
           saver.save(sess, path+'best_model')
        else:
           patience += 1

           if patience == stop:
               print('Early stopping at epoch: {}'.format(best_val_epoch))
               break

# # ==============================================================================
# #                                  SAVE CURVES
# # ==============================================================================

np.save(path_+'loss.npy', train_loss_)
np.save(path_+'val_loss.npy', val_loss_)
我已经尝试解决这个问题一段时间了,但仍然不明白为什么在我实施混合精度训练时它不起作用。请随意询问更多细节


如果你能帮忙,那就太好了!提前谢谢你。

我已经尝试了很多方法,我最后的想法是,预先训练好的keras模型不兼容。我将其改为tensorflow VGG16模型,它工作速度较慢,但至少可以工作

import tensorflow as tf
import PConv
import model
import layers
import math
import os
import data
import utils
import numpy as np
import datetime

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

# Mixed precision training variable storage
def float32_variable_storage_getter(getter, name, shape=None, dtype=None,
                                    initializer=None, regularizer=None,
                                    trainable=True, *args, **kwargs):
    storage_dtype = tf.float32 if trainable else dtype
    variable = getter(name, shape, dtype=storage_dtype,
                      initializer=initializer, regularizer=regularizer,
                      trainable=trainable, *args, **kwargs)
    if trainable and dtype != tf.float32:
        variable = tf.cast(variable, dtype)
    return variable


# ==============================================================================
#                                   SETTINGS
# ==============================================================================

path_ =''

batch_size = 16
best_val = math.inf
best_val_epoch = 0
patience = 0
stop = 300
epochs = 2000
steps_train = 25
steps_val = 8 

template = '{}, Epoch {}, train_loss: {:.4f} - val_loss: {:.4f}'

path = path_ + 'tmp/'
if not os.path.isdir(path):
   os.mkdir(path)

# ==============================================================================
#                                       DATA
# ==============================================================================

X_train, m_train, y_train = data.get_filenames()
X_val, m_val, y_val = data.get_filenames(train=False)

# ==============================================================================
#                                     DATASET
# ==============================================================================

train_dataset = tf.data.Dataset.from_tensor_slices((X_train, m_train, y_train))#(images, mask, gt))

train_dataset = train_dataset.map(data.load, num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_dataset = train_dataset.batch(batch_size)
train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

val_dataset = tf.data.Dataset.from_tensor_slices((X_val, m_val, y_val))#(images, mask, gt))

val_dataset = val_dataset.map(data.load, num_parallel_calls=tf.data.experimental.AUTOTUNE)
val_dataset = val_dataset.batch(batch_size)
val_dataset = val_dataset.prefetch(buffer_size=1)

iterator = tf.data.Iterator.from_structure(train_dataset.output_types,
                                       train_dataset.output_shapes)


data_im, data_mask, data_gt = iterator.get_next()

# create the initialization operations
train_init_op = iterator.make_initializer(train_dataset)
val_init_op = iterator.make_initializer(val_dataset)

# ==============================================================================
#                                     MODEL
# ==============================================================================

data_im = tf.cast(data_im, tf.float16)
data_mask = tf.cast(data_mask, tf.float16)

with tf.variable_scope('fp32_vars', custom_getter=float32_variable_storage_getter):
    unet_pconv = model.pconv_unet(data_im, data_mask)

unet_pconv = tf.cast(unet_pconv, tf.float32)
data_mask = tf.cast(data_mask, tf.float32)

psi_gt, psi_out, psi_comp, I_comp, layers = model.build_vgg(data_gt, unet_pconv, data_mask)

I_comp = tf.cast(I_comp, tf.float32)


# # ==============================================================================
# #                                     LOSS
# # ==============================================================================
loss = utils.get_total_loss(unet_pconv, data_gt, data_mask, psi_gt, psi_out, psi_comp, I_comp, layers)
lr = 0.0002
optimizer = utils.optimize(loss, lr)

saver = tf.train.Saver()

# # ==============================================================================
# #                                  TRAINING
# # ==============================================================================

output_summary = tf.summary.image(name='output', tensor=unet_pconv)
merged = tf.summary.merge_all()

with tf.Session() as sess:

    sess.run(tf.local_variables_initializer())
    sess.run(tf.global_variables_initializer())

    writer = tf.summary.FileWriter('graphs',sess.graph)

    train_loss_, val_loss_ = [], []

    for epoch in range(epochs):
        pred_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')

        tl, vl = [], []

        #Initialize iterator with training data
        sess.run(train_init_op)
        try:
            for step in range (steps_train):
                _, train_loss, summ = sess.run([optimizer, loss, merged])
                writer.add_summary(summ, epoch)
                tl.append(train_loss)
            mean_train = utils.list_mean(tl)
            train_loss_.append(mean_train)

        except tf.errors.OutOfRangeError:
            pass

        if (epoch+1) % 1 == 0:
            sess.run(val_init_op)
            try:
                for step in range (steps_val):
                    val_loss = sess.run([loss])
                    vl.append(val_loss)
                mean_val = utils.list_mean(vl)
                val_loss_.append(mean_val)

            except tf.errors.OutOfRangeError:
                pass


        print(template.format(pred_time, epoch, mean_train, mean_val))

        # early stopping
        if mean_val < best_val:
           print('Saving on epoch {0}'.format(epoch))
           best_val = mean_val
           patience = 0

           best_val_epoch = epoch
           saver.save(sess, path+'best_model')
        else:
           patience += 1

           if patience == stop:
               print('Early stopping at epoch: {}'.format(best_val_epoch))
               break

# # ==============================================================================
# #                                  SAVE CURVES
# # ==============================================================================

np.save(path_+'loss.npy', train_loss_)
np.save(path_+'val_loss.npy', val_loss_)
def optimize(loss, learning_rate=1e-4):
    U_vars = [var for var in tf.trainable_variables() if 'UNET' in var.name]

    opt = tf.train.AdamOptimizer(learning_rate=learning_rate)
    opt = tf.train.experimental.enable_mixed_precision_graph_rewrite(opt, loss_scale=128.0)

    train_opt = opt.minimize(loss, var_list=U_vars)

    return train_opt