Python 自定义仅用于培训过程的图层

Python 自定义仅用于培训过程的图层,python,tensorflow2.0,keras-layer,Python,Tensorflow2.0,Keras Layer,环境:Tensorflow2.3.0蟒蛇3.6 我正在尝试为图像增强的训练过程定制一个层。这是我的密码: class RandomLight(layers.Layer): def __init__(self, factor=0.2): super(RandomLight,self).__init__() self.factor = factor def call(self, input, training=None): return tf.cond(training,

环境:Tensorflow2.3.0蟒蛇3.6

我正在尝试为图像增强的训练过程定制一个层。这是我的密码:

class RandomLight(layers.Layer):
def __init__(self, factor=0.2):
    super(RandomLight,self).__init__()
    self.factor = factor

def call(self, input, training=None):
    return tf.cond(training,
                  lambda: tf.clip_by_value(tf.image.random_brightness(input,self.factor),0,1),
                  lambda: input)
当我要把它放到网络中时:

import tensorflow.keras as keras
import tensorflow.keras.layers as layers
from tensorflow.keras.applications import VGG16

inputs = keras.Input(shape=(224,224,3))
vgg16 = VGG16(include_top=False, weights='imagenet',input_shape=(224,224,3))
data_augmentation = keras.Sequential(
[
    layers.experimental.preprocessing.RandomRotation(0.25),
    layers.experimental.preprocessing.RandomFlip(),
    RandomLight()
])
i1 = data_augmentation(inputs)
bn = layers.BatchNormalization()(i1)
x = vgg16(bn)
flat_out = layers.Flatten()(x)
h1 = layers.Dense(1024,activation='relu',name='fc1')(flat_out)
h2 = layers.Dropout(0.5)(h1)
h3 = layers.Dense(32,activation='relu',name='fc2')(h2)
h4 = layers.Dropout(0.5)(h3)
new_out = layers.Dense(1,activation='sigmoid',name='prediction')(h4)
vgg_ft = keras.Model(inputs,new_out)
似乎出现了“培训=无”的错误

ValueError回溯(最近一次调用)
在()
---->1输入=数据扩充(输入)
2个输入=随机光(输入)
3 bn=layers.BatchNormalization()(输入)
4 x=vgg16(bn)
5展平=层。展平()(x)
F:\Anaconda3\envs\tf\lib\site packages\tensorflow\python\keras\engine\base\u layer.py in\uuuuuuu调用(self,*args,**kwargs)
924如果处于功能构建模式(自身、输入、参数、kwargs、输入列表):
925返回自功能构造调用(输入、参数、kwargs、,
-->926输入(U列表)
927
928#维护有关“Layer.call”堆栈的信息。
F:\Anaconda3\envs\tf\lib\site packages\tensorflow\python\keras\engine\base\u layer.py in\u functional\u construction\u call(self、inputs、args、kwargs、input\u list)
1115尝试:
1116带操作。启用自动转换变量(自计算类型对象):
->1117输出=调用fn(转换输入,*args,**kwargs)
1118
1119除错误外。运算符不允许作为e的Ingrapherror:
F:\Anaconda3\envs\tf\lib\site packages\tensorflow\python\autograph\impl\api.py包装中(*args,**kwargs)
256例外情况为e:#pylint:disable=broad except
257如果hasattr(即“ag\u错误\u元数据”):
-->258将e.ag\u错误\u元数据提升到\u异常(e)
259其他:
260加薪
ValueError:在用户代码中:
:11电话*
lambda:输入)
F:\Anaconda3\envs\tf\lib\site packages\tensorflow\python\util\dispatch.py:201包装器**
返回目标(*args,**kwargs)
F:\Anaconda3\envs\tf\lib\site packages\tensorflow\python\ops\control\u flow\u ops.py:1396 cond\u for\u tf\u v2
返回条件(pred,true\u fn=true\u fn,false\u fn=false\u fn,strict=true,name=name)
F:\Anaconda3\envs\tf\lib\site packages\tensorflow\python\util\dispatch.py:201包装器
返回目标(*args,**kwargs)
F:\Anaconda3\envs\tf\lib\site packages\tensorflow\python\util\deprecation.py:507 new\u func
返回函数(*args,**kwargs)
F:\Anaconda3\envs\tf\lib\site packages\tensorflow\python\ops\control\u flow\u ops.py:1180 cond
返回cond_v2.cond_v2(pred,true,false,name)
F:\Anaconda3\envs\tf\lib\site packages\tensorflow\python\ops\cond\u v2.py:74 cond\u v2
pred=运算。将_转换为_张量(pred)
F:\Anaconda3\envs\tf\lib\site packages\tensorflow\python\framework\ops.py:1499 convert\u to\u tensor
ret=conversion\u func(值,dtype=dtype,name=name,as\u ref=as\u ref)
F:\Anaconda3\envs\tf\lib\site packages\tensorflow\python\framework\constant\u op.py:338\u constant\u tensor\u conversion\u函数
返回常量(v,dtype=dtype,name=name)
F:\Anaconda3\envs\tf\lib\site packages\tensorflow\python\framework\constant\u op.py:264 constant
允许(广播=真)
F:\Anaconda3\envs\tf\lib\site packages\tensorflow\python\framework\constant\u op.py:282\u constant\u impl
允许广播=允许广播)
F:\Anaconda3\envs\tf\lib\site packages\tensorflow\python\framework\tensor\u util.py:444 make\u tensor\u proto
raise VALUERROR(“不支持无值”)
ValueError:不支持任何值。
我也尝试过
training=False
,但也不起作用


