Tensorflow 如何为多类分割初始化样本权重?

Tensorflow 如何为多类分割初始化样本权重?,tensorflow,machine-learning,image-processing,keras,deep-learning,Tensorflow,Machine Learning,Image Processing,Keras,Deep Learning,我正在使用Keras和U-net进行多类细分 我使用soft max激活函数将NN 12类作为输出。我的输出的形状是(N,288,12) 为了适应我的模型,我使用了稀疏的分类交叉熵 我想为我的不平衡数据集初始化模型的权重 我发现这很有用,并尝试将其实现;由于Keras中的class_weight不适用于2个以上的类,因此我使用了样本权重 我的代码是: inputs = tf.keras.layers.Input((IMG_WIDHT, IMG_HEIGHT, IMG_CHANNELS))

我正在使用Keras和U-net进行多类细分

我使用soft max激活函数将NN 12类作为输出。我的输出的形状是(N,288,12)

为了适应我的模型,我使用了稀疏的分类交叉熵

我想为我的不平衡数据集初始化模型的权重

我发现这很有用,并尝试将其实现;由于Keras中的
class_weight
不适用于2个以上的类,因此我使用了样本权重

我的代码是:

inputs = tf.keras.layers.Input((IMG_WIDHT, IMG_HEIGHT, IMG_CHANNELS))                                                                
smooth = 1.                                                                                                                          

s = tf.keras.layers.Lambda(lambda x: x / 255)(inputs)                                                                                
c1 = tf.keras.layers.Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(                          
    s)  # Kernelsize : start with some weights initial value                                                                         
c1 = tf.keras.layers.Dropout(0.1)(c1)                                                                                                
c1 = tf.keras.layers.Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(                          
    c1)  # Kernelsize : start with some weights initial value                                                                        
p1 = tf.keras.layers.MaxPool2D((2, 2))(c1)                                                                                           

c2 = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(                          
    p1)  # Kernelsize : start with some weights initial value                                                                        
c2 = tf.keras.layers.Dropout(0.1)(c2)                                                                                                
c2 = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(                          
    c2)  # Kernelsize : start with some weights initial value                                                                        
p2 = tf.keras.layers.MaxPool2D((2, 2))(c2)                                                                                           

c3 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(                          
    p2)  # Kernelsize : start with some weights initial value                                                                        
c3 = tf.keras.layers.Dropout(0.1)(c3)                                                                                                
c3 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(                          
    c3)  # Kernelsize : start with some weights initial value                                                                        
p3 = tf.keras.layers.MaxPool2D((2, 2))(c3)                                                                                           

c4 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(                         
    p3)  # Kernelsize : start with some weights initial value                                                                        
c4 = tf.keras.layers.Dropout(0.1)(c4)                                                                                                
c4 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(                         
    c4)  # Kernelsize : start with some weights initial value                                                                        
p4 = tf.keras.layers.MaxPool2D((2, 2))(c4)                                                                                           

c5 = tf.keras.layers.Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(                         
    p4)  # Kernelsize : start with some weights initial value                                                                        
c5 = tf.keras.layers.Dropout(0.1)(c5)                                                                                                
c5 = tf.keras.layers.Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(                         
    c5)  # Kernelsize : start wi                                                                                                     

u6 = tf.keras.layers.Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(c5)                                                
u6 = tf.keras.layers.concatenate([u6, c4])                                                                                           
c6 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u6)                      
c6 = tf.keras.layers.Dropout(0.2)(c6)                                                                                                
c6 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c6)                      

u7 = tf.keras.layers.Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(c6)                                                 
u7 = tf.keras.layers.concatenate([u7, c3])                                                                                           
c7 = tf.keras.layers.Conv2D(64, (2, 2), activation='relu', kernel_initializer='he_normal', padding='same')(u7)                       
c7 = tf.keras.layers.Dropout(0.2)(c7)                                                                                                
c7 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c7)                       

u8 = tf.keras.layers.Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(c7)                                                 
u8 = tf.keras.layers.concatenate([u8, c2])                                                                                           
c8 = tf.keras.layers.Conv2D(32, (2, 2), activation='relu', kernel_initializer='he_normal', padding='same')(u8)                       
c8 = tf.keras.layers.Dropout(0.1)(c8)                                                                                                
c8 = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c8)                       

