Tensorflow 具有CRNN的OCR。如何获得预测分数
我有一个Tensorflow 具有CRNN的OCR。如何获得预测分数,tensorflow,machine-learning,deep-learning,conv-neural-network,ocr,Tensorflow,Machine Learning,Deep Learning,Conv Neural Network,Ocr,我有一个CRNN配备了CTC损耗输出的模型 我有预测,我使用keras.backend.ctc\u decode对其进行解码。如文档()中所述,函数将返回一个带有解码结果的元组,以及一个带有预测对数概率的张量 keras.backend.ctc_decode可以接受多个预测值,但我需要一次传递一次 代码如下: def decode_single_prediction(pred, num_to_char): input_len = np.ones(pred.shape[0])
CRNN
配备了CTC损耗
输出的模型
我有预测,我使用keras.backend.ctc\u decode
对其进行解码。如文档()中所述,函数将返回一个带有解码结果的元组,以及一个带有预测对数概率的张量
keras.backend.ctc_decode
可以接受多个预测值,但我需要一次传递一次
代码如下:
def decode_single_prediction(pred, num_to_char):
input_len = np.ones(pred.shape[0]) * pred.shape[1]
# Use greedy search. For complex tasks, you can use beam search
decoded = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)
# decoded[0] is supposed to be the decoded result
# decoded[1] is supposed to be it's log probability
accuracy = float(decoded[1][0][0])
# take the resultin encoded char until it gets -1
result = decoded[0][0][:,: np.argmax(decoded[0][0] == -1)]
output_text = tf.strings.reduce_join(num_to_char(result)).numpy().decode("utf-8")
return (output_text, accuracy)
for image in images:
pred = prediction_model.predict(image)
# num_to_char is the mapping from number to char
pred_texts, acc = decode_single_prediction(pred, num_to_char)
print("True value: " + <true_result> + " prediction: " + pred_texts + " acc: " + str(acc))
预测总是正确的。然而,我认为它的概率似乎不是我所期望的。它们看起来像完全随机的数字,甚至比1或2还要大!我做错了什么
提前谢谢你
True value: test0, prediction: test0, acc: 1.841524362564087
True value: test1, prediction: test1, acc: 0.9661365151405334
True value: test2, prediction: test2, acc: 1.0634151697158813
True value: test3, prediction: test3, acc: 2.471940755844116
True value: test4, prediction: test4, acc: 1.4866207838058472
True value: test5, prediction: test5, acc: 0.7630811333656311
True value: test6, prediction: test6, acc: 0.35642576217651367
True value: test7, prediction: test7, acc: 1.5693446397781372
True value: test8, prediction: test8, acc: 0.9700028896331787
True value: test9, prediction: test9, acc: 1.4783780574798584