Python 计算torch.utils.data.DataLoader中数据对应的光流
我建立了一个CNN模型,用于PyTorch视频中的动作识别。我正在使用torch dataloader模块加载训练数据Python 计算torch.utils.data.DataLoader中数据对应的光流,python,pytorch,opencv-python,Python,Pytorch,Opencv Python,我建立了一个CNN模型,用于PyTorch视频中的动作识别。我正在使用torch dataloader模块加载训练数据 train_loader = torch.utils.data.DataLoader( training_data, batch_size=8, shuffle=True, num_workers=4, pin_memory=True) 然后通过火车装载机对
train_loader = torch.utils.data.DataLoader(
training_data,
batch_size=8,
shuffle=True,
num_workers=4,
pin_memory=True)
然后通过火车装载机
对模型进行训练
train_epoch(i, train_loader, action_detect_model, criterion, optimizer, opt,
train_logger, train_batch_logger)
现在我想添加一个额外的路径,它将采用视频帧的相应光流。为了计算光流,我使用了
cv2.calcOpticalFlowFarneback
。但问题是,我不确定如何获得与火车数据加载器张量中的数据相对应的图像,因为它们将被洗牌。我不想预先计算光流,因为存储需求将是巨大的(每帧需要600 kBs)。您必须使用自己的数据加载器类来动态计算光流。其思想是该类获得一个文件名元组列表(curr image,next image),其中包含视频序列的当前和下一帧文件名,而不是简单的文件名列表。这允许在填充文件名列表后获得正确的图像对。
以下代码为您提供了一个非常简单的示例实现:
from torch.utils.data import Dataset
import cv2
import numpy as np
class FlowDataLoader(Dataset):
def __init__(self,
filename_tuples):
random.shuffle(filename_tuples)
self.lines = filename_tuples
def __getitem__(self, index):
img_filenames = self.lines[index]
curr_img = cv2.cvtColor(cv2.imread(img_filenames[0]), cv2.BGR2GRAY)
next_img = cv2.cvtColor(cv2.imread(img_filenames[1]), cv2.BGR2GRAY)
flow = cv2.calcOpticalFlowFarneback(curr_img, next_img, ... [parameter])
# code for loading the class label
# label = ...
#
# this is a very simple data normalization
curr_img= curr_img.astype(np.float32) / 255
next_img = next_img .astype(np.float32) / 255
# you can return the image and flow seperatly
return curr_img, flow, label
# or stacked as follows
# return np.dstack((curr_img,flow)), label
# at this place you need a function that create a list of training sample filenames
# that look like this
training_filelist = [(img000.png, img001.png),
(img001.png, img002.png),
(img002.png, img003.png)]
training_data = FlowDataLoader(training_filelist)
train_loader = torch.utils.data.DataLoader(
training_data,
batch_size=8,
shuffle=True,
num_workers=4,
pin_memory=True)
这只是FlowDataLoader的一个简单示例。理想情况下,这应该扩展,以便当前图像输出包含标准化的rgb值,并且光流也被标准化和剪裁