u9 = tf.keras.layers.Conv2DTranspose(16, (2, 2), strides=(2, 2), padding='same')(c8)                                                 
u9 = tf.keras.layers.concatenate([u9, c1], axis=3)                                                                                   
c9 = tf.keras.layers.Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u9)                       
c9 = tf.keras.layers.Dropout(0.1)(c9)                                                                                                
c9 = tf.keras.layers.Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c9)                       
outputs = tf.keras.layers.Conv2D(12, (1, 1), activation='softmax')(c9)                                                               
outputs = tf.keras.layers.Flatten(data_format=None)     (outputs)                                                                    
model = tf.keras.Model(inputs=[inputs], outputs=[outputs])                                                                           
cc = tf.keras.optimizers.Adam(learning_rate=0.0001, beta_1=0.9, beta_2=0.999, amsgrad=False)                                         
model.compile(optimizer=cc, loss='sparse_categorical_crossentropy',                                         
              metrics=['sparse_categorical_accuracy'],sample_weight_mode="temporal")  # metrics =[dice_coeff] model.summary()        
model.summary()                                                                                                                      
checkpointer = tf.keras.callbacks.ModelCheckpoint('chek12class3.h5', verbose = 1, save_best_only = True)                             
#                                                                                                                                    
print('############## Initial weights ############## : ', model.get_weights())                                                       
#callbacks = [                                                                                                                       
  # tf.keras.callbacks.EarlyStopping(patience=2, monitor='val_loss'), tf.keras.callbacks.TensorBoard(log_dir='logs')]                
#history = model.fit(train_generator, validation_split=0.1, batch_size=4,epochs = 100 ,callbacks = callbacks) #,callbacks = callbacks

class_weights = np.zeros((82944, 12))                                                                                                
class_weights[:, 0] += 7                                                                                                             
class_weights[:, 1] += 10                                                                                                            
class_weights[:, 2] += 2                                                                                                             
class_weights[:, 3] += 3                                                                                                             
class_weights[:, 4] += 4                                                                                                             
class_weights[:, 5] += 5                                                                                                             
class_weights[:, 6] += 6                                                                                                             
class_weights[:, 7] += 50                                                                                                            
class_weights[:, 8] += 8                                                                                                             
class_weights[:, 9] += 9                                                                                                             
class_weights[:, 10] += 50                                                                                                           
class_weights[:, 11] += 11                                                                                                           

history = model.fit(X_train, Y_train, validation_split=0.18, batch_size=1,epochs = 60 ,sample_weight=class_weights) #class_weight=clas
82944是我的样本的288*288 h和w,12是类数

我得到了这个错误:

ValueError: Found a sample_weight array with shape (82944, 12) for an input with shape (481, 288, 288). sample_weight cannot be broadcast.
ValueError: Found a sample_weight array with shape (481,). In order to use timestep-wise sample weighting, you should pass a 2D sample_weight array.
在此链接中,样本重量应为(训练数据的nbr、训练数据的形状)

然后我在输出前添加了展平层,它不起作用

