딥러닝을 위해 데이터를 불러오거나 전처리하는 방법을 매번 작성하는 것은 비효율적이고 반복적인 작업이 될 수 있습니다. PyTorch에서는 torch.utils.data를 통해 다양한 클래스를 제공하고 있으며, 제공된 클래스를 적절히 활용하면 효율적이고 유연한 데이터 파이프라인을 구축할 수 있습니다.
Dataset
torch.utils.data.Dataset은 키 -> 데이터 샘플로 매핑되는 모든 데이터셋을 표현하기 위해 Dataset을 상속받아 사용합니다. 데이터를 초기화하는 __init__ 메서드, 데이터 크기를 반환하는 __len__ 메서드, 특정 인덱스의 데이터 샘플을 반환할 수 있도록 하는 __getitems__ 메서드를 구현할 수 있습니다.
import torch
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, data, labels):
self.data = data
self.labels = labels
def __len__(self): # 데이터셋의 총 샘플 수를 반환
return len(self.data)
def __getitem__(self, idx): # 주어진 인덱스에서 데이터와 레이블을 반환
sample = self.data[idx]
label = self.labels[idx]
return sample, label
위 예시에 대해 아래와 같이 생성해보고, 결과를 출력해보겠습니다. 먼저 예제로 사용할 데이터셋을 생성하고 가장 첫번째에 있을 데이터를 출력해보겠습니다.
data = torch.randn(100, 3)
labels = torch.randint(0, 2, (100,))
그렇다면 만들어본 클래스가 잘 작동되는지 출력해보겠습니다. __len__는 len()함수를 통해 확인하고, __getitem__은 [index]를 통해 접근해보겠습니다. 출력 결과를 보니, 문제가 없이 잘 구현된 것을 확인했습니다.
dataset = MyDataset(data, labels)
print(len(dataset)) # output: 100
sample, label = dataset[0]
print(sample, label)
DataLoader
torch.utils.data.DataLoader는 데이터를 배치(batch: 데이터를 건마다 처리하는 것이 아닌, 한 번에 처리되는 데이터의 묶음) 단위로 처리하고 병렬작업을 지원해 속도를 높입니다. 앞에서 만든 Dataset의 인스턴스를 감싸 배치 크기에 맞춰 나눠주고, 특정 순서에 의존하지 않도록 섞어주는 기능(shuffle : 훈련에서는 True / 테스트는 False)도 제공합니다. 또한, 필요할 경우 여러 스레드를 사용해 데이터를 병렬로 로드할 수 있도록 기능(num_workers)도 제공하고 있습니다. 외에도 다양하게 사용할 수 있으니 아래 링크를 통해 학습하면 좋습니다.
https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
torch.utils.data — PyTorch 2.4 documentation
torch.utils.data At the heart of PyTorch data loading utility is the torch.utils.data.DataLoader class. It represents a Python iterable over a dataset, with support for These options are configured by the constructor arguments of a DataLoader, which has si
pytorch.org
'
DataLoader는 별도 클래스를 구성할 것 없이 만들어진 것을 활용하면 됩니다.
from torch.utils.data import DataLoader
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
for batch_data, batch_labels in dataloader:
print(batch_data.shape, batch_labels.shape)
'Python > Data Prep' 카테고리의 다른 글
[OD] 객체 탐지(Object Detection) 대표 데이터 포맷 공부 | COCO, Pascal VOC, YOLO (0) | 2024.10.04 |
---|---|
DataLoader에서 오류가 난다면 누락 데이터가 있는지 확인 필요 | DataLoader는 이터레이터 (2) | 2024.09.26 |
Colab에서 Kaggle 데이터셋 가져오기 | Kaggle, API, Colab (0) | 2024.08.10 |
Selenium 을 활용한 Element 찾기 (find_element, By) | Python, Web Scraping, Web Crawling, 자동화 (0) | 2022.08.09 |
Selenium을 활용한 지자체 선거 당선인 데이터 가져오기 | Web Scraping (0) | 2021.10.16 |