似乎
Sequential()
与我的自定义层配合得很好,但是我如何在我的格式中使用它

您是否从这些层创建了
模型
,或者您只是直接调用它们?不,它出现在我要创建keras.Model之前。您可以包括用于创建模型的代码吗?vgg_ft=keras.Model(输入,新输出)不确定为什么要使用
None
。但如果使用
False
,则需要使用
tf将其包装。将\u转换为\u张量
ValueError                                Traceback (most recent call last)
<ipython-input-290-966a2fabc71b> in <module>()
----> 1 inputs = data_augmentation(inputs)
  2 inputs = randomLight(inputs)
  3 bn = layers.BatchNormalization()(inputs)
  4 x = vgg16(bn)
  5 flat_out = layers.Flatten()(x)

F:\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\keras\engine\base_layer.py in __call__(self, *args, **kwargs)
    924     if _in_functional_construction_mode(self, inputs, args, kwargs, input_list):
    925       return self._functional_construction_call(inputs, args, kwargs,
--> 926                                                 input_list)
    927 
    928     # Maintains info about the `Layer.call` stack.

F:\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\keras\engine\base_layer.py in _functional_construction_call(self, inputs, args, kwargs, input_list)
   1115           try:
   1116             with ops.enable_auto_cast_variables(self._compute_dtype_object):
-> 1117               outputs = call_fn(cast_inputs, *args, **kwargs)
   1118 
   1119           except errors.OperatorNotAllowedInGraphError as e:

F:\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\autograph\impl\api.py in wrapper(*args, **kwargs)
    256       except Exception as e:  # pylint:disable=broad-except
    257         if hasattr(e, 'ag_error_metadata'):
--> 258           raise e.ag_error_metadata.to_exception(e)
    259         else:
    260           raise

ValueError: in user code:
<ipython-input-278-87ec004f05b3>:11 call  *
    lambda: input)
F:\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\util\dispatch.py:201 wrapper  **
    return target(*args, **kwargs)
F:\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\ops\control_flow_ops.py:1396 cond_for_tf_v2
    return cond(pred, true_fn=true_fn, false_fn=false_fn, strict=True, name=name)
F:\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\util\dispatch.py:201 wrapper
    return target(*args, **kwargs)
F:\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\util\deprecation.py:507 new_func
    return func(*args, **kwargs)
F:\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\ops\control_flow_ops.py:1180 cond
    return cond_v2.cond_v2(pred, true_fn, false_fn, name)
F:\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\ops\cond_v2.py:74 cond_v2
    pred = ops.convert_to_tensor(pred)
F:\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\framework\ops.py:1499 convert_to_tensor
    ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
F:\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\framework\constant_op.py:338 _constant_tensor_conversion_function
    return constant(v, dtype=dtype, name=name)
F:\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\framework\constant_op.py:264 constant
    allow_broadcast=True)
F:\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\framework\constant_op.py:282 _constant_impl
    allow_broadcast=allow_broadcast))
F:\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\framework\tensor_util.py:444 make_tensor_proto
    raise ValueError("None values not supported.")

ValueError: None values not supported.