我的模型的体系结构:

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, 288, 288, 3) 0                                            
__________________________________________________________________________________________________
lambda (Lambda)                 (None, 288, 288, 3)  0           input_1[0][0]                    
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 288, 288, 16) 448         lambda[0][0]                     
__________________________________________________________________________________________________
dropout (Dropout)               (None, 288, 288, 16) 0           conv2d[0][0]                     
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 288, 288, 16) 2320        dropout[0][0]                    
__________________________________________________________________________________________________
max_pooling2d (MaxPooling2D)    (None, 144, 144, 16) 0           conv2d_1[0][0]                   
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 144, 144, 32) 4640        max_pooling2d[0][0]              
__________________________________________________________________________________________________
dropout_1 (Dropout)             (None, 144, 144, 32) 0           conv2d_2[0][0]                   
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 144, 144, 32) 9248        dropout_1[0][0]                  
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, 72, 72, 32)   0           conv2d_3[0][0]                   
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 72, 72, 64)   18496       max_pooling2d_1[0][0]            
__________________________________________________________________________________________________
dropout_2 (Dropout)             (None, 72, 72, 64)   0           conv2d_4[0][0]                   
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, 72, 72, 64)   36928       dropout_2[0][0]                  
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D)  (None, 36, 36, 64)   0           conv2d_5[0][0]                   
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 36, 36, 128)  73856       max_pooling2d_2[0][0]            
__________________________________________________________________________________________________
dropout_3 (Dropout)             (None, 36, 36, 128)  0           conv2d_6[0][0]                   
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 36, 36, 128)  147584      dropout_3[0][0]                  
__________________________________________________________________________________________________
max_pooling2d_3 (MaxPooling2D)  (None, 18, 18, 128)  0           conv2d_7[0][0]                   
__________________________________________________________________________________________________
conv2d_8 (Conv2D)               (None, 18, 18, 256)  295168      max_pooling2d_3[0][0]            
__________________________________________________________________________________________________
dropout_4 (Dropout)             (None, 18, 18, 256)  0           conv2d_8[0][0]                   
__________________________________________________________________________________________________
conv2d_9 (Conv2D)               (None, 18, 18, 256)  590080      dropout_4[0][0]                  
__________________________________________________________________________________________________
conv2d_transpose (Conv2DTranspo (None, 36, 36, 128)  131200      conv2d_9[0][0]                   
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 36, 36, 256)  0           conv2d_transpose[0][0]           
                                                                 conv2d_7[0][0]                   
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, 36, 36, 128)  295040      concatenate[0][0]                
__________________________________________________________________________________________________
dropout_5 (Dropout)             (None, 36, 36, 128)  0           conv2d_10[0][0]                  
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, 36, 36, 128)  147584      dropout_5[0][0]                  
__________________________________________________________________________________________________
conv2d_transpose_1 (Conv2DTrans (None, 72, 72, 64)   32832       conv2d_11[0][0]                  
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 72, 72, 128)  0           conv2d_transpose_1[0][0]         
                                                                 conv2d_5[0][0]                   
