반응형

딥러닝을 위해 데이터를 불러오거나 전처리하는 방법을 매번 작성하는 것은 비효율적이고 반복적인 작업이 될 수 있습니다. PyTorch에서는 torch.utils.data를 통해 다양한 클래스를 제공하고 있으며, 제공된 클래스를 적절히 활용하면 효율적이고 유연한 데이터 파이프라인을 구축할 수 있습니다. 

Dataset 

torch.utils.data.Dataset은 키 -> 데이터 샘플로 매핑되는 모든 데이터셋을 표현하기 위해 Dataset을 상속받아 사용합니다. 데이터를 초기화하는 __init__ 메서드, 데이터 크기를 반환하는 __len__ 메서드, 특정 인덱스의 데이터 샘플을 반환할 수 있도록 하는 __getitems__ 메서드를 구현할 수 있습니다.  

 

import torch
from torch.utils.data import Dataset

class MyDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self): # 데이터셋의 총 샘플 수를 반환
        return len(self.data)

    def __getitem__(self, idx): # 주어진 인덱스에서 데이터와 레이블을 반환
        sample = self.data[idx]
        label = self.labels[idx]
        return sample, label

 

위 예시에 대해 아래와 같이 생성해보고, 결과를 출력해보겠습니다. 먼저 예제로 사용할 데이터셋을 생성하고 가장 첫번째에 있을 데이터를 출력해보겠습니다.

 

data = torch.randn(100, 3) 
labels = torch.randint(0, 2, (100,))

 

 

그렇다면 만들어본 클래스가 잘 작동되는지 출력해보겠습니다. __len__는 len()함수를 통해 확인하고, __getitem__은 [index]를 통해 접근해보겠습니다. 출력 결과를 보니, 문제가 없이 잘 구현된 것을 확인했습니다.  

 

dataset = MyDataset(data, labels)
print(len(dataset))  # output: 100

sample, label = dataset[0]
print(sample, label)

 

DataLoader

torch.utils.data.DataLoader는 데이터를 배치(batch: 데이터를 건마다 처리하는 것이 아닌, 한 번에 처리되는 데이터의 묶음) 단위로 처리하고 병렬작업을 지원해 속도를 높입니다. 앞에서 만든 Dataset의 인스턴스를 감싸 배치 크기에 맞춰 나눠주고, 특정 순서에 의존하지 않도록 섞어주는 기능(shuffle : 훈련에서는 True / 테스트는 False)도 제공합니다. 또한, 필요할 경우 여러 스레드를 사용해 데이터를 병렬로 로드할 수 있도록 기능(num_workers)도 제공하고 있습니다. 외에도 다양하게 사용할 수 있으니 아래 링크를 통해 학습하면 좋습니다.

 

https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader

 

torch.utils.data — PyTorch 2.4 documentation

torch.utils.data At the heart of PyTorch data loading utility is the torch.utils.data.DataLoader class. It represents a Python iterable over a dataset, with support for These options are configured by the constructor arguments of a DataLoader, which has si

pytorch.org

'

DataLoader는 별도 클래스를 구성할 것 없이 만들어진 것을 활용하면 됩니다. 

 

from torch.utils.data import DataLoader

dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

for batch_data, batch_labels in dataloader:
    print(batch_data.shape, batch_labels.shape)

 

반응형