Python Tensorflow目标检测急切\u少量\u快照\u od\u训练\u tf2分类头部错误

Python Tensorflow目标检测急切\u少量\u快照\u od\u训练\u tf2分类头部错误,python,tensorflow,object-detection-api,Python,Tensorflow,Object Detection Api,我正试图在一个有三个类的自定义数据集上训练对象检测模型,就像他们在这里所做的那样,但是他们没有恢复模型的分类头,因为他们有一个有一个类的数据集,但是,他们建议取消注释一行,正如您在这个代码片段中看到的那样。但当我取消注释时,会得到一个错误。谢谢你的帮助 # Set up object-based checkpoint restore --- RetinaNet has two prediction # `heads` --- one for classification, the other f

我正试图在一个有三个类的自定义数据集上训练对象检测模型,就像他们在这里所做的那样,但是他们没有恢复模型的分类头,因为他们有一个有一个类的数据集,但是,他们建议取消注释一行,正如您在这个代码片段中看到的那样。但当我取消注释时,会得到一个错误。谢谢你的帮助

# Set up object-based checkpoint restore --- RetinaNet has two prediction
# `heads` --- one for classification, the other for box regression.  We will
# restore the box regression head but initialize the classification head
# from scratch (we show the omission below by commenting out the line that
# we would add if we wanted to restore both heads)
fake_box_predictor = tf.compat.v2.train.Checkpoint(
    _base_tower_layers_for_heads=detection_model._box_predictor._base_tower_layers_for_heads,
    _prediction_heads=detection_model._box_predictor._prediction_heads, #I uncommented this line
    #    (i.e., the classification head that we *will not* restore)
    _box_prediction_head=detection_model._box_predictor._box_prediction_head,
    )
fake_model = tf.compat.v2.train.Checkpoint(
          _feature_extractor=detection_model._feature_extractor,
          _box_predictor=fake_box_predictor)
ckpt = tf.compat.v2.train.Checkpoint(model=fake_model)
ckpt.restore(checkpoint_path).expect_partial()

# Run model through a dummy image so that variables are created
image, shapes = detection_model.preprocess(tf.zeros([1, 640, 640, 3]))
prediction_dict = detection_model.predict(image, shapes)
_ = detection_model.postprocess(prediction_dict, shapes)
print('Weights restored!')
这是我在运行笔记本并取消注释该行后遇到的错误:

ValueError                                Traceback (most recent call last)
    <ipython-input-7-96e77f9f8468> in <module>
         24 
         25 image, shapes = detection_model.preprocess(tf.zeros([1, 640, 640, 3]))
    ---> 26 prediction_dict = detection_model.predict(image, shapes)
         27 _ = detection_model.postprocess(prediction_dict, shapes)
         28 print('Weights restored!')


C:\Python\lib\site-packages\object_detection\meta_architectures\ssd_meta_arch.py in predict(self, preprocessed_inputs, true_image_shapes)
    589     self._anchors = box_list_ops.concatenate(boxlist_list)
    590     if self._box_predictor.is_keras_model:
