高性能 PyTorch 训练 (2):Dataset

HowBoring 2020-11-16 02:10:04 13301

PyTorch 数据封装

PyTorch 为我们提供了两个类型 DatasetDataLoader,前者负责创建可被 PyTorch 使用的数据集,而后者负责向训练过程传递数据。

如果想个性化自己的数据集或者数据传递方式,也可以自己重写子类。

Dataset

Dataset 是一个抽象类,其完整调用路径是 torch.utils.data.Dataset。自定义的 Dataset 需要继承它,并实现两个成员魔术方法:

  1. __getitem__()
  2. __len()__()

而其中 __getitem__ 更需要根据情况灵活地进行编写,例如

from PIL import Image

def __getitem__(self, index):
    img_path, label = self.data[index].img_path, self.data[index].label
    img = Image.open(img_path)

    return img, label

只要以标准形式返回一个包含图像和对应标签的元组就可以了。

另一个 __len__ 返回数据集包含的数据量:

def __len__(self):
    return len(self.data)

另外,PyTorch 也提供了一些实用的 transformer,包含在 torchvision.transforms 中。常用的有 ResizeRandomCropNormalizeToTensor 等等。TorchVision 是 PyTorch 的额外组件,提供了 CV 方面的一些工具包。

DatasetDataLoader 实例化的一个参数。例如,CIFAR10 是图像分类、目标检测任务中的一个常用数据集,也是 CV 领域常见的标准 benchmark。我们经常能够在开源的模型代码中见到:

import torchvision.datasets as datasets

train_set = datasets.CIFAR10("data", transform=train_transform, train=True, download=True)

torchvision.datasets 中包含了常用的数据集。datasets.CIFAR10Dataset 的一个子类。

如果需要使用自己的数据作为数据集,除了继承 Dataset,也可以使用 ImageFolder 来构建:

my_dataset = datasets.ImageFolder('path/to/data', trasform=data_transform)

DataLoader

DataLoader 的初始化参数列表如下:

  1. dataset:要从中加载数据的数据集。
  2. batch_size:每个批要装载多少样本数据。
  3. shuffle:设置为 True 可以在每个 epoch 重新洗牌数据。
  4. sampler:定义从数据集中提取样本的策略。
  5. batch_sampler:与 sampler 功能类似,但一次返回一批索引。
  6. num_worker:要使用多少子进程装载数据。“0”表示数据将在主进程中加载。
  7. collate_fn:将一组样本合并成一个小批张量。在从字典样式的数据集进行批加载时使用。
  8. pin_memory:如果为TrueDataLoader 将把 Tensor 复制到CUDA固定内存中,然后返回它们。
  9. drop_last:如果数据集大小不能被批大小整除,则设置为 True 以删除最后一个不完整的批。如果 False 和数据集的大小不能被批大小整除,那么最后的批会更小。

可以看到,主要的参数就是 dataset 以及 batch_size

Sampler

这里带来了另一个新的概念,就是 Sampler。Dataset、DataLoader 以及 Sampler 的关系大概可以用以下的图表示:

可以参考 DataLoader.__next__ 的源码来方便我们理解整个的工作流程:

class DataLoader(object):
    ...

    def __next__(self):
        if self.num_workers == 0:  
            indices = next(self.sample_iter)  # Sampler
            batch = self.collate_fn([self.dataset[i] for i in indices]) # Dataset
            if self.pin_memory:
                batch = _utils.pin_memory.pin_memory_batch(batch)
        return batch

假设我们的数据是一组图像,每一张图像对应一个 index,那么如果我们要读取数据就只需要对应的 index 即可,即上面代码中的 indices,而选取 index 的方式有多种,有按顺序的,也有乱序的,所以这个工作需要 Sampler 完成。在拿到 index 之后,就可以依此在 Dataset 中读取相应的数据和标签。

在上文中 DataLoader 的初始化参数中可以看到里有两种 sampler:samplerbatch_sampler,都默认为None。前者的作用是生成一系列的 index,而 batch_sampler 则是将 sampler 生成的 indices 打包分组,得到一个又一个 batch 的 index。例如下面示例中,BatchSamplerSequentialSampler生成的index按照指定的batch size分组。

>>>  in: list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
>>> out: [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]

Pytorch 中已经实现的 Sampler 有如下几种:

  • SequentialSampler
  • RandomSampler
  • WeightedSampler
  • SubsetRandomSampler

所有采样器其实都继承自同一个父类,即Sampler。只要定义好 __iter__ 函数即可实现自定义的 sampler。

另外 BatchSampler 与其他 sampler 的主要区别是它需要将 Sampler 作为参数进行打包,进而每次迭代返回以 batch size 为大小的 index 列表。也就是说在后面的读取数据过程中使用的都是 batch sampler。

声明:本文内容由易百纳平台入驻作者撰写,文章观点仅代表作者本人,不代表易百纳立场。如有内容侵权或者其他问题,请联系本站进行删除。
红包 2109 54 评论 打赏
评论
0个
内容存在敏感词
手气红包
    易百纳技术社区暂无数据
相关专栏
置顶时间设置
结束时间
删除原因
  • 广告/SPAM
  • 恶意灌水
  • 违规内容
  • 文不对题
  • 重复发帖
打赏作者
易百纳技术社区
HowBoring
您的支持将鼓励我继续创作!
打赏金额:
¥1易百纳技术社区
¥5易百纳技术社区
¥10易百纳技术社区
¥50易百纳技术社区
¥100易百纳技术社区
支付方式:
微信支付
支付宝支付
易百纳技术社区微信支付
易百纳技术社区
打赏成功!

感谢您的打赏,如若您也想被打赏,可前往 发表专栏 哦~

举报反馈

举报类型

  • 内容涉黄/赌/毒
  • 内容侵权/抄袭
  • 政治相关
  • 涉嫌广告
  • 侮辱谩骂
  • 其他

详细说明

审核成功

发布时间设置
发布时间:
是否关联周任务-专栏模块

审核失败

失败原因
备注
拼手气红包 红包规则
祝福语
恭喜发财,大吉大利!
红包金额
红包最小金额不能低于5元
红包数量
红包数量范围10~50个
余额支付
当前余额:
可前往问答、专栏板块获取收益 去获取
取 消 确 定

小包子的红包

恭喜发财,大吉大利

已领取20/40,共1.6元 红包规则

    易百纳技术社区