Python 含Keras的序列稠密网络的输入维数误差

Python 含Keras的序列稠密网络的输入维数误差,python,tensorflow,keras,Python,Tensorflow,Keras,这是一个很长的问题,因为我试图尽可能多地解释我的问题,因为这对我来说是一个反复出现的问题,我真的不明白,所以感谢您花时间阅读我的文章 我想创建一个顺序密集模型,该模型以如下维度作为输入列表: ModelDense = Sequential() ModelDense.add(Dense(380, input_shape=(None,185), activation='elu', kernel_initializer='glorot_normal')) ModelDense.add(Dense(3

这是一个很长的问题,因为我试图尽可能多地解释我的问题,因为这对我来说是一个反复出现的问题,我真的不明白,所以感谢您花时间阅读我的文章

我想创建一个顺序密集模型,该模型以如下维度作为输入列表:

ModelDense = Sequential()

ModelDense.add(Dense(380, input_shape=(None,185), activation='elu', kernel_initializer='glorot_normal'))
ModelDense.add(Dense(380, activation='elu', kernel_initializer='glorot_normal'))
ModelDense.add(Dense(380, activation='elu', kernel_initializer='glorot_normal'))
ModelDense.add(Dense(7, activation='elu', kernel_initializer='glorot_normal'))
optimizer = tf.keras.optimizers.Adam(lr=0.00025)

ModelDense.compile(loss='mean_squared_error', optimizer=optimizer, metrics=['accuracy'])
[批次大小、数据维度]

因此,我将我的网络定义为:

ModelDense = Sequential()

ModelDense.add(Dense(380, input_shape=(None,185), activation='elu', kernel_initializer='glorot_normal'))
ModelDense.add(Dense(380, activation='elu', kernel_initializer='glorot_normal'))
ModelDense.add(Dense(380, activation='elu', kernel_initializer='glorot_normal'))
ModelDense.add(Dense(7, activation='elu', kernel_initializer='glorot_normal'))
optimizer = tf.keras.optimizers.Adam(lr=0.00025)

ModelDense.compile(loss='mean_squared_error', optimizer=optimizer, metrics=['accuracy'])
但当我使用这个网络时,输入的形状如下:(1185)我得到了一个错误:

检查输入时出错:预期密集_输入有3维,但得到了形状为(185,1)的数组

不要问我为什么说我的向量形状是(1185),我们在错误消息中看到(185,1),因为在将其作为网络输入之前检查数组形状时,显示的形状是(1185)

好的,我检查了一些主题,然后我发现其中解释了:

密集层需要输入(批次大小、输入大小) 或(批量大小,可选,…,可选,输入大小)

这就是我做的,不是吗? 但我也看到:

Keras中的形状:

因此,即使您使用了input_shape=(50,50,3),当keras向您发送消息或打印模型摘要时,它也会显示(无,50,50,3)

因此,在定义输入形状时,忽略批次大小: 输入_形状=(50,50,3)

好的!让我们试试我现在定义的输入层如下:

ModelDense.add(Dense(380, input_shape=(185,), activation='elu', kernel_initializer='glorot_normal'))
当我执行model.summary()时:

_________________________________________________________________图层(类型)输出形状参数# ============================================================================密集(密集)(无,380)70680 _________________________________________________________________致密(致密)(无,380)144780 _________________________________________________________________致密(致密)(无,380)144780 _________________________________________________________________密集型_3(密集型)(无,7)2667 =====================================================================总参数:362907个可培训参数:362907个不可培训参数: 0


好的,我想这是我想要的,但是当我给相同的数组作为输入时,我现在得到了错误:

ValueError:检查输入时出错:应具有密集的\u输入 形状(185,)但获得了形状为(1,)的数组

我很困惑,我误解了什么

\uuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuu:

预测功能:

def predict(dense_model, state, action_size, epsilon):

    alea = np.random.rand()

    # DEBUG
    print(state)
    print(np.array(state).shape)

    output = dense_model.predict(state)

    if (epsilon > alea):
        action = random.randint(1, action_size) - 1
        flag_alea = True

    else:
        action = np.argmax(output)
        flag_alea = False

    return output, action, flag_alea
Qs, action, flag_alea = predict(Dense_model, [state], ACTION_SIZE, Epsilon)
我使用函数的行:

def predict(dense_model, state, action_size, epsilon):

    alea = np.random.rand()

    # DEBUG
    print(state)
    print(np.array(state).shape)

    output = dense_model.predict(state)

    if (epsilon > alea):
        action = random.randint(1, action_size) - 1
        flag_alea = True

    else:
        action = np.argmax(output)
        flag_alea = False

    return output, action, flag_alea
Qs, action, flag_alea = predict(Dense_model, [state], ACTION_SIZE, Epsilon)
“调试”打印的确切结果:

[[0.0,0.0,0.0,0.1241002730206064,0.0,0.0,0.0,0.0,0.0,0.0,0.0, 0.18851780241253108, 0.0, 0.0, 0.2863141820958198, 0.0, 0.07328154770628756, 0.418848167539267, 0.07328154770628756, 0.2094240837696335, 0.42857142857142855, 0.0, 0.12410027302060064, 0.0, 0.0, 0.0, 0.0, 0.263306220774655, 0.14740566037735847, 0.40346984062941293, 0.675310642895732, 0.0, 0.0, 0.0, 0.0, 0.07328154770628756, 0.0, 0.4396892862377253, 0.0, 0.42857142857142855, 0.0, 0.12410027302060064, 0.08759635599159075, 0.0, 0.1401927621025243, 0.6755559204272007, 0.0, 0.0, 0.11564568886156315, 0.4051863857374392, 0.0, 0.0, 0.19087612139721322, 0.0, 0.07328154770628756, 0.6282722513089005, 0.14656309541257512, 0.10471204188481675, 0.42857142857142855, 0.0, 0.12410027302060064, 0.0, 0.0, 0.0, 0.0, 0.0974621385076755, 0.0, 0.0, 0.675310642895732, 0.0, 0.0, 0.0, 0.09543806069860661, 0.07328154770628756, 0.10471204188481675, 0.5129708339440129, 0.5233396901920598, 0.42857142857142855, 0.0, 0.0, 0.0, 0.0, 0.5528187746700128, 0.6755564266434103, 0.0, 0.0, 0.10086746015735323, 0.1350621285791464, 0.0, 0.0, 0.0, 0.0, 0.14891426591693724, 0.5166404112353377, 0.14656309541257512, 0.10471204188481675, 0.42857142857142855, 0.00846344605088234, 0.012550643645226955, 0.0, 0.0, 0.004527776502072811, 0.0, 0.001294999849051237, 0.019391579553484917, 0.02999694086611271, 0.0026073455810546875, 0.0, 0.0, 0.016546493396162987, 0.024497902020812035, 0.00018889713101089, 0.0, 0.0055684475228190420.0、6.7332292938383826226256565、0.0、0.01323232326282828285858565-05、0.00120920990905252525252525252525458586、0.0 0 0 0.0、0.0、0.0、0.0 0、0.0 0 0、0.0 0、0.0 0、0.0 0 0、0.013332326262626262767676767673737373737878787878787878787878787878787878787878787878787878787878787878787878787878787878787878787878787878787878787878780 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 21625, 0.000697580398991704, 0.00213554291985929, 0.0, 0.0021772112231701612, 0.012761476449668407, 0.015171871520578861, 0.001512336079031229, 0.0, 0.0, 0.008273545652627945, 0.01777557097375393, 0.006600575987249613, 0.0, 0.007174563594162464, 0.0, 0.004660750739276409, 0.009024208411574364, 0.0, 0.0014235835988074541, 0.0, 0.0, 0.0, 0.008785379119217396, 0.010602384805679321, 0.0024691042490303516, 0.0, 0.0, 0.003091508522629738, 0.0120345214381814, 0.003123666625469923, 0.0, 0.005664713680744171, 0.0, 0.004825159907341003, 0.0034197410568594933, 0.0030767947901040316, 0.004110954236239195, 0.0, 0.0, 0.001896441332064569, 0.002400417113676667, 0.0012791997287422419, 0.0, 0.0, 0.0, 0.002102752914652284,0.006922871805727482,0.004868669901043177,0.0,7.310241926461458e-05,0.0]]

(1185)

\uuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuu:

错误回溯:

文件“!Qltrain.py”,第360行,在 Qs,action,flag_alea=predict(密集模型,[状态],action_大小,Epsilon)文件“\Lib\Core.py”,predict中第336行 output=densite\u model.predict(state)文件“C:\Users\Odeven\AppData\Local\Programs\Python37\lib\site packages\tensorflow\Python\keras\engine\training.py”, 第1096行,在predict中 x、 检查步骤=True,步骤\u name='steps',steps=steps)文件“C:\Users\Odeven\AppData\Local\Programs\Python\Python37\li