__________________________________________________________________________________________________
conv2d_12 (Conv2D)              (None, 72, 72, 64)   32832       concatenate_1[0][0]              
__________________________________________________________________________________________________
dropout_6 (Dropout)             (None, 72, 72, 64)   0           conv2d_12[0][0]                  
__________________________________________________________________________________________________
conv2d_13 (Conv2D)              (None, 72, 72, 64)   36928       dropout_6[0][0]                  
__________________________________________________________________________________________________
conv2d_transpose_2 (Conv2DTrans (None, 144, 144, 32) 8224        conv2d_13[0][0]                  
__________________________________________________________________________________________________
concatenate_2 (Concatenate)     (None, 144, 144, 64) 0           conv2d_transpose_2[0][0]         
                                                                 conv2d_3[0][0]                   
__________________________________________________________________________________________________
conv2d_14 (Conv2D)              (None, 144, 144, 32) 8224        concatenate_2[0][0]              
__________________________________________________________________________________________________
dropout_7 (Dropout)             (None, 144, 144, 32) 0           conv2d_14[0][0]                  
__________________________________________________________________________________________________
conv2d_15 (Conv2D)              (None, 144, 144, 32) 9248        dropout_7[0][0]                  
__________________________________________________________________________________________________
conv2d_transpose_3 (Conv2DTrans (None, 288, 288, 16) 2064        conv2d_15[0][0]                  
__________________________________________________________________________________________________
concatenate_3 (Concatenate)     (None, 288, 288, 32) 0           conv2d_transpose_3[0][0]         
                                                                 conv2d_1[0][0]                   
__________________________________________________________________________________________________
conv2d_16 (Conv2D)              (None, 288, 288, 16) 4624        concatenate_3[0][0]              
__________________________________________________________________________________________________
dropout_8 (Dropout)             (None, 288, 288, 16) 0           conv2d_16[0][0]                  
__________________________________________________________________________________________________
conv2d_17 (Conv2D)              (None, 288, 288, 16) 2320        dropout_8[0][0]                  
__________________________________________________________________________________________________
conv2d_18 (Conv2D)              (None, 288, 288, 12) 204         conv2d_17[0][0]                  
==================================================================================================
我认为这个解决方案可能会奏效:

sample_weights = np.zeros(len(Y_train))     
# your own weight corresponding here:       
sample_weights[Y_train[Y_train==0]] = 7     
sample_weights[Y_train[Y_train==1]] = 10    
sample_weights[Y_train[Y_train==2]] = 2     
sample_weights[Y_train[Y_train==3]] = 3     
sample_weights[Y_train[Y_train==4]] = 4     
sample_weights[Y_train[Y_train==5]] = 5     
sample_weights[Y_train[Y_train==6]] = 6     
sample_weights[Y_train[Y_train==7]] = 50    
sample_weights[Y_train[Y_train==8]] = 8     
sample_weights[Y_train[Y_train==9]] = 9     
sample_weights[Y_train[Y_train==10]] = 50   
sample_weights[Y_train[Y_train==11]] = 11   
我得到了这个错误:

ValueError: Found a sample_weight array with shape (82944, 12) for an input with shape (481, 288, 288). sample_weight cannot be broadcast.
ValueError: Found a sample_weight array with shape (481,). In order to use timestep-wise sample weighting, you should pass a 2D sample_weight array.

您误用了
样品重量
。顾名思义,它在每个样本中分配一个权重;因此,尽管您只有481个样本,但您传递了长度为82944(另外还有2维)的内容,因此出现了预期的错误:

ValueError: Found a sample_weight array with shape (82944, 12) for an input with shape (481, 288, 288). sample_weight cannot be broadcast.
因此,您实际需要的是一个长度等于您的训练样本的
sample\u weight
1D数组,其中的每个元素都是相应样本的权重,反过来,每个类的权重应该相同,如您所示

下面是如何使用12个类的一些虚拟数据
y
,仅使用30个样本来完成此操作:

import numpy as np

y = np.random.randint(12, size=30) # dummy data, 12 classes
y
# array([ 8,  0,  6,  8,  9,  9,  7, 11,  6,  4,  6,  3, 10,  8,  7,  7, 11,
#        2,  5,  8,  8,  1,  7,  2,  7,  9,  5,  2,  0,  0])

sample_weights = np.zeros(len(y))
# your own weight corresponding here:
sample_weights[y==0] = 7                                                                                                             
sample_weights[y==1] = 10                                                                                                            
sample_weights[y==2] = 2                                                                                                             
sample_weights[y==3] = 3                                                                                                             
sample_weights[y==4] = 4                                                                                                             
sample_weights[y==5] = 5                                                                                                             
sample_weights[y==6] = 6                                                                                                             
sample_weights[y==7] = 50                                                                                                            
sample_weights[y==8] = 8                                                                                                             
sample_weights[y==9] = 9                                                                                                             
sample_weights[y==10] = 50                                                                                                           
sample_weights[y==11] = 11  

sample_weights
# result:
array([ 8.,  7.,  6.,  8.,  9.,  9., 50., 11.,  6.,  4.,  6.,  3., 50.,
        8., 50., 50., 11.,  2.,  5.,  8.,  8., 10., 50.,  2., 50.,  9.,
        5.,  2.,  7.,  7.])
让我们把它们放在一个漂亮的数据框中,以便更好地查看:

import pandas as pd
d = {'y': y, 'weight': sample_weights}
df = pd.DataFrame(d)
print(df.to_string(index=False))

# result:

  y  weight
  8     8.0
  0     7.0
  6     6.0
  8     8.0
  9     9.0
  9     9.0
  7    50.0
 11    11.0
  6     6.0
  4     4.0
  6     6.0
  3     3.0
 10    50.0
  8     8.0
  7    50.0
  7    50.0
 11    11.0
  2     2.0
  5     5.0
  8     8.0
  8     8.0
  1    10.0
  7    50.0
  2     2.0
  7    50.0
  9     9.0
  5     5.0
  2     2.0
  0     7.0
  0     7.0

当然,您应该在哪里替换
模型中的
sample\u weight=class\u weights
481@desertnaut我只是发布我的问题[,请查看,再次感谢:)非常感谢您花时间给出这个答案我尝试了您的建议并用y_列(标签)替换y,我得到了一个索引器:对于数组的许多索引,我发布了我所有的代码,我想这是由于我的示例形状(2d)@SamBn04一次一个错误!:)用新问题打开一个新问题(以及您的
Y\u列车的样本
)。