Python ValueError:训练CNN模型时,无法将NumPy数组转换为张量(不支持的对象类型NumPy.ndarray)

Python ValueError:训练CNN模型时,无法将NumPy数组转换为张量(不支持的对象类型NumPy.ndarray),python,numpy,tensorflow,machine-learning,keras,Python,Numpy,Tensorflow,Machine Learning,Keras,每当我尝试使用4d numpy阵列训练CNN模型时,就会出现上述错误。模型生成器功能如下所示: def build_model(input_shape, LR=.001, phone_count=43): #build the network model= keras.Sequential() #conv layer 1 #model.add(keras.layers.Conv2D(64,(3,3), activation='relu', input_shape=in

每当我尝试使用4d numpy阵列训练CNN模型时,就会出现上述错误。模型生成器功能如下所示:

def build_model(input_shape, LR=.001, phone_count=43):
  
  #build the network
  model= keras.Sequential()
  
  #conv layer 1
  #model.add(keras.layers.Conv2D(64,(3,3), activation='relu', input_shape=input_shape, kernel_regularizer=keras.regularizers.l2(0.001), data_format="channels_first"))
  model.add(keras.layers.Conv2D(64,(3,3), activation='relu', input_shape=input_shape, kernel_regularizer=keras.regularizers.l2(0.001)))
  model.add(keras.layers.BatchNormalization())
  model.add(keras.layers.MaxPooling2D((3,3), strides=(2,2), padding='same'))
  
  #conv layer 2
  model.add(keras.layers.Conv2D(32,(3,3), activation='relu', kernel_regularizer=keras.regularizers.l2(0.001)))
  model.add(keras.layers.BatchNormalization())
  model.add(keras.layers.MaxPooling2D((3,3), strides=(2,2), padding='same'))
  
  #conv layer 3
  model.add(keras.layers.Conv2D(32,(2,2), activation='relu', kernel_regularizer=keras.regularizers.l2(0.001)))
  model.add(keras.layers.BatchNormalization())
  model.add(keras.layers.MaxPooling2D((2,2), strides=(2,2), padding='same'))
  
  #Flatten the output
  model.add(keras.layers.Flatten())
  model.add(keras.layers.Dense(64, activation='relu'))
  model.add(keras.layers.Dropout(0.3))

  #softmax layer
  model.add(keras.layers.Dense(phone_count, activation='softmax'))

  #compile the model
  the_optimizer= keras.optimizers.Adamax(learning_rate=LR)
  model.compile(optimizer=the_optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy'])

  model.summary()

  return model
这个函数用于将npy文件中的数据提取为数据帧(以便更容易删除不必要的数据/行),然后转换为numpy数组来训练和测试模型

def prep_data(shape_num=3, train_min=50):

  #read data from npy file
  data1=np.load(os.path.join(rootdir+"/Train(Small Win).npy"), allow_pickle=True)
  data2=np.load(os.path.join(rootdir+"/Test(Small Win).npy"), allow_pickle=True)


  train= pd.DataFrame(data1, columns=['Phone', 'Signal'])
  test= pd.DataFrame(data2, columns=['Phone', 'Signal'])
  train['Phone']=train['Phone'].astype(int)
  test['Phone']=test['Phone'].astype(int)


  #shuffle data before splitting
  train= sklearn.utils.shuffle(train)
  train= train.reset_index(drop=True)
  test= sklearn.utils.shuffle(test)
  test= test.reset_index(drop=True)

  #filter unnecessary data based on criteria
  train, test= data_filter(train, test, shape_num, train_min)
  

  #split data 
  test, validation= sklearn.model_selection.train_test_split(test, test_size=0.4, shuffle=False)
  validation=validation.reset_index(drop=True)
  x_train, y_train, x_test, y_test= train['Signal'], train['Phone'], test['Signal'], test['Phone']  
  x_validation, y_validation= validation['Signal'], validation['Phone']

  
  x_train=x_train.to_numpy()
  y_train=y_train.to_numpy()
  x_test=x_test.to_numpy()
  y_test=y_test.to_numpy()
  x_validation=x_validation.to_numpy()
  y_validation=y_validation.to_numpy()

  #convert from 2D -> 3D
  # x_test=x_test[..., np.newaxis]
  # x_train=x_train[..., np.newaxis]
  # x_validation=x_validation[..., np.newaxis]

  return x_train, y_train, x_test, y_test, x_validation, y_validation
这是我用来训练模型的主要函数,我得到了以下错误

def main():

  x_train, y_train, x_test, y_test, x_validation, y_validation= prep_data(shape_num=12, train_min=5)
  
  #build the model
  model= build_model((12, 16, 1), phone_count=42)

  #train the model
  model.fit(x_train, y_train, epochs=40, batch_size=32, 
            validation_data=(x_validation, y_validation))  
  
  #evaluate the model
  error, accuracy= model.evaluate(x_test, y_test)
  print(f"Test error: {error}, Test accuracy: {accuracy}")

  #save model
  model.save(model.h5)

我尝试使用
tf.convert_to_tensor(x_train)
来处理我的所有训练和验证数据,但它只是返回相同的错误。下面我验证了所有元素的形状(16,12)和类型()都是相同的,并显示了第一个元素的示例

print(x_train.shape)
print(x_train[0].shape)
  print(type(x_train[0]))
  print(x_validation[0].shape)
  print(type(x_validation[0]))
  print(x_train.__len__())
  print(y_train.__len__())
  print(x_train[0])


  for i in range(x_train.__len__()):
    if x_train[i].shape != x_train[0].shape or type(x_train[i]) != type(x_train[0]):
      print("Index ",i," is: ", x_train[i].shape)
      print("type is: ", x_validation[i].shape)

  for i in range(x_validation.__len__()):
    if x_validation[i].shape != x_validation[0].shape or type(x_validation[i]) != type(x_validation[0]):
      print("Index ",i," is: ", x_validation[i].shape)
      print("type is: ", x_validation[i].shape)
结果:

(9592,)
(16, 12)
<class 'numpy.ndarray'>
(16, 12)
<class 'numpy.ndarray'>
9592
9592
[[-6.21559034e+02 -6.03861092e+02 -6.44275070e+02 -6.21108087e+02
  -6.33902502e+02 -6.10202319e+02 -6.38066649e+02 -6.14831453e+02
  -6.08786659e+02 -5.54751120e+02 -5.59108425e+02 -5.90645072e+02]
 [ 1.62488404e+02  1.97702161e+02  1.93338819e+02  2.00768320e+02
   1.96443196e+02  2.06493442e+02  1.95534267e+02  2.05817857e+02
   1.97520058e+02  1.80149207e+02  1.51495948e+02  1.36431422e+02]
 [-1.68000882e+01 -1.71383988e+01 -2.19247949e+01 -3.50331972e+01
  -3.22168674e+01 -3.66609409e+01 -2.59634154e+01 -3.14709250e+01
  -4.25033816e+01 -4.43741521e+01 -3.70234923e+01 -2.95822967e+01]
 [ 4.10432947e+01  4.87035804e+01  5.52045579e+01  5.74210124e+01
   5.85131273e+01  5.99729371e+01  5.79334290e+01  5.10461554e+01
   4.97165940e+01  5.25213143e+01  5.17593922e+01  5.76097801e+01]
 [-3.43969005e+01 -5.62415603e+01 -6.69549272e+01 -7.22399067e+01
  -7.08110925e+01 -7.29823164e+01 -7.44342335e+01 -7.85082031e+01
  -7.90021868e+01 -9.42318038e+01 -1.07264419e+02 -7.69748910e+01]
 [-3.27785843e+01 -4.15082566e+01 -3.81622595e+01 -4.26821583e+01
  -4.28947818e+01 -4.35927895e+01 -4.20522233e+01 -4.07253579e+01
  -3.54880234e+01 -1.24468099e+01  7.66788474e+00  2.93352370e+01]
 [ 8.13909911e+00  8.86097919e+00  4.62275236e+00  6.80030771e+00
   7.73078526e+00  3.24659302e+00  4.32970141e+00  1.21840123e+00
  -5.97645909e-01 -1.59230216e+01 -1.83660643e+01 -2.36423881e+01]
 [ 4.72338015e+00  1.05095395e+01  1.83643214e+01  1.62532773e+01
   1.63903496e+01  1.48676056e+01  1.80896153e+01  1.67417707e+01
   1.57838347e+01  2.75592116e+01  3.30267593e+01  1.33186030e+01]
 [-1.70412263e+01 -2.89931507e+01 -3.17646436e+01 -3.13428307e+01
  -3.21056855e+01 -3.13025217e+01 -2.92779319e+01 -3.28383410e+01
  -2.86348024e+01 -3.56308366e+01 -3.72640420e+01 -2.96905489e+01]
 [ 1.23463742e+01  2.39805333e+00 -3.08857470e+00 -3.72663302e+00
  -5.67306760e+00 -6.75569345e+00 -6.68106808e+00 -8.97175994e+00
  -1.01855553e+01 -1.85199448e+01 -1.87053295e+01 -1.61857339e+00]
 [ 4.64371993e+00 -6.73981799e+00 -1.21180531e+01 -6.47924891e+00
  -4.09369632e+00 -7.61701294e-01  8.39204718e-01  4.33606029e+00
   2.50502410e+00  5.49674950e+00  1.75774383e+01  3.36114716e+01]
 [-1.34966338e+01 -2.31340623e+01 -2.69256468e+01 -2.52562186e+01
  -2.49909368e+01 -3.01723141e+01 -2.65044181e+01 -2.28321862e+01
  -1.76064424e+01 -1.96508533e+01 -4.32244206e+00  1.07479438e+01]
 [ 2.00379681e+00  1.15979206e+00  5.48741698e+00  9.24981351e+00
   1.23602788e+01  7.61237738e+00  7.21907492e+00  6.76921775e+00
   1.54061820e+01  1.40574085e+01  2.71901434e+01  1.73292818e+01]
 [-3.95202463e-01 -4.64221159e+00 -2.56725829e+00 -9.75469429e+00
  -1.53314060e+00 -7.13854658e+00  6.17673619e-01  1.11667664e-01
   5.55916296e+00  1.19341530e+01  2.47728353e+01  4.94114903e+00]
 [ 1.13522109e+01  1.48745532e+01  1.94529171e+01  1.26263470e+01
   1.71353642e+01  1.04325893e+01  1.76135506e+01  1.50917866e+01
   2.00581224e+01  1.27428903e+01  1.49328414e+01 -4.96430014e+00]
 [ 1.08379527e+01  1.16791172e+01  9.86733739e+00  8.62182114e+00
   1.09428703e+01  1.16916648e+01  1.07482436e+01  9.04880045e+00
   8.05048866e+00 -1.34804509e+00 -5.49922438e+00 -5.12267269e+00]]
(9592,)
(16, 12)
(16, 12)
9592
9592
[[-6.21559034e+02-6.03861092e+02-6.44275070e+02-6.211087E+02
-6.33902502e+02-6.10202319e+02-6.38066649e+02-6.14831453e+02
-6.08786659e+02-5.54751120e+02-5.59108425e+02-5.90645072e+02]
[1.62488404e+02 1.97702161e+02 1.93338819e+02 2.00768320e+02
1.96443196e+02 2.06493442e+02 1.95534267e+02 2.05817857e+02
1.97520058e+021.80149207e+021.51495948e+021.36431422e+02]
[-1.68000882e+01-1.71383988e+01-2.19247949e+01-3.50331972e+01
-3.22168674e+01-3.66609409e+01-2.59634154e+01-3.14709250e+01
-4.25033816e+01-4.43741521e+01-3.70234923e+01-2.95822967e+01]
[4.10432947e+01 4.87035804e+01 5.52045579e+01 5.74210124e+01
5.85131273e+01 5.99729371e+01 5.79334290e+01 5.10461554e+01
4.97165940e+01 5.25213143e+01 5.17593922e+01 5.76097801e+01]
[-3.439699005E+01-5.62415603e+01-6.69549272e+01-7.22399067e+01
-7.08110925e+01-7.29823164e+01-7.44342335e+01-7.85082031e+01
-7.90021868e+01-9.42318038e+01-1.07264419e+02-7.69748910e+01]
[-3.27785843e+01-4.15082566e+01-3.81622595e+01-4.26821583e+01
-4.28947818e+01-4.35927895e+01-4.20522233e+01-4.07253579e+01
-3.54880234e+01-1.24468099e+01 7.66788474e+00 2.93352370e+01]
[8.13909911e+00 8.86097919e+00 4.62275236e+00 6.80030771e+00
7.73078526e+00 3.24659302e+00 4.32970141e+00 1.21840123e+00
-5.97645909e-01-1.59230216e+01-1.83660643e+01-2.36423881e+01]
[4.72338015e+00 1.050953955E+01 1.83643214e+01 1.62532773e+01
1.63903496e+01 1.48676056e+01 1.80896153e+01 1.67417707e+01
1.57838347e+01 2.75592116e+01 3.30267593e+01 1.3318630E+01]
[-1.70412263e+01-2.89931507e+01-3.17646436e+01-3.13428307e+01
-3.21056855e+01-3.13025217e+01-2.92779319e+01-3.28383410e+01
-2.86348024e+01-3.56308366e+01-3.72640420e+01-2.96905489e+01]
[1.23463742e+01 2.39805333e+00-3.08857470e+00-3.72663302e+00
-5.67306760e+00-6.75569345e+00-6.68106808e+00-8.97175994e+00
-1.01855553e+01-1.85199448e+01-1.87053295e+01-1.61857339e+00]
[4.64371993e+00-6.73981799e+00-1.21180531e+01-6.47924891e+00
-4.09369632e+00-7.61701294e-01 8.39204718e-01 4.3306029E+00
2.50502410e+00 5.49674950e+00 1.75774383e+01 3.36114716e+01]
[-1.34966338e+01-2.31340623e+01-2.69256468e+01-2.52562186e+01
-2.49909368e+01-3.01723141e+01-2.6504481E+01-2.28321862e+01
-1.76064424e+01-1.96508533e+01-4.32244206e+00 1.07479438e+01]
[2.00379681e+00 1.15979206e+00 5.48741698e+00 9.24981351e+00
1.23602788e+017.61237738e+007.21907492e+006.76921775e+00
1.54061820e+01 1.40574085e+01 2.71901434e+01 1.73292818e+01]
[-3.95202463e-01-4.64221159e+00-2.56725829e+00-9.75469429e+00
-1.53314060e+00-7.13854658e+00 6.17673619e-01 1.11667664e-01
5.55916296e+00 1.19341530e+01 2.47728353e+01 4.94114903e+00]
[1.13522109e+01 1.48745532e+01 1.94529171e+01 1.26263470e+01
1.71353642e+01 1.04325893e+01 1.76135506e+01 1.50917866e+01
2.00581224e+01 1.27428903e+01 1.49328414e+01-4.96430014e+00]
[1.08379527e+01 1.16791172e+01 9.86733739e+00 8.62182114e+00
1.09428703e+01 1.16916648e+01 1.07482436e+01 9.04880045e+00
8.05048866e+00-1.34804509e+00-5.49922438e+00-5.12267269e+00]]
如果你知道我为什么会犯这个错误,请告诉我


(对于那些对数据的使用感兴趣的人,其MFCC(16)是从用于分类的语音数据中提取的)

xtrain的
dtype
是什么?@hpaulj
x_train.dtype
总体上是
object
但元素都是``float64`,所以它是一个包含对象引用的数组,它们本身就是数组。tf不能处理它。@hpaulj据我所知,它应该是一个3d数组(矩阵数组)。在您询问之前,我没有检查
dtype
,但随后我检查了数组
x_列
中所有矩阵的
dtype
(即
x_列[0:]dtype
),它们都是
dtype float64
。但是当我检查x_列车本身的
dtype
时,它返回了
object
,但我不知道为什么。我对keras和numpy阵列非常陌生,如果我的问题看起来很奇怪,我很抱歉。我的意思是说3d阵列不是“3d阵列的阵列”