반응형

1. Pruning 준비하기

Pruning(가지치기)는 모델 경량화 방법 중 하나로 파라미터를 줄여 모델의 경량화를 목표로 한다. 이전 을 통해 이론적 배경에 대해 살펴본 바 있었는데, 이번에는 PyTorch에서 제공하고 있는 도구를 활용해 Pruning을 적용하는 것을 연습해보고자 한다.



우선, model이라는 변수에 딥러닝 모델 인스턴스가 있음을 가정하고, 필요한 가지치기 클래스를 불러온다. 그리고 module이라는 이름으로 적용 대상인 하나의 층을 가져온다.

import torch.nn.utils.prune as prune

model = Model()
module = model.conv1

2. Pruning 적용하기

지정한 module에 대해서 torch.nn.utils.prune에서 제공하는 기법 중 하나를 선택하고, 모듈과 파라미터를 지정한다. 아래 예제는 module의 가중치 50%를 랜덤으로 가지치기 하는 방법을 선택했다.

prune.random_unstructured(module, name="weight", amount=0.5)

만약 가중치가 아니라 편향값에 Pruning을 수행하려면 아래와 같이 수행할 수 있다. 여기서는 L1 Norm 값이 가장 작은 편향값 3개를 가지치기를 시도하였다.

prune.l1_unstructured(module, name="bias", amount=3)

잘 적용되었는지 확인하려면, 아래 코드 중 하나로 결과값을 출력할 수 있다.

print(module.weight)
print(list(module.named_parameters()))

3.정리

Pruning 기법이 영구적으로 적용되게 하고 싶을 때 prune을 적용하면서 생성된 것을 제거하는 과정이 필요하다. 이를 위해서는 prune에 있는 remove를 사용한다.

prune.remove(module, 'weight')

참고자료

  1. PyTorch Pruning Tutorial
반응형