DataLoader基础用法

作者 : admin 本文共2783个字,预计阅读时间需要7分钟 发布时间: 2024-06-10 共2人阅读

DataLoader 是 PyTorch 中一个非常有用的工具,用于将数据集进行批处理,并提供一个迭代器来简化模型训练和评估过程。以下是 DataLoader 的常见用法和功能介绍:

基本用法

  1. 创建数据集
    首先,需要一个数据集。数据集可以是 PyTorch 提供的内置数据集,也可以是自定义的数据集。数据集需要继承 torch.utils.data.Dataset 并实现 __len____getitem__ 方法。

    import torch
    import torch.utils.data as Data
    
    class MyDataSet(Data.Dataset):
        def __init__(self, enc_inputs, dec_inputs, dec_outputs):
            self.enc_inputs = enc_inputs
            self.dec_inputs = dec_inputs
            self.dec_outputs = dec_outputs
    
        def __len__(self):
            return len(self.enc_inputs)
    
        def __getitem__(self, idx):
            return self.enc_inputs[idx], self.dec_inputs[idx], self.dec_outputs[idx]
    
  2. 创建 DataLoader
    DataLoader 用于将数据集封装成批次,并提供一个迭代器来进行数据的加载。常见的参数包括数据集、批量大小、是否打乱数据、使用的进程数等。

    enc_inputs = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
    dec_inputs = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
    dec_outputs = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
    
    dataset = MyDataSet(enc_inputs, dec_inputs, dec_outputs)
    loader = Data.DataLoader(dataset=dataset, batch_size=2, shuffle=True)
    
  3. 迭代数据
    使用 DataLoader 的迭代器来访问批次数据。

    for batch in loader:
        enc_batch, dec_batch, output_batch = batch
        print(enc_batch)
        print(dec_batch)
        print(output_batch)
    

常见参数

  1. dataset

    • 数据集对象,必须继承 torch.utils.data.Dataset 类。
  2. batch_size

    • 每个批次的大小,默认为 1。
  3. shuffle

    • 是否在每个 epoch 开始时打乱数据,默认为 False
  4. num_workers

    • 使用多少个子进程来加载数据。0 表示数据将在主进程中加载。对于大型数据集,增加 num_workers 可以加快数据加载速度。
  5. drop_last

    • 如果设置为 True,则丢弃不能整除 batch_size 的最后一个不完整的批次。
  6. pin_memory

    • 如果设置为 True,DataLoader 将在返回前将张量复制到 CUDA 固定内存中。这对 GPU 训练有所帮助。

进阶用法

  1. 自定义 collate_fn

    • collate_fn 用于指定如何将多个样本合并成一个批次。默认情况下,DataLoader 将使用 default_collate,它会将相同类型的数据合并在一起。例如,所有张量数据将合并成一个张量。
    def my_collate_fn(batch):
        enc_inputs, dec_inputs, dec_outputs = zip(*batch)
        return torch.stack(enc_inputs, 0), torch.stack(dec_inputs, 0), torch.stack(dec_outputs, 0)
    
    loader = Data.DataLoader(dataset=dataset, batch_size=2, shuffle=True, collate_fn=my_collate_fn)
    
  2. 使用 Sampler

    • Sampler 用于指定如何抽样数据。PyTorch 提供了一些内置的采样器,如 RandomSamplerSequentialSampler
    from torch.utils.data.sampler import RandomSampler
    
    sampler = RandomSampler(dataset)
    loader = Data.DataLoader(dataset=dataset, batch_size=2, sampler=sampler)
    

完整示例

import torch
import torch.utils.data as Data
class MyDataSet(Data.Dataset):
def __init__(self, enc_inputs, dec_inputs, dec_outputs):
self.enc_inputs = enc_inputs
self.dec_inputs = dec_inputs
self.dec_outputs = dec_outputs
def __len__(self):
return len(self.enc_inputs)
def __getitem__(self, idx):
return self.enc_inputs[idx], self.dec_inputs[idx], self.dec_outputs[idx]
enc_inputs = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
dec_inputs = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
dec_outputs = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
dataset = MyDataSet(enc_inputs, dec_inputs, dec_outputs)
loader = Data.DataLoader(dataset=dataset, batch_size=2, shuffle=True)
for batch in loader:
enc_batch, dec_batch, output_batch = batch
print("Encoder batch:", enc_batch)
print("Decoder batch:", dec_batch)
print("Output batch:", output_batch)

通过使用 DataLoader,我们可以轻松地处理和批量化我们的数据,这对于大型数据集和深度学习模型的训练是非常重要的。

本站无任何商业行为
个人在线分享 » DataLoader基础用法
E-->