적용 방법
1. 모듈 추가
새로운 지표 클래스(CustomMetric)를 만들어 해당 클래스를 mmseg/evaluation/metrics/custom_metric.py에 두었다고 가정하겠습니다. 새로운 모듈을 추가하기 위해서는 registry에 정의되어 있는 클래스의 register_module을 통해 등하게 됩니다. (일반적으로 데코레이터를 통해 등록을 진행합니다)
from typing import List, Sequence
from mmengine.evaluator import BaseMetric
from mmseg.registry import METRICS
@METRICS.register_module()
class CustomMetric(BaseMetric):
def __init__(self, arg1, arg2):
pass
def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
pass
def compute_metrics(self, results: list) -> dict:
pass
def evaluate(self, size: int) -> dict:
pass
위 클래스는 BaseMetric이란 것을 상속받아서 만들어지게 되는데, 3가지 메서드(process, compute_metrics, evaluate)를 포함해야 하는데, 여기서는 이에 대한 설명은 생략하도록 하겠습니다. 자세히 알고 싶은 경우 참고문헌의 [1]에서 확인하시기 바랍니다.
2. 사용자 지표 가져오기
새롭게 정의된 지표를 mmseg/evaluation/metrics/__init__.py에 아래와 같이 추가해줍니다.
from .custom_metric import CustomMetric
__all__ = ['CustomMetric', ...]
3. config 파일에 추가
OpenMMLab에서 만든 라이브러리를 사용하면 .py로 쓰여진 Configuration 파일을 만들어야 합니다. (예를 들어, config.py 처럼 말이죠) 해당 파일에 새롭게 정의된 사용자 지표를 아래와 같이 작성해서 config.py에 추가해줍니다.
val_evaluator = dict(type='CustomMetric', arg1=xxx, arg2=xxx)
test_evaluator = dict(type='CustomMetric', arg1=xxx, arg2=xxx)
또 다른 방법
위 접근방식은 MMSegmentation에서 제공하고 있는 BaseMetric 등 코드를 활용하는 방법입니다. 만약 다른 방식으로 지표를 정의하고 싶다면 아래와 같은 방식으로 만들 수도 있습니다. 먼저 custom_metric.py를 만듭니다. 여기서 경로는 임의로 'path/custom_metric.py'라고 하겠습니다.
from typing import List, Sequence
from mmseg.registry import METRICS
@METRICS.register_module()
class CustomMetric:
def __init__(self, arg1, arg2):
pass
def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
pass
def compute_metrics(self, results: list) -> dict:
pass
def evaluate(self, size: int) -> dict:
pass
정의된 custom_metric.py를 코드나 config 파일에 불러옵니다. 만약 코드에 직접불러온다면, 아래와 같이 import하면 됩니다.
from path import CustomMetric
또는 config 파일에 쓰려면 아래와 같이 써주면 됩니다.
custom_imports = dict(imports=['/Path/to/metrics'], allow_failed_imports=False)
val_evaluator = dict(type='CustomMetric', arg1=xxx, arg2=xxx)
test_evaluator = dict(type='CustomMetric', arg1=xxx, arg2=xxx)
참고자료
[1] https://mmsegmentation.readthedocs.io/en/latest/advanced_guides/add_metrics.html
'Python > Data Analysis' 카테고리의 다른 글
MMSeg에서 새로운 증강 기법 만들기 | Custom Transforms (0) | 2024.11.27 |
---|---|
MMSeg에서 새로운 데이터셋 정의하기 | Custom Dataset (0) | 2024.11.26 |
앙상블 기법을 딥러닝에서 활용할 수 있게 pytorch로 구현하기 | Voting, Bagging, Stacking (0) | 2024.11.11 |
PIL을 활용한 이미지 특징(Image attributes) 추출하기 | 이미지 데이터 EDA (0) | 2024.09.25 |
인공신경망에서 그래디언트 손실 및 폭주 문제 해결 | 활성화 함수, 가중치 초기화 (2) | 2024.08.31 |