批量输入到tensorflow中的某一层

批量输入到tensorflow中的某一层,tensorflow,Tensorflow,我正在一个基于inception-v3的网络上工作。我成功地训练了这个网络,现在我想将一批opencv图像提供给我的网络,并获得一些输出。 网络的原始占位符接受一个字符串并将其解码为jpg,但我使用opencv读取视频帧,并在nparray列表中转换它们: for cnt in range(batch_size): frameBuffer = [] if (currentPosition >= nFrames): break

我正在一个基于inception-v3的网络上工作。我成功地训练了这个网络,现在我想将一批opencv图像提供给我的网络,并获得一些输出。 网络的原始占位符接受一个字符串并将其解码为jpg,但我使用opencv读取视频帧,并在
nparray
列表中转换它们:

  for cnt in range(batch_size):
        frameBuffer = []
        if (currentPosition >= nFrames):
            break
        ret, frame = vidFile.read()
        img_data = np.asarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
        frameBuffer.append(img_data)
        currentPosition += multiplier
如果我想处理单个图像,因为我直接从opencv读取帧,我将它们转换为np数组,然后将其馈送到inception网络的“Cast:0”层:

pred = sess.run([predictions], {'Cast:0': img_data})
到目前为止,结果还可以。但我想输入一批帧:我尝试以当前方式使用
feed\u dict

images = tf.placeholder(tf.float32, [batch_size,width,height, 3])
image_batch = tf.stack(frameBuffer)

feed_dict = {images: image_batch}
avgRepresentation, pred = sess.run([pool_avg, predictions],{'Cast:0': feed_dict})
但是我犯了错误;我知道我给这批货喂料时出错了。你有什么建议我如何将一批图像传送到网络的某一层吗?

你的传送目录(至少)有一个问题:传送目录通常是一个字典,以张量或字符串(表示张量名称)为键,值(按常规类型、np数组等给出)

这里使用的是
{'Cast:0':feed_dict}
,因此字典的值本身就是一个字典,这对tensorflow没有任何意义。您需要将值放在那里,即图像的串联(解码、转换等)。另外,如果我遗漏了什么,我也很抱歉,但是我想
frameBuffer
应该包含批处理的所有图像,因此它应该在
for
循环之外初始化

此代码应适用于:

frameBuffer = []
for cnt in range(batch_size):
        if (currentPosition >= nFrames):
            break
        ret, frame = vidFile.read()
        img_data = np.asarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
        frameBuffer.append(img_data)
        currentPosition += multiplier
avgRepresentation, pred = sess.run([pool_avg, predictions],{'Cast:0': np.asarray(frameBuffer)})

你能告诉我们你得到的错误吗?pred=sess.run([predictions],{'Cast:0':feed_dict})文件“/usr/local/lib/python2.7/dist packages/tensorflow/python/client/session.py”,第778行,在run\u metadata\u ptr)文件“/usr/local/lib/python2.7/dist packages/tensorflow/python/client/session.py”,第954行,在运行np\u val=np.asarray中(subfeed\u val,dtype=subfeed\u dtype)文件“/usr/local/lib/python2.7/dist packages/numpy/core/numeric.py”,asarray返回数组(a,dtype,copy=False,order=order)类型错误:float()参数必须是字符串或数字