pytorch-data处理的方法

pytorch对数据处理的一些基本用法总结:

为了更好地分配数据,通常在训练中会用到pytorch的几个库,torchvision.transformstorchvision.dataset.Imagefoldertorch.util.data.Dataloader

用代码解释这三个函数的作用:

transforms:

transforms的作用一句话概括就是使得数据集里的数据统一化,比如对于图像数据,可能很多图像的尺寸不一样,需要对图像的大小进行裁剪和缩放,并对图像的大小尺寸进行统一。

tutorial中的几个参数:

  • Rescale: to scale the image

  • RandomCrop: to crop from image randomly. This is data augmentation.

  • ToTensor: to convert the numpy images to torch images (we need to swap axes).

    note: numpy中的图像数据是 H × W × C ; 而torch.tensor的数据是C × H × W

class Rescale(obejct)
def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
self.output_size = output_size

def __call__(self, sample):
image, landmarks = sample['image'], sample['landmarks']

h, w = image.shape[:2]
if isinstance(self.output_size, int):
if h > w:
new_h, new_w = self.output_size * h / w, self.output_size
else:
new_h, new_w = self.output_size, self.output_size * w / h
else:
new_h, new_w = self.output_size

new_h, new_w = int(new_h), int(new_w)

img = transform.resize(image, (new_h, new_w))

# h and w are swapped for landmarks because for images,
# x and y axes are axis 1 and 0 respectively
landmarks = landmarks * [new_w / w, new_h / h]

return {'image': img, 'landmarks': landmarks}

_call_函数的作用是可以直接调用rescale这个类,不需要每次调用时都需要传递参数。 We will write them as callable classes instead of simple functions so that parameters of the transform need not be passed everytime it’s called.

transformer的用法:
transformed_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',
root_dir='data/faces/',
transform=transforms.Compose([
Rescale(256),
RandomCrop(224),
ToTensor()
]))

FaceLandmarksDataset是之前定义的一个类,transformed_dataset是它的实例化,等于对目录下的data经过了处理,保存到了transformed_dataset这个实例化的类中。

for i in range(len(transformed_dataset)):
sample = transformed_dataset[i]

print(i, sample['image'].size(), sample['landmarks'].size())

然后通过for 循环,用sample一个一个取出这个类中的数据。

输出如下:

0 torch.Size([3, 224, 224]) torch.Size([68, 2])
1 torch.Size([3, 224, 224]) torch.Size([68, 2])
2 torch.Size([3, 224, 224]) torch.Size([68, 2])
3 torch.Size([3, 224, 224]) torch.Size([68, 2])

dataloader:

由此,这不是深度学习中常用的方式,

1.我们对数据集通常会进行打乱,不是顺序读取的。shffle data

2.我们会选择一个batch作为数据,而不是每次都train一个图像。batch data

3.用上述的代码不适合在多线程运行。muilt-processing

此题背景是有69张人脸图像+其面部轮廓的框图。我们使用上面的方法,那么就是先标准化处理每张图像,并且将框图和人脸融合,作为一个深度学习标准的数据集,这些是transform做的事情。接下来载入数据,我们想通过数据集载入到图像数据中

dataloader=torch.util.data.Dataloader{
transforms_dataset,batch_size=4,shuffle=true, num_workers=4
}

通过这样的方式,就解决了上述的三个问题。num_workers是读取数据的线程数目。

note: 补充一下这个函数的实现方式,有一个collate_fn这个函数,是核心,如果后面想用到其他数据里面,重写collate_fn函数,在dataloader=torch.util.data.Dataloader{collate_fn=collate_fn}来得到自己想要的加载数据的结果:

collate_fn这个函数的输入就是一个list,list的长度是一个batch size,list中的每个元素都是__getitem__得到的结果。

ImageFolder

import torch
from torchvision import transforms, datasets

data_transform = transforms.Compose([
transforms.RandomSizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train',
transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,
batch_size=4, shuffle=True,
num_workers=4)

从torchvision中导入这两个函数,先写transforms定义了随机裁图片,变成tensor,进行标准化处理。

使用ImageFolder读取在train下面的图片,其输出如下: {'cat': 0, 'dog': 1} 这样的形式,将label和input做了分离 。

最后使用dataloader导入数据。


   转载规则


《pytorch-data处理的方法》 胡哲 采用 知识共享署名 4.0 国际许可协议 进行许可。