--> 591       predictor_results_dict = self._box_predictor(feature_maps)
    592     else:
    593       with slim.arg_scope([slim.batch_norm],

C:\Python\lib\site-packages\tensorflow\python\keras\engine\base_layer.py in __call__(self, *args, **kwargs)
    983 
    984         with ops.enable_auto_cast_variables(self._compute_dtype_object):
--> 985           outputs = call_fn(inputs, *args, **kwargs)
    986 
    987         if self._activity_regularizer:

C:\Python\lib\site-packages\object_detection\core\box_predictor.py in call(self, image_features, **kwargs)
    200           feature map in the input `image_features` list.
    201     """
--> 202     return self._predict(image_features, **kwargs)
    203 
    204   @abstractmethod

C:\Python\lib\site-packages\object_detection\predictors\convolutional_keras_box_predictor.py in _predict(self, image_features, **kwargs)
    482               self._base_tower_layers_for_heads[head_name][index],
    483               image_feature)
--> 484         prediction = head_obj(head_tower_feature)
    485         predictions[head_name].append(prediction)
    486     return predictions

C:\Python\lib\site-packages\tensorflow\python\keras\engine\base_layer.py in __call__(self, *args, **kwargs)
    983 
    984         with ops.enable_auto_cast_variables(self._compute_dtype_object):
--> 985           outputs = call_fn(inputs, *args, **kwargs)
    986 
    987         if self._activity_regularizer:

C:\Python\lib\site-packages\object_detection\predictors\heads\head.py in call(self, features)
     67   def call(self, features):
     68     """The Keras model call will delegate to the `_predict` method."""
---> 69     return self._predict(features)
     70 
     71   @abstractmethod

C:\Python\lib\site-packages\object_detection\predictors\heads\keras_class_head.py in _predict(self, features)
    339     for layer in self._class_predictor_layers:
    340       class_predictions_with_background = layer(
--> 341           class_predictions_with_background)
    342     batch_size = features.get_shape().as_list()[0]
    343     if batch_size is None:

C:\Python\lib\site-packages\tensorflow\python\keras\engine\base_layer.py in __call__(self, *args, **kwargs)
    980       with ops.name_scope_v2(name_scope):
    981         if not self.built:
--> 982           self._maybe_build(inputs)
    983 
    984         with ops.enable_auto_cast_variables(self._compute_dtype_object):

C:\Python\lib\site-packages\tensorflow\python\keras\engine\base_layer.py in _maybe_build(self, inputs)
   2641         # operations.
   2642         with tf_utils.maybe_init_scope(self):
-> 2643           self.build(input_shapes)  # pylint:disable=not-callable
   2644       # We must set also ensure that the layer is marked as built, and the build
   2645       # shape is stored since user defined build functions may not be calling

C:\Python\lib\site-packages\tensorflow\python\keras\layers\convolutional.py in build(self, input_shape)
    202         constraint=self.kernel_constraint,
    203         trainable=True,
--> 204         dtype=self.dtype)
    205     if self.use_bias:
    206       self.bias = self.add_weight(

C:\Python\lib\site-packages\tensorflow\python\keras\engine\base_layer.py in add_weight(self, name, shape, dtype, initializer, regularizer, trainable, constraint, partitioner, use_resource, synchronization, aggregation, **kwargs)
    612         synchronization=synchronization,
    613         aggregation=aggregation,
--> 614         caching_device=caching_device)
    615     if regularizer is not None:
    616       # TODO(fchollet): in the future, this should be handled at the

C:\Python\lib\site-packages\tensorflow\python\training\tracking\base.py in _add_variable_with_custom_getter(self, name, shape, dtype, initializer, getter, overwrite, **kwargs_for_getter)
    729         # there is nothing to restore.
    730         checkpoint_initializer = self._preload_simple_restoration(
--> 731             name=name, shape=shape)
    732       else:
    733         checkpoint_initializer = None

C:\Python\lib\site-packages\tensorflow\python\training\tracking\base.py in _preload_simple_restoration(self, name, shape)
    796         key=lambda restore: restore.checkpoint.restore_uid)
    797     return CheckpointInitialValue(
--> 798         checkpoint_position=checkpoint_position, shape=shape)
    799 
    800   def _track_trackable(self, trackable, name, overwrite=False):

C:\Python\lib\site-packages\tensorflow\python\training\tracking\base.py in __init__(self, checkpoint_position, shape)
     73       # We need to set the static shape information on the initializer if
     74       # possible so we don't get a variable with an unknown shape.
---> 75       self.wrapped_value.set_shape(shape)
     76     self._checkpoint_position = checkpoint_position
     77 

C:\Python\lib\site-packages\tensorflow\python\framework\ops.py in set_shape(self, shape)
   1207       raise ValueError(
   1208           "Tensor's shape %s is not compatible with supplied shape %s" %
-> 1209           (self.shape, shape))
   1210 
   1211   # Methods not supported / implemented for Eager Tensors.

ValueError: Tensor's shape (3, 3, 256, 546) is not compatible with supplied shape (3, 3, 256, 24)
ValueError回溯(最近一次调用)
在里面
24
25图像,形状=检测模型预处理(tf.zeros([1,640,640,3]))
--->26预测=检测模型。预测(图像、形状)
27=检测模型。后处理(预测、预测、形状)
28打印('已恢复权重!')
C:\Python\lib\site packages\object\u detection\meta\u architecture\ssd\u meta\u arch.py in predict(自我、预处理的\u输入、真实的\u图像\u形状)
589自锚点=框列表操作串联(框列表)
590如果自盒预测器为keras模型:
-->591预测器\u结果\u dict=self.\u框\u预测器(特征图)
592其他:
593具有slim.arg\u范围([slim.batch\u norm],
C:\Python\lib\site packages\tensorflow\Python\keras\engine\base\u layer.py in\uuuuu调用(self,*args,**kwargs)
983
984带操作。启用自动转换变量(自计算类型对象):
-->985输出=呼叫(输入,*args,**kwargs)
986
987如果自活动正则化器:
调用中的C:\Python\lib\site packages\object\u detection\core\box\u predictor.py(self、image\u功能,**kwargs)
200输入'image_features'列表中的特征映射。
201     """
-->202返回自我预测(图像特征,**kwargs)
203
204@abstractmethod
C:\Python\lib\site packages\object\u detection\predictors\voluminal\u keras\u box\u predictor.py in\u predict(self,image\u features,**kwargs)
482自底楼层楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼楼,
483图像(U功能)
-->484预测=首塔obj(首塔特征)
485预测[head_name]。追加(预测)
486返回预测
C:\Python\lib\site packages\tensorflow\Python\keras\engine\base\u layer.py in\uuuuu调用(self,*args,**kwargs)
983
984带操作。启用自动转换变量(自计算类型对象):
-->985输出=呼叫(输入,*args,**kwargs)
986
987如果自活动正则化器:
C:\Python\lib\site packages\object\u detection\predictors\heads\head.py in call(self,features)
67 def呼叫(自身、功能):
68“Keras模型调用将委托给`\u predict`方法。”“”
--->69返回自我预测(特征)
70
71@abstractmethod
C:\Python\lib\site packages\object\u detection\predictors\heads\keras\u class\u head.py in\u predict(self,features)
339对于自组中的层_类_预测器_层:
340类预测,背景=层(
-->341类(含背景)
342批次大小=特征。获取形状()。作为列表()[0]
343如果批次大小为无:
C:\Python\lib\site packages\tensorflow\Python\keras\engine\base\u layer.py in\uuuuu调用(self,*args,**kwargs)
980,带有操作名称\u范围\u v2(名称\u范围):
981如果不是自建的:
-->982自组装(输入)
983
984带操作。启用自动转换变量(自计算类型对象):
C:\Python\lib\site packages\tensorflow\Python\keras\engine\base\u layer.py in\u maybe\u build(self,inputs)
2641#操作。
2642带有tf_utils。可能是_init_作用域(self):
->2643 self.build(输入形状)#pylint:disable=不可调用
2644#我们还必须设置并确保层标记为已构建,并且构建
2645#由于用户定义的生成函数可能未调用,因此存储了shape
C:\Python\lib\site packages\tensorflow\Python\keras\layers\convolutional.py内置(self,input\u shape)
202约束=self.kernel\u约束,
203可培训=正确,
-->204 dtype=self.dtype)
205如果自我使用偏差:
206 self.bias=self.add\u权重(
C:\Python\lib\site packages\tensorflow\Python\keras\engine\base\u layer.py in add\u weight(self、name、shape、dtype、initializer、regularizer、trainable、constraint、partitioner、use\u resource、synchronization、aggregation、**kwargs)
612同步=同步,
613聚合=聚合,
-->614缓存\u设备=缓存\u设备)
615如果正则化器不是无:
616#TODO(fchollet):在未来,这应该在
C:\Python\lib\site packages\tensorflow\Python\training\tracking\base.py in\u add\u variable\u with\u custom\u getter(self、name、shape、dtype、initializer、getter、overwrite、**kwargs\u for\u getter)
没有什么可恢复的。
730检查点\初始值设定项=自身。\预加载\简单\恢复(
-->731名称=名称,形状=形状)
732其他:
733检查点\初始值设定项=无
C:\Python\lib\site packages\tensorflow\Python\training\tracking\base.py in\u preload\u simple\u restoration(self、name、shape)
796 key=lambda restore:restore.checkpoint.restore\u uid)
797返回检查点初始值(
-->798检查点位置=检查点位置,形状=形状)
799
800 def_track_trackable(self,trackable,name,overwrite=False):
C:\Python\lib\site packages\tensorflow\Python\training\tracking\base.py in\uuuuuuu init\uuuuuuu(self,checkpoint\u position,shape)
73#如果需要,我们需要在初始值设定项上设置静态形状信息
74#可能,因此我们不会得到形状未知的变量。
--->75自我包装的形状(形状)
76自我检查点位置=检查点位置
77
C