筛选丹方上炉开炼的过程是少不了dataset和dataloader的,只不过我在早期炼丹的时候,由于那时候需要更加关心怎么出效果,怎么赶紧出活,没怎么在乎这个功能是怎么实现的。
曾经有无数个日与夜,我都是这样:
1 | import torch |
我好像从来没有在乎过,__getitem__
的输入idx是怎么来的,以及这个DataLoader是怎么凑够batchsize个样本的。只知道自己继承dataset写一个类,然后覆写它的几个魔法方法,然后再用torch里的dataloader封装一下,我就可以遍历数据集了。
这在大部分简单的任务里都行得通,但为了以后能掌控更加复杂的任务,以及看懂别人的代码,我们最好理解一下dataloader的实现。
实际上,在torch中,sampler和dataset是dataloader的两个部分。sampler用来生成索引,dataset是根据索引去读数据(图片,标签,序列等)。同时,我们知道,如果Python中的一个对象能被for…in…遍历,那它需要有__iter__
,__next__
方法,无论是其本身直接实现__next__
,还是内部间接返回(使用一些高级语法__getitem__
,yield
,或者返回一个有__next__
的类)。
所以我们现在从for…in…来观察一个整个过程,for…in…操作其实会被Python解释为:
1 | A = iter(A) |
所以如果我们打开torch中dataloader.py的实现,我们会找到其__iter__
方法:
1 | def __iter__(self) -> '_BaseDataLoaderIter': |
我们可以看到,函数注解->提示我们这个方法会返回一个_BaseDataLoaderIter
,我们跟着self._get_iterator()
往下找:
1 | def _get_iterator(self) -> '_BaseDataLoaderIter': |
我们可以发现这里是来确定单进程还是多进程的,我们这里就只分析单进程了。所以我们转入_SingleProcessDataLoaderIter(self)
。
注意,此时我们还都在DataLoader的类定义下,_SingleProcessDataLoaderIter(self)
是将DataLoader所有的对象作为参数传入了_SingleProcessDataLoaderIter()
的构造函数__init__
。而这个类其实是继承自_BaseDataLoaderIter()
的,所以这就是为什么函数注解告诉我们返回的是_BaseDataLoaderIter(object)
。
1 | class _SingleProcessDataLoaderIter(_BaseDataLoaderIter): |
这个类也正是for…in…执行后,第一行的A=iter(A)
得到的返回值,下面我们会关心_SingleProcessDataLoaderIter
是否有__next__
方法,其本身没有定义__next__
,所以我们去找一下它的父类_BaseDataLoaderIter
。在其定义中我们可以看到:
1 | def __next__(self) -> Any: |
所以_SingleProcessDataLoaderIter
继承自_BaseDataLoaderIter
的__next__
,同时覆写了其中用到的_next_data()
,覆写的_next_data()
复用了父类的_next_index()
,其定义很简短:
1 | def _next_index(self): |
所以此时,我们大概勾勒出了运行时的模样,当我们for…in…一个dataloader时,首先会先给出一个迭代器_SingleProcessDataLoaderIter
。然后这个迭代器不断的调用next,调用时,_next_data
内部的_next_index
负责给出一个索引,然后会去抓取这个索引对应的数据。
我们已经有了不少的进展了,下面要继续看下去。
首先我们要关注_next_index
里的self._sampler_iter
,这个对象是在父类_BaseDataLoaderIter
初始化时产生的,与其相关的两行代码为:
1 | self._index_sampler = loader._index_sampler |
所以我们要进一步寻找这个self._index_sampler
的__iter__
方法,来搞清楚这个(至少是可迭代的)对象是怎么来的。我们不难发现,其定义是在DataLoader中的:
1 |
|
而至于self.batch_sampler
和self.sampler
,其在DataLoader的初始化中就设定了:
1 | if sampler is None: # give default samplers |
同时,上面这段代码也是为什么我们一般无需手动在给一个dataloader时,输入sampler的原因。我们一般不会给出sampler,更不会给出batch_sampler。我们一般都是给出batchsize和shuffle,根据上面的代码我们可以知道如果我们只给出batchsize和shuffle,会自动的将一个RandomSampler赋给self.sampler,然后将一个BatchSampler赋给self.batch_sampler。
例如在最简单的情形下的SequentailSampler
:
1 | class SequentialSampler(Sampler[int]): |
此时,iter(self._index_sampler)
就会返回iter(range(len(self.data_source)))
这个迭代器,所以next()
的结果就是0,1,2,…。
在Python中,range()本身并不是迭代器,它仅仅是一个可迭代对象,但iter(range)会自动返回一个迭代器。这个设计使得range()听起来很奇怪,但其实背后是有更深层的原因的。
现在我们明确了,当实例化一个dataloader时,会根据输入参数得到类内成员sampler。然后sampler会向下穿越到_SingleProcessDataLoaderIter
,来完成“得到索引”的使命。所以到这里,dataloader的两部分,sampler和dataset,我们已经解决其中之一了。
现在距离谜团彻底解开,还差_SingleProcessDataLoaderIter
里的那一行:
1 | self._dataset_fetcher = _DatasetKind.create_fetcher( |
这其实就是dataset的那一部分,我们跳转到_DatasetKind
:
1 | class _DatasetKind(object): |
我们发现它好像约定了两种dataset,Map和Iterable。令人震惊的是,大多数人(包括我)可能炼丹到现在都不一定用过,甚至不知道Iterable这种形式的数据集。(但其实这两种数据集也没那么大差异罢了)
如果你像文章最开头那样,每次都用torch.utils.data.Dataset来覆写的话,那你使用的就是标准的Map式数据集。这种方式是使用__getitem__
和__len__
来达到对数据集的随机访问(需要__len__
来保证访问是否越界),一般而言,这种设计方式鼓励的是在__init__
时将数据集整个都存入内存中,这样就可以快速的进行访问,减少IO时间。
但其实也并不用严格的将整个数据集都load进内存里,完全可以
_init
时候只给定路径,然后每__getitem__
的时候现读。这又何尝不是一种Iterable?
而Iterable的数据集,可能是在更加实际的场景下用的。它继承的是torch.utils.data.IterableDataset,每次只需要重写__iter__
就好啦:
1 | from torch.utils.data import IterableDataset |
由于其__iter__
会返回一个生成器,所以需要先将数据集iter()一下返回一个迭代器,这里由于生成器机制yield可以next(),所以如果读数据就疯狂next()就好了。
所以说回来,我们现在只需要看一下_utils.fetch
是怎么实现的就好了:
1 | class _BaseDatasetFetcher(object): |
可以看到,和我们之前的分析是一致的,如果是Iterable的数据集,就会self.dataset_iter = iter(dataset)
以后然后next,如果是Map的数据集,那就会直接用[]来访问。
最后的self.collate_fn
是一个用于“整理”的函数,如果去查阅其源码,其意思基本就是把各种形式的数据集强制转换为tensor。我们一般都不会用到这个功能,或许在自然语言处理里,用于处理长短不一的输入时,会重写一个自己的collate_fn
。
所以,我们的一个batch的数据是怎么来的呢?现在我们可以有了圆满的答案,注意在dataloader的类定义中,有一个函数:
1 |
|
而这个方法第一次被调用,是在dataloader的__init__
中:
1 | self.batch_size = batch_size |
如果我们设置了batchsize,那么此时的self.batch_sampler
就不为空,self._auto_collation
就为真。所以在后面的_indexIsampler
里,就会自动使用self.batch_sampler
,但其实BatchSampler只是对sampler的又一层封装:
1 | class BatchSampler(Sampler[List[int]]): |
我们可以看到,BatchSampler里的__iter__
,忠实的完成了凑够一个batchsize的数据的职责。
所以,最后的递归调用可以总结成:
多进程时的情况这里就不作分析了,涉及到一系列进程守护,队列之类的机制。毕竟写这篇blog只是为了理清另一个代码库里的几行代码罢了,总之,这就是作为dataloader,传奇的运行周期,它尽职尽责的将数据load进来,为纷至沓来的炼丹师编织一个又一个美好的梦境。
End
沉淀!