Python 尝试将图像数据输入DNNReressor时出现值错误

Python 尝试将图像数据输入DNNReressor时出现值错误,python,tensorflow,Python,Tensorflow,让DNNRegressor接受图像数据时遇到一些问题。我在运行代码时遇到此错误: ValueError: Cannot reshape a tensor with 147456 elements to shape [384,442368] (169869312 elements) for 'dnn/input_from_feature_columns/input_layer/image/Reshape' (op: 'Reshape') with input shapes: [384,384,1]

DNNRegressor
接受图像数据时遇到一些问题。我在运行代码时遇到此错误:

ValueError: Cannot reshape a tensor with 147456 elements to shape [384,442368] (169869312 elements) for 'dnn/input_from_feature_columns/input_layer/image/Reshape' (op: 'Reshape') with input shapes: [384,384,1], [2] and with input tensors computed as partial shapes: input[1] = [384,442368].
这是一个有问题代码的简化版本

import os
import os.path

import tensorflow as tf

SPLIT_PERCENTAGE = 0.8

# snip snip

# ids is a List of strings
# filenames is a List of filenames of image files on the disk
# labels is a List of int scores

estimator = tf.estimator.DNNRegressor(
    feature_columns=[
        tf.feature_column.numeric_column('image', shape=(384, 384, 3)),
    ],
    hidden_units=[1024, 512, 256],
    model_dir=output_dir,
)

estimator.train(input_fn=lambda: input_fn(False, ids, filenames, labels))

def input_fn(is_training, ids, filenames, labels):
    id_tensor = tf.constant(ids, dtype=tf.string)
    filenames_tensor = tf.constant(filenames, dtype=tf.string)
    labels_tensor = tf.constant(labels, dtype=tf.float32)

    ds = tf.data.Dataset.from_tensor_slices(((id_tensor, filenames_tensor), labels_tensor))
    print(ds)
    ds = ds.take(int(len(labels) * SPLIT_PERCENTAGE)) if is_training else ds.skip(int(len(labels) * SPLIT_PERCENTAGE))
    ds = ds.map(load_image)

    iterator = ds.make_one_shot_iterator()
    features, labels = iterator.get_next()

    return features, labels

def load_image(id_file, score):
    _, filename = id_file
    image_string = tf.read_file(filename)
    image_decoded = tf.image.decode_jpeg(image_string, channels=1)
    image_converted = tf.image.convert_image_dtype(image_decoded, tf.float16)
    image_resized = tf.image.resize_image_with_crop_or_pad(image_converted, 384, 384)

    return {'image': image_resized}, [tf.log(score)]

我怀疑这与我如何声明我的功能栏有关,但做的事情几乎完全一样,而且也有效。我错过了什么?

我去散步,买了一个漂亮的新枕头,喝了一杯啤酒,意识到
.batch()
调用
数据集中的
.batch(1)
数据是消费的必要条件