Python 如何在tensorflow数据集中加载numpy数组

Python 如何在tensorflow数据集中加载numpy数组,python,tensorflow,tensorflow-datasets,Python,Tensorflow,Tensorflow Datasets,我试图在tensorflow 1.14中创建一个Dataset对象(我有一些无法为这个特定项目更改的遗留代码),从numpy数组开始,但每次尝试都会将所有内容复制到图形上,因此,当我创建一个事件日志文件时,它非常大(在本例中为719 MB) 最初我尝试使用这个函数“tf.data.Dataset.from_tensor_slices()”,但它不起作用,然后我了解到这是一个常见问题,有人建议我尝试使用生成器,因此我尝试使用以下代码,但再次得到一个巨大的事件文件(719 MB) def fetch

我试图在tensorflow 1.14中创建一个Dataset对象(我有一些无法为这个特定项目更改的遗留代码),从numpy数组开始,但每次尝试都会将所有内容复制到图形上,因此,当我创建一个事件日志文件时,它非常大(在本例中为719 MB)

最初我尝试使用这个函数“tf.data.Dataset.from_tensor_slices()”,但它不起作用,然后我了解到这是一个常见问题,有人建议我尝试使用生成器,因此我尝试使用以下代码,但再次得到一个巨大的事件文件(719 MB)

def fetch_批(x,y,批):
i=0
而我:
收益率(x[i,:,:,:],y[i])
i+=1
train,test=tf.keras.dataset.fashion\u mnist.load\u data()
图像、标签=列车
图像=图像/255
training_dataset=tf.data.dataset.from_生成器(获取_批,
args=[图像,np.int32(标签),批处理大小],输出类型=(tf.float32,tf.int32),
输出形状=(tf.TensorShape(特征形状),tf.TensorShape(标签形状)))
file\u writer=tf.summary.FileWriter(“/content”,graph=tf.get\u default\u graph())
我知道在这种情况下,我可以使用tensorflow_数据集API,这会更容易,但这是一个更一般的问题,它涉及如何创建一般的数据集,而不仅仅是使用mnist。
你能解释一下我做错了什么吗?谢谢你

我想这是因为你在
中使用了
参数
,从\u生成器
。这肯定会将提供的
args
放在图中

您可以定义一个函数,该函数将返回一个生成器,该生成器将遍历您的集合,例如(尚未测试):

def数据_生成器(图像、标签):
def fetch_示例():
i=0
尽管如此:
示例=(图像[i],标签[i])
i+=1
i%=len(标签)
产量示例
返回示例
这将在您的示例中给出:

train,test=tf.keras.dataset.fashion\mnist.load\u data()
图像、标签=列车
图像=图像/255
training_dataset=tf.data.dataset。从_生成器(数据_生成器(图像、标签),输出_类型=(tf.float32,tf.int32),
输出形状=(tf.TensorShape(特征形状),tf.TensorShape(标签形状))。批次(批次大小)
file\u writer=tf.summary.FileWriter(“/content”,graph=tf.get\u default\u graph())

请注意,我将
fetch\u batch
更改为
fetch\u examples
,因为您可能希望使用数据集实用程序(
.batch
)进行批处理。

您能否更详细地解释一下是什么导致事件文件如此大?它是在创建重复的子图吗?你能解释一下从张量切片中看什么不起作用吗?是的,我想你是对的。我马上就要发布我是如何解决这个问题的(事实上我在github找到了一个建议这个的人),它对我起了作用谢谢!酷,如果这个解决方案有效,那么请接受它。另外,下次回答您的问题时,请尝试提供一个如果复制粘贴并包含版本号(通常对于tensorflow,API在1.14和2.0之间变化很大)就可以工作的代码。完成了,但还有一件事:“您将无法使用多处理”是什么意思?还有什么更有效的方法吗?实际上忘记我说过的,在这个阶段你不需要多重处理(只是获取数据),所以我完全困惑了。
def fetch_batch(x, y, batch):
    i = 0
    while i < batch:
        yield (x[i,:,:,:], y[i])
        i +=1

train, test = tf.keras.datasets.fashion_mnist.load_data()
images, labels = train  
images = images/255

training_dataset = tf.data.Dataset.from_generator(fetch_batch, 
    args=[images, np.int32(labels), batch_size], output_types=(tf.float32, tf.int32), 
    output_shapes=(tf.TensorShape(features_shape), tf.TensorShape(labels_shape)))

file_writer = tf.summary.FileWriter("/content", graph=tf.get_default_graph())