时间:2022-01-27 08:48:36 | 栏目:Python代码 | 点击:次
数据加载分为加载torchvision.datasets中的数据集以及加载自己使用的数据集两种情况。
torchvision.datasets中的数据集
torchvision.datasets中自带MNIST,Imagenet-12,CIFAR等数据集,所有的数据集都是torch.utils.data.Dataset的子类,都包含 _ _ len _ (获取数据集长度)和 _ getItem _ _ (获取数据集中每一项)两个子方法。

Dataset源码如上,可以看到其中包含了两个没有实现的子方法,之后所有的Dataet类都继承该类,并根据数据情况定制这两个子方法的具体实现。
因此当我们需要加载自己的数据集的时候也可以借鉴这种方法,只需要继承torch.utils.data.Dataset类并重写 init ,len,以及getitem这三个方法即可。这样组着的类可以直接作为参数传入到torch.util.data.DataLoader中去。
以CIFAR10为例 源码:
class torchvision.datasets.CIFAR10(root, train=True, transform=None, target_transform=None, download=False)
root (string) ?C Root directory of dataset where directory cifar-10-batches-py exists or will be saved to if download is set to True. train (bool, optional) ?C If True, creates dataset from training set, otherwise creates from test set. transform (callable, optional) ?C A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop target_transform (callable, optional) ?C A function/transform that takes in the target and transforms it. download (bool, optional) ?C If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
加载自己的数据集
对于torchvision.datasets中有两个不同的类,分别为DatasetFolder和ImageFolder,ImageFolder是继承自DatasetFolder。
下面我们通过源码来看一看folder文件中DatasetFolder和ImageFolder分别做了些什么
import torch.utils.data as data
from PIL import Image
import os
import os.path
def has_file_allowed_extension(filename, extensions): //检查输入是否是规定的扩展名
"""Checks if a file is an allowed extension.
Args:
filename (string): path to a file
Returns:
bool: True if the filename ends with a known image extension
"""
filename_lower = filename.lower()
return any(filename_lower.endswith(ext) for ext in extensions)
def find_classes(dir):
classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] //获取root目录下所有的文件夹名称
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))} //生成类别名称与类别id的对应Dictionary
return classes, class_to_idx
def make_dataset(dir, class_to_idx, extensions):
images = []
dir = os.path.expanduser(dir)// 将~和~user转化为用户目录,对参数中出现~进行处理
for target in sorted(os.listdir(dir)):
d = os.path.join(dir, target)
if not os.path.isdir(d):
continue
for root, _, fnames in sorted(os.walk(d)): //os.work包含三个部分,root代表该目录路径 _代表该路径下的文件夹名称集合,fnames代表该路径下的文件名称集合
for fname in sorted(fnames):
if has_file_allowed_extension(fname, extensions):
path = os.path.join(root, fname)
item = (path, class_to_idx[target])
images.append(item) //生成(训练样本图像目录,训练样本所属类别)的元组
return images //返回上述元组的列表
class DatasetFolder(data.Dataset):
"""A generic data loader where the samples are arranged in this way: ::
root/class_x/xxx.ext
root/class_x/xxy.ext
root/class_x/xxz.ext
root/class_y/123.ext
root/class_y/nsdf3.ext
root/class_y/asd932_.ext
Args:
root (string): Root directory path.
loader (callable): A function to load a sample given its path.
extensions (list[string]): A list of allowed extensions.
transform (callable, optional): A function/transform that takes in
a sample and returns a transformed version.
E.g, ``transforms.RandomCrop`` for images.
target_transform (callable, optional): A function/transform that takes
in the target and transforms it.
Attributes:
classes (list): List of the class names.
class_to_idx (dict): Dict with items (class_name, class_index).
samples (list): List of (sample path, class_index) tuples
"""
def __init__(self, root, loader, extensions, transform=None, target_transform=None):
classes, class_to_idx = find_classes(root)
samples = make_dataset(root, class_to_idx, extensions)
if len(samples) == 0:
raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n"
"Supported extensions are: " + ",".join(extensions)))
self.root = root
self.loader = loader
self.extensions = extensions
self.classes = classes
self.class_to_idx = class_to_idx
self.samples = samples
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
"""
根据index获取sample 返回值为(sample,target)元组,同时如果该类输入参数中有transform和target_transform,torchvision.transforms类型的参数时,将获取的元组分别执行transform和target_transform中的数据转换方法。
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
path, target = self.samples[index]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
def __len__(self):
return len(self.samples)
def __repr__(self): //定义输出对象格式 其中和__str__的区别是__repr__无论是print输出还是直接输出对象自身 都是以定义的格式进行输出,而__str__ 只有在print输出的时候会是以定义的格式进行输出
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']
def pil_loader(path):
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
def accimage_loader(path):
import accimage
try:
return accimage.Image(path)
except IOError:
# Potentially a decoding problem, fall back to PIL.Image
return pil_loader(path)
def default_loader(path):
from torchvision import get_image_backend
if get_image_backend() == 'accimage':
return accimage_loader(path)
else:
return pil_loader(path)
class ImageFolder(DatasetFolder):
"""A generic data loader where the images are arranged in this way: ::
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
Args:
root (string): Root directory path.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
Attributes:
classes (list): List of the class names.
class_to_idx (dict): Dict with items (class_name, class_index).
imgs (list): List of (image path, class_index) tuples
"""
def __init__(self, root, transform=None, target_transform=None,
loader=default_loader):
super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,
transform=transform,
target_transform=target_transform)
self.imgs = self.samples
如果自己所要加载的数据组织形式如下
root/dog/xxx.png root/dog/xxy.png root/dog/xxz.png root/cat/123.png root/cat/nsdf3.png root/cat/asd932_.png
即不同类别的训练数据分别存储在不同的文件夹中,这些文件夹都在root(即形如 D:/animals 或者 /usr/animals )路径下
class torchvision.datasets.ImageFolder(root, transform=None, target_transform=None, loader=<function default_loader>)
参数如下:
root (string) ?C Root directory path. transform (callable, optional) ?C A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop target_transform (callable, optional) ?C A function/transform that takes in the target and transforms it. loader ?C A function to load an image given its path. 就是上述源码中 __getitem__(index) Parameters: index (int) ?C Index Returns: (sample, target) where target is class_index of the target class. Return type: tuple
可以通过torchvision.datasets.ImageFolder进行加载
img_data = torchvision.datasets.ImageFolder('D:/bnu/database/flower',
transform=transforms.Compose([
transforms.Scale(256),
transforms.CenterCrop(224),
transforms.ToTensor()])
)
print(len(img_data))
data_loader = torch.utils.data.DataLoader(img_data, batch_size=20,shuffle=True)
print(len(data_loader))
对于所有的训练样本都在一个文件夹中 同时有一个对应的txt文件每一行分别是对应图像的路径以及其所属的类别,可以参照上述class写出对应的加载类
def default_loader(path):
return Image.open(path).convert('RGB')
class MyDataset(Dataset):
def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
fh = open(txt, 'r')
imgs = []
for line in fh:
line = line.strip('\n')
line = line.rstrip()
words = line.split()
imgs.append((words[0],int(words[1])))
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
self.loader = loader
def __getitem__(self, index):
fn, label = self.imgs[index]
img = self.loader(fn)
if self.transform is not None:
img = self.transform(img)
return img,label
def __len__(self):
return len(self.imgs)
train_data=MyDataset(txt='mnist_test.txt', transform=transforms.ToTensor())
data_loader = DataLoader(train_data, batch_size=100,shuffle=True)
print(len(data_loader))
DataLoader解析
位于torch.util.data.DataLoader中 源代码
该接口的主要目的是将pytorch中已有的数据接口如torchvision.datasets.ImageFolder,或者自定义的数据读取接口转化按照
batch_size的大小封装为Tensor,即相当于在内置数据接口或者自定义数据接口的基础上增加一维,大小为batch_size的大小,
得到的数据在之后可以通过封装为Variable,作为模型的输出
_ _ init _ _中所需的参数如下
1. dataset torch.utils.data.Dataset类的子类,可以是torchvision.datasets.ImageFolder等内置类,也可是继承了torch.utils.data.Dataset的自定义类 2. batch_size 每一个batch中包含的样本个数,默认是1 3. shuffle 一般在训练集中采用,默认是false,设置为true则每一个epoch都会将训练样本打乱 4. sampler 训练样本选取策略,和shuffle是互斥的 如果 shuffle为true,该参数一定要为None 5. batch_sampler BatchSampler 一次产生一个 batch 的 indices,和sampler以及shuffle互斥,一般使用默认的即可 上述Sampler的源代码地址如下[源代码](https://github.com/pytorch/pytorch/blob/master/torch/utils/data/sampler.py) 6. num_workers 用于数据加载的线程数量 默认为0 即只有主线程用来加载数据 7. collate_fn 用来聚合数据生成mini_batch
使用的时候一般为如下使用方法:
train_data=torch.utils.data.DataLoader(...) for i, (input, target) in enumerate(train_data): ...
循环取DataLoader中的数据会触发类中_ _ iter __方法,查看源代码可知 其中调用的方法为 return _DataLoaderIter(self),因此需要查看 DataLoaderIter 这一内部类
class DataLoaderIter(object):
"Iterates once over the DataLoader's dataset, as specified by the sampler"
def __init__(self, loader):
self.dataset = loader.dataset
self.collate_fn = loader.collate_fn
self.batch_sampler = loader.batch_sampler
self.num_workers = loader.num_workers
self.pin_memory = loader.pin_memory and torch.cuda.is_available()
self.timeout = loader.timeout
self.done_event = threading.Event()
self.sample_iter = iter(self.batch_sampler)
if self.num_workers > 0:
self.worker_init_fn = loader.worker_init_fn
self.index_queue = multiprocessing.SimpleQueue()
self.worker_result_queue = multiprocessing.SimpleQueue()
self.batches_outstanding = 0
self.worker_pids_set = False
self.shutdown = False
self.send_idx = 0
self.rcvd_idx = 0
self.reorder_dict = {}
base_seed = torch.LongTensor(1).random_()[0]
self.workers = [
multiprocessing.Process(
target=_worker_loop,
args=(self.dataset, self.index_queue, self.worker_result_queue, self.collate_fn,
base_seed + i, self.worker_init_fn, i))
for i in range(self.num_workers)]
if self.pin_memory or self.timeout > 0:
self.data_queue = queue.Queue()
self.worker_manager_thread = threading.Thread(
target=_worker_manager_loop,
args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory,
torch.cuda.current_device()))
self.worker_manager_thread.daemon = True
self.worker_manager_thread.start()
else:
self.data_queue = self.worker_result_queue
for w in self.workers:
w.daemon = True # ensure that the worker exits on process exit
w.start()
_update_worker_pids(id(self), tuple(w.pid for w in self.workers))
_set_SIGCHLD_handler()
self.worker_pids_set = True
# prime the prefetch loop
for _ in range(2 * self.num_workers):
self._put_indices()