반응형

분산 데이터 병렬화(Distributed Data Parallel, DDP)은 효율적인 분산 및 병렬 처리를 위한 데이터 병렬화 방법론입니다. DDP는 PyTorch의 강력한 모듈로 여러 GPU에서 모델이 구동될 수 있도록 도와줍니다. DDP는 각 프로세스에서 모델의 복사본을 유지하며, 데이터를 병렬로 처리하고 각 GPU의 계산 결과를 자동으로 동기화하여 효율성을 극대화 합니다.


DDP 프로세스

환경 설정

DDP를 실행하기 위해 분산 학습을 위한 기본 환경 변수를 설정합니다. dist.init_process_group로 DDP 모듈을 초기화하는데, 이는 서로 다른 GPU 간 소통 및 동기화를 위해 필요합니다. 초기화에서 사용하는 값들의 주요의미는 다음과 같습니다. PyTorch에서 지원하는 분산처리를 위한 백엔드 종류는 여기에서 확인 가능합니다. 여기서는 gloo를 사용했습니다.

 

  • MASTER_ADDR: 마스터 프로세스의 IP 주소.
  • MASTER_PORT: 마스터 프로세스의 통신 포트.
  • RANK: 현재 프로세스의 순위(고유 ID).
  • WORLD_SIZE: 총 프로세스 수.
import os
import torch
import torch.distributed as dist



def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

모델 및 데이터 준비

모델을 준비한 뒤 DDP로 감싸줍니다. 또한 DistributedSampler를 사용해 데이터셋을 분산 학습에 맞게 나눕니다.

from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler


model = MyModel().to(device)
model = DDP(model, device_ids=[rank])

data_sampler = DistributedSampler(dataset)
data_loader = DataLoader(dataset, sampler=data_sampler, batch_size=batch_size)

학습 루프 작성

DDP를 사용할 때는 일반적인 학습 루프와 유사하지만, 각 프로세스가 독립적으로 실행되며 DistributedSampler를 재설정해야 합니다:

for epoch in range(num_epochs):
    data_sampler.set_epoch(epoch)  # 각 epoch마다 샘플링 초기화
    for data, target in data_loader:
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

프로세스 종료

모든 학습이 끝난 후에는 분산 프로세스를 정리합니다.

dist.destroy_process_group()

최종 코드

import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler

# 모델 정의
class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.linear = torch.nn.Linear(10, 1)

    def forward(self, x):
        return self.linear(x)

# 분산 학습 함수
def main(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    dist.init_process_group(
        backend='nccl',
        init_method='env://',
        world_size=world_size,
        rank=rank
    )

    # 데이터 준비
    dataset = torch.utils.data.TensorDataset(torch.randn(100, 10), torch.randn(100, 1))
    sampler = DistributedSampler(dataset)
    data_loader = DataLoader(dataset, sampler=sampler, batch_size=32)

    # 모델 준비
    model = MyModel().to(rank)
    model = DDP(model, device_ids=[rank])
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    criterion = torch.nn.MSELoss()

    # 학습 루프
    for epoch in range(10):
        sampler.set_epoch(epoch)
        for data, target in data_loader:
            data, target = data.to(rank), target.to(rank)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

    dist.destroy_process_group()

if __name__ == "__main__":
    world_size = 2  # GPU 개수
    torch.multiprocessing.spawn(main, args=(world_size,), nprocs=world_size)

참고자료

  1. PyTorch Distributed Data Parallel Tutorial
반응형