반응형

적용 방법

1. 모듈 추가

새로운 지표 클래스(CustomMetric)를 만들어 해당 클래스를 mmseg/evaluation/metrics/custom_metric.py에 두었다고 가정하겠습니다. 새로운 모듈을 추가하기 위해서는 registry에 정의되어 있는 클래스의 register_module을 통해 등하게 됩니다. (일반적으로 데코레이터를 통해 등록을 진행합니다)

 

from typing import List, Sequence
from mmengine.evaluator import BaseMetric
from mmseg.registry import METRICS


@METRICS.register_module()
class CustomMetric(BaseMetric):

    def __init__(self, arg1, arg2):
    	pass

    def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
        pass

    def compute_metrics(self, results: list) -> dict:
        pass

    def evaluate(self, size: int) -> dict:
        pass

 

위 클래스는 BaseMetric이란 것을 상속받아서 만들어지게 되는데, 3가지 메서드(process, compute_metrics, evaluate)를 포함해야 하는데, 여기서는 이에 대한 설명은 생략하도록 하겠습니다. 자세히 알고 싶은 경우 참고문헌의 [1]에서 확인하시기 바랍니다. 

2. 사용자 지표 가져오기

새롭게 정의된 지표를 mmseg/evaluation/metrics/__init__.py에 아래와 같이 추가해줍니다.

 

from .custom_metric import CustomMetric
__all__ = ['CustomMetric', ...]

3. config 파일에 추가

OpenMMLab에서 만든 라이브러리를 사용하면 .py로 쓰여진 Configuration 파일을 만들어야 합니다. (예를 들어, config.py 처럼 말이죠) 해당 파일에 새롭게 정의된 사용자 지표를 아래와 같이 작성해서 config.py에 추가해줍니다. 

 

val_evaluator = dict(type='CustomMetric', arg1=xxx, arg2=xxx)
test_evaluator = dict(type='CustomMetric', arg1=xxx, arg2=xxx)

 

또 다른 방법

위 접근방식은 MMSegmentation에서 제공하고 있는 BaseMetric 등 코드를 활용하는 방법입니다. 만약 다른 방식으로 지표를 정의하고 싶다면 아래와 같은 방식으로 만들 수도 있습니다. 먼저 custom_metric.py를 만듭니다. 여기서 경로는 임의로 'path/custom_metric.py'라고 하겠습니다.

 

from typing import List, Sequence
from mmseg.registry import METRICS


@METRICS.register_module()
class CustomMetric:

    def __init__(self, arg1, arg2):
    	pass

    def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
        pass

    def compute_metrics(self, results: list) -> dict:
        pass

    def evaluate(self, size: int) -> dict:
        pass

 

정의된 custom_metric.py를 코드나 config 파일에 불러옵니다. 만약 코드에 직접불러온다면, 아래와 같이 import하면 됩니다.

 

from path import CustomMetric

 

또는 config 파일에 쓰려면 아래와 같이 써주면 됩니다.

 

custom_imports = dict(imports=['/Path/to/metrics'], allow_failed_imports=False)

val_evaluator = dict(type='CustomMetric', arg1=xxx, arg2=xxx)
test_evaluator = dict(type='CustomMetric', arg1=xxx, arg2=xxx)

참고자료

[1] https://mmsegmentation.readthedocs.io/en/latest/advanced_guides/add_metrics.html

 

반응형