网页的网站导航怎么做,龙岩市城乡规划建设局网站,页面设计所遵循的原则有哪些,做防腐木花架的网站本文不做太多原理介绍#xff0c;直讲使用流畅。想看更多底层实现-〉传送门。DataLoader简介torch.utils.data.DataLoader是PyTorch中数据读取的一个重要接口#xff0c;该接口定义在dataloader.py脚本中#xff0c;只要是用PyTorch来训练模型基本都会用到该接口。本文介绍t…本文不做太多原理介绍直讲使用流畅。想看更多底层实现-〉传送门。DataLoader简介torch.utils.data.DataLoader是PyTorch中数据读取的一个重要接口该接口定义在dataloader.py脚本中只要是用PyTorch来训练模型基本都会用到该接口。本文介绍torch.utils.data.DataLoader与torch.utils.data.Dataset结合使用的方法。torch.utils.data.DataLoader接收torch.utils.data.Dataset作为输入得到DataLoader它是一个迭代器方便我们去多线程地读取数据并且可以实现batch以及shuffle的读取等。torch.utils.data.Dataset这是一个抽象类所以我们需要对其进行派生从而使用其派生类来创建数据集。最主要的两个函数实现为__Len__和__getitem__。__init__可以在这里设置加载的data和label。__Len__获取数据集大小__getitem__根据索引获取一条训练的数据和标签。dataLoader的基本使用输入数据格式在使用torch.utils.data.DataLoader与torch.utils.data.Dataset前需要对自己的数据读取或者做一些处理比如我已经将我的文本数据读取到Dict里了格式如下(就放了两个例子list里面存储多个Dict一个Dict存的数据是我的一条样例)train_pair
[{id: bb6b40-en, question: paddy:rice, choices: [walnut:walnut crisp, cotton:cotton seed, watermelon:melon seeds, peanut:peanut butter], text: [question: paddy:rice. option: walnut:walnut crisp, question: paddy:rice. option: cotton:cotton seed, question: paddy:rice. option: watermelon:melon seeds, question: paddy:rice. option: peanut:peanut butter], label: 1}, {id: 1af9fe-en, question: principal:teacher, choices: [police:thief, manager:staff, teacher:student, doctor:nurse], text: [question: principal:teacher. option: police:thief, question: principal:teacher. option: manager:staff, question: principal:teacher. option: teacher:student, question: principal:teacher. option: doctor:nurse], label: 1
}]构造DataSet这里的功能是主要是设置加载的data和label获取数据集大小并根据索引获取一条训练的数据和标签。是为使用DataLoader作准备。class MyDataset(Dataset):def __init__(self, data_pairs): super().__init__()self.data data_pairsdef __len__(self):return len(self.data)def __getitem__(self, index):return self.data[index]train_data MyDataset(train_pair)使用collate_fn在DataLoader基础上自定义自己的输出这里实现的东西很简单因为DataLoader会自动把多个Dict中的数据合并。但是如果我上面的text字段存储了ListList就会被合并成元组但是我在使用数据的时候希望所有的句子一起输入模型。这样我就可以自己定义一个函数去控制合并的操作。def my_collate(batch_line):batch_line deepcopy(batch_line)text []label []for line in batch_line:text.extend(line[text]) #我只使用这两个字段其他的可以不处理不输出label.append(line[label])batch {text:text,label:label,}return batchtrain_data_loader DataLoader(train_data, batch_sizeargs.batch_size, shuffleTrue, collate_fnmy_collate)这是自己控制合并操作的结果{text: [question: white pollution:biodegradation. option: industrial electricity:solar energy, question: white pollution:biodegradation. option: domestic water:reclaimed water recycling, question: white pollution:biodegradation. option: chinese herbal prescriptions:medical research,question: stone wall:earth wall. option: legal:illegal, question: stone wall:earth wall. option: riverway:waterway, question: stone wall:earth wall. option: new house:wedding room, question: kiln:ceramics. option: school:student, question: kiln:ceramics. option: oven:bread,], label: [3, 0]
}如果不自己处理直接使用train_data_loader DataLoader(train_data, batch_sizeargs.batch_size, shuffleTrue)输出结果就会变成{ id: [bb6b40-en,1af9fe-en], question: [paddy:rice, principal:teacher], choices:......(省略)text: (question: white pollution:biodegradation. option: industrial electricity:solar energy, question: white pollution:biodegradation. option: domestic water:reclaimed water recycling, question: white pollution:biodegradation. option: chinese herbal prescriptions:medical research,question: stone wall:earth wall. option: legal:illegal), (question: stone wall:earth wall. option: riverway:waterway, question: stone wall:earth wall. option: new house:wedding room, question: kiln:ceramics. option: school:student, question: kiln:ceramics. option: oven:bread), label: [3, 0]
}DataLoader的参数dataset (Dataset) – 加载数据的数据集。batch_size (int, optional) – 每个batch加载多少个样本(默认: 1)。shuffle (bool, optional) – 设置为True时会在每个epoch重新打乱数据(默认: False).sampler (Sampler, optional) – 定义从数据集中提取样本的策略即生成index的方式可以顺序也可以乱序num_workers (int, optional) – 用多少个子进程加载数据。0表示数据将在主进程中加载(默认: 0)collate_fn (callable, optional) –将一个batch的数据和标签进行合并操作。pin_memory (bool, optional) –设置pin_memoryTrue则意味着生成的Tensor数据最开始是属于内存中的锁页内存这样将内存的Tensor转义到GPU的显存就会更快一些。drop_last (bool, optional) – 如果数据集大小不能被batch size整除则设置为True后可删除最后一个不完整的batch。如果设为False并且数据集的大小不能被batch size整除则最后一个batch将更小。(默认: False)timeout是用来设置数据读取的超时时间的但超过这个时间还没读取到数据的话就会报错。参考材料pytorch中的数据导入之DataLoader和Dataset的使用介绍PyTorch源码解读之torch.utils.data.DataLoader