FastAPI 앱을 실행하거나 종료할 때 로직을 넣고 싶을 경우 사용합니다. 예를 들어, 앱이 시작할 때 머신런이 모델을 로드하고 앱을 종료하면 db 연결을 정리하면서 각각 상황마다 출력하도록 구현할 수 있습니다. 아래 코드는 예제를 구현한 것으로 기본적인 FastAPI 구조는 아래 글을 참고해보시기 바랍니다.
앱이 커짐에 따라 get, post가 많아질텐데 이를 관리하기 위해 API Router를 활용할 수 있습니다. @app.get, @app.post을 사용하지 않고 별도의 router 파일을 따로 설정하고 app에 가져와서 사용하는 방식입니다. 다음 예제는 user라는 라우터를 별도로 만들어 app에 연결하는 코드입니다. 우선은 user.py로 만드는 코드입니다.
다음으로 저희가 만든 앱에 해당 라우터를 불러옵니다. 같은 디렉토리에 user.py가 있다고 가정하고 import해온 후에 이를 FastAPI 앱에 include_router를 통해 추가해줍니다. 역서 참고로 include_router는 스크립트가 실행되는 __name__ == '__main__'에 들어가게 되면 app에 라우터가 포함되지 않은 상태로 정의되기 때문에 제대로 처리하지 못합니다.
from fastapi import FastAPI
import uvicorn
import user
app = FastAPI()
app.include_router(user.user_router)
if __name__ == '__main__':
uvicorn.run(app, host="0.0.0.0", port=8000)
아래와 같이 잘 작동하고 있는 것을 확인할 수 있습니다.
Error Handler
에러가 발생한 경우 관련 기록을 남기고, 메시지를 클라이언트에 보내는 것은 애플리케이션을 운영하는 관점에서 필요한 일입니다. FastAPI에서는 이러한 에러 처리를 위해 HTTP Exception을 통해 가능하도록 하고 있습니다. 아래는 그 예시 코드입니다.
from fastapi import HTTPException
@app.get("/v/{item_id}")
async def find_by_id(item_id: int):
try :
item = items[items_id]
except KeyError:
raise HTTPException(status_code=404,
detail=f"아이템을 찾을 수 없습니다. id : {item_id}")
return item
아래 이미지를 보면 제대로 입력된 경우(왼쪽) 원하는 값을 반환하지만, 사전에 정의되지 않은 오류 발생의 경우(오른쪽) 에러 메시지를 띄우는 것을 확인할 수 있습니다.
Background Task
여러 작업 중 시간이 오래 소요되는 것들에 대해서는 백그라운드로 실행할 수 있습니다. 다음은 주어진 대기 시간(wait_time) 동안 작업을 수행하고 생성한 작업 결과를 저장해 필요할 때 조회하는 코드입니다.
from time import sleep
from fastapi import BackgroundTasks
from uuid import UUID, uuid4
from pydantic import BaseModel, Field
# 입력 데이터 모델
class TaskInput(BaseModel):
id_: UUID = Field(default_factory=uuid4)
wait_time: int
# 대기 및 작업 저장
def cpu_bound_task(id_ : UUID, wait_time: int):
sleep(wait_time)
result = f"task done after {wait_time}"
task_repo[id_] = result
# 비동기 작업
@app.post("/task", status_code=202) # 비동기 작업에 대해서 status code 202를 반환
async def create_task_in_background(task_input: TaskInput, background_tasks: BackgroundTasks):
background_tasks.add_task(cpu_bound_task, id_=task_input.id_, wait_time=task_input.wait_time)
return "ok"
# 작업 결과 조회
@app.get("/task/{task_id}")
def get_task_result(task_id: UUID):
try:
return task_repo[task_id]
except KeyError:
return None
참고자료
[1] 변성윤. "[Product Serving] Fast API (2)". boostcamp AI Tech.
Pydantic은 데이터의 유효성을 검증할 수 있는 대표적인 라이브러리입니다. 파이썬의 타입 힌트를 활용해 데이터의 유효성을 검사하고 오류 메시지를 제공하는데 사용됩니다. JSON과 쉽게 연동되기 때문에 FastAPI 같은 웹 프레임워크와 사용해 API의 입력 데이터와 출력 데이터를 검증하는 데 유용합니다.
일반적인 사용 사례는 다음과 같습니다.
from pydantic import BaseModel
class User(BaseModel):
name: str
age: int
# 유효한 데이터 생성
user1 = User(name="John", age=30)
print(user1)
# 유효하지 않은 데이터 생성
try:
user2 = User(name="Jane", age="twenty")
except Exception as e:
print(e)
위 코드에서 User 클래스는 pydantic의 BaseModel을 상속받아 name과 age 속성을 정의합니다. user1은 유효한 데이터를 생성하지만, user2는 age가 문자열이기 때문에 오류가 발생합니다. 위 사례는 조금 단순한 사례이기 때문에 조금 더 복잡한 사례로 연습해보겠습니다.
Case1 : Web Input Validation
웹을 구축하면서 요청과 응답에 대해 크게 3가지를 체크하는 것을 생각해보겠습니다.
올바른 url 입력
1~10 사이의 정수 입력
올바른 폴더 이름 입력
python 버전 3.7 이상을 사용한다면 dataclasses라는 라이브러리를 활용할 수 있습니다. 라이브러리 내의 dataclass를 데코레이터를 사용해 init 메서드로 별도로 정의할 필요가 없고, __post_init__ 메서드로 검증을 수행하는 로직을 생성 시점에서 수행하게끔 합니다.
from dataclasses import dataclass
class ValidationError(Exception):
pass
@dataclass
class ModelInput:
url : str
rate : int
target_dir : str
def _validate_url(self, url: str) -> bool:
from urllib.parse import urlparse
try:
result = urlparse(url)
return all([result.scheme, result.netloc])
except:
return False
def _validate_directory(self, dir: str) -> bool:
import os
return os.path.isdir(dir)
def validate(self) -> bool:
validation_results = [
self._validate_url(self.url),
1 <= self.rate <= 10,
self._validate_directory(self.target_dir)
]
return all(validation_results) # 모두 True일 때 True 반환
def __post_init__(self):
if not self.validate():
raise ValidationError("Incorrect input")
if __name__ == '__main__':
try:
dataclasses_test = ModelInput(**INPUT) # INPUT은 딕셔너리 형태의 입력 데이터
except ValidationError as exc:
print('Error : ', exc.json())
pass
원하는 데로 파이썬의 저수준에서 검증 로직을 만든다는 점에서는 좋으나, 검증 로직을 쌓아가야 하기 때문에 코드가 길어질 수 있습니다. 이러한 부분을 보완하기에 적합한 것이 pydantic 입니다. 이미 만들어져 있는 클래스와 함수들을 가져다 씀으로 동일한 기능을 훨씬 짧은 코드로 구현이 가능합니다.
from pydantic import BaseModel, HttpUrl, Field, DirectoryPath, ValidationError
class ModelInput:
url : HttpUrl
rate : int = Field(ge=1, le=10)
target_idr : DirectoryPath
if __name__ == '__main__':
try:
pydantic_test = ModelInput(**INPUT)
except ValidationError as exc:
print('Error : ', exc.json())
pass
Case2 : Config 관리
Config란 Configuration(환경 설정)의 약자로 코드에 필요한 여러 변수들을 저장해두고 사용하는 것을 말합니다. 이를 위해 코드 내에서 활용하거나 yaml과 같은 파일을 만들어서 읽어주거나 pydantic을 활용할 수 있습니다.
다음 예제는 pydantic을 사용해서 애플리케이션의 설정을 관리하고 오버라이드(덮어 쓰기)하는 예제입니다. 우선 기본적인 환경 설정을 정의합니다. 과거에는 pydantic에 있었으나 pydantic-settings로 옮겨진 BaseSettings 클래스를 상속받습니다.
from pydantic import Field
from pydantic_settings import BaseSettings
from enum import Enum
class ConfigEnv(str, Enum):
DEV = "dev"
PROD = "prod"
class DBConfig(BaseSettings):
host: str = Field(default="localhost", env="db_host")
port: int = Field(default=3306, env="db_port")
username: str = Field(default="user", env="db_username")
password: str = Field(default="user", env="db_password")
database: str = Field(default="dev", env="db_database")
class AppConfig(BaseSettings):
env: ConfigEnv = Field(default="dev", env="env")
db: DBConfig = DBConfig()
Field 함수는 모델의 필드에 메타데이터를 추가하거나 바꾸는 기능을 수행합니다. 여기서는 default와 env 인자를 사용했는데, 각각 다음과 같이 정리할 수 있습니다.
default : 해당 필드의 기본 값을 지정. host : str = Field(default="localhost")는 host 필드의 기본값이 localhost 임.
env : 해당 필드의 환경 변수로 지정해서 해당 변수에서 값을 가져오도록 설정.
전체 애플리케이션에 대한 기본 설정을 담은 yaml 파일이 있다고 가정하고, 이를 로드해 AppConfig 클래스에 전달합니다. 그리고 이를 검증하는 과정을 거칩니다.
with open("dev_config.yaml", "r") as f:
config = load(f, FullLoader)
config_with_pydantic = AppConfig(**config)
assert config_with_pydantic.env == "dev"
assert config_with_pydantic.db.model_dump() == expected
만약 환경 변수로 설정 오버라이딩을 원하는 경우 아래와 같이 수정할 수 있습니다. 필요하다면 검증도 추가적으로 할 수 있습니다.
그리고 uvicorn을 통해 파이썬 파일을 실행시킵니다. uvicorn은 Asynchronous Server Gateway Interface(ASGI)라 불리는 비동기 코드를 처리할 수 있는 파이썬 웹 서버인 프레임워크 간 표준 인터페이스 입니다. 실행하는 방법은 아래와 같이 CLI 입력하면 됩니다. 만약 아래처럼 입력하지 않고 코드 하단에 uvicorn.run을 추가하면 됩니다.
구글 폼과 같이 데이터를 입력받기 위해 Form 클래스를 사용하면, 요청하는 형식의 함수인 Request의 데이터를 가져올 수 있습니다. GET을 통해 해당 페이지를 조회하고 입력받은 데이터를 보내는 POST 메서드를 각각 구현합니다.
또한, 프론트의 구현을 위해 Jinja2를 활용했는데 템플릿을 활용해 로그인 폼을 구현합니다. Jinja2Templates에서는 {{}}에 들어있는 데이터를 사용할 수 있게 됩니다. (여기서는 login_form.html은 있다고 가정) 추가적으로 덧붙인다면, 아래와 같은 코드 구현을 위해서는 사전에 Jinja2와 python-multipart 설치가 필요합니다.
파일을 업로드하고 싶다면 File과 UploadFile을 이용할 수 있습니다. main 함수에서는 입력받을 수 있는 요소를 넣기 위해 HTML을 입력해 HTMLResponse를 통해 구현했습니다. 그리고 파일을 업로드하면 각각 사이즈와 파일명을 반환할 수 있는 함수들입니다.
from typing import List
from fastapi import File, UploadFile
from fastapi.responses import HTMLResponse
@app.post("/files/")
def create_files(files: List[bytes] = File(...)):
return {"file_sizes": [len(file) for file in files]}
@app.post("/uploadfiles/")
def create_upload_files(files: List[UploadFile] = File(...)):
return {"filename": [file.filename for file in files]}
@app.get("/")
def main():
content = """
<body>
<form action="/files/" enctype="multipart/form-data" method="post">
<input name="files" type="file" multiple>
<input type="submit">
</form>
<form action="/uploadfiles/" enctype="multipart/form-data" method="post">
<input name="files" type="file" multiple>
<input type="submit">
</form>
</body>
"""
return HTMLResponse(content=content)
참고자료
[1] 변성윤. "[Product Serving] Fast API (1)". boostcamp AI Tech.
MMCV에서 정의된 여러 변형 기법들이 있지만[1], 사용하기에 적합하지 않거나 새로운 기법을 적용하고 싶을 수 있습니다. 이럴 경우 새로운 사용자 정의 변형(transformation) 기법을 만들어 등록하는 과정을 거칩니다. 아래 예시는 프로젝트에서 utils라는 하위 디렉토리를 만들고 그 안에 function.py를 만든 후 RandomSharpen이라는 transformation 클래스를 만들었습니다.
새로운 사용자 정의 데이터셋을 mmseg/datasets/example.py로 만듭니다. 새로운 모듈을 추가하기 위해서는 registry에 정의되어 있는 클래스의 register_module을 통해 등록합니다. 여기서는 MMSegmentation에서 제공하고 있는 BaseSegDataset 클래스를 상속받아 작성하는 것을 가정했습니다.
from mmseg.registry import DATASETS
from .basesegdataset import BaseSegDataset
@DATASETS.register_module()
class ExampleDataset(BaseSegDataset):
METAINFO = dict(
classes=('xxx', 'xxx', ...),
palette=[[x, x, x], [x, x, x], ...])
def __init__(self, arg1, arg2):
pass
2. 불러오기
새롭게 정의된 데이터셋을 mmseg/datasets/__init__.py에 아래와 같이 추가해줍니다.
from .example import ExampleDataset
3. config 파일에 추가
새롭게 만들어진 데이터셋에 대한 config를 configs/_base_/datasets/example_dataset.py에 아래와 같이 작성해서 추가해줍니다.
새로운 지표 클래스(CustomMetric)를 만들어 해당 클래스를 mmseg/evaluation/metrics/custom_metric.py에 두었다고 가정하겠습니다. 새로운 모듈을 추가하기 위해서는 registry에 정의되어 있는 클래스의 register_module을 통해 등하게 됩니다. (일반적으로 데코레이터를 통해 등록을 진행합니다)
위 클래스는 BaseMetric이란 것을 상속받아서 만들어지게 되는데, 3가지 메서드(process, compute_metrics, evaluate)를 포함해야 하는데, 여기서는 이에 대한 설명은 생략하도록 하겠습니다. 자세히 알고 싶은 경우 참고문헌의 [1]에서 확인하시기 바랍니다.
2. 사용자 지표 가져오기
새롭게 정의된 지표를 mmseg/evaluation/metrics/__init__.py에 아래와 같이 추가해줍니다.
from .custom_metric import CustomMetric
__all__ = ['CustomMetric', ...]
3. config 파일에 추가
OpenMMLab에서 만든 라이브러리를 사용하면 .py로 쓰여진 Configuration 파일을 만들어야 합니다. (예를 들어, config.py 처럼 말이죠) 해당 파일에 새롭게 정의된 사용자 지표를 아래와 같이 작성해서 config.py에 추가해줍니다.
위 접근방식은 MMSegmentation에서 제공하고 있는 BaseMetric 등 코드를 활용하는 방법입니다. 만약 다른 방식으로 지표를 정의하고 싶다면 아래와 같은 방식으로 만들 수도 있습니다. 먼저 custom_metric.py를 만듭니다. 여기서 경로는 임의로 'path/custom_metric.py'라고 하겠습니다.
Python Imaging Library(PIL)은 이미지 처리를 위한 오픈소스 라이브러리입니다. Pillow는 PIL의 확장된 버전으로 이미지 파일 입출력 및 기본적인 처리 작업을 지원합니다. 다양한 이미지 포맷(JPEG, PNG, BMP 등)을 지원하고, 크기 조정이나 회전 등 간단한 이미지 처리가 가능합니다. [1]
PIL에서는 데이터를 PIL.Image 객체로 관리해 필요할 때마다 불러옵니다. 파이썬에서 많이 사용하는 형식인 NumPy와 변환을 위해서 아래와 같은 방법을 이용할 수 있습니다.
OpenCV는 2000년에 최초로 공개된 이미지 처리를 위한 소프트웨어로 C++로 구현되어 있지만, 파이썬이나 자바 등 다양한 프로그래밍 언어를 지원합니다. 파이썬에서 이미지를 다루기 위해서 최근에 사용해보신 분이라면 cv2라는 라이브러리 형태로 사용해본 적이 있을 겁니다.
OpenCV의 강점 중 하나는 다양한 이미지 분석 및 컴퓨터 비전 알고리즘을 제공한다는 점일 것입니다. 필터링, 엣지 검출, 변환 등 옛날부터 연구된 이미지 처리 알고리즘이나 CNN을 통해서나 접했던 특징 추출을 알고리즘으로도 가능합니다. [2]
cv2를 통한 이미지는 NumPy 다차원 배열(ndarray)을 통해 반환됩니다. NumPy는 기본적으로 [H, W, C] 구조로 저장을 하게 되는데, 그중에서 cv2는 채널(C)에 대해 통상적으로 익숙한 RGB가 아니라 BGR 순서로 처리합니다. 그렇기 때문에 RGB 채널이 필요할 경우 아래와 같은 방법을 통해 변환이 필요합니다.
참고로, EDA를 수행할 때 Jupyter Notebook을 많이 사용할텐데, cv2에서 이미지를 보여주는 기능인 cv2.imshow()는 잘 동작하지 않습니다.[3] 따라서, 저같은 경우는 NumPy 배열을 동시에 잘 활용할 수 있는 matplotlib 라이브러리를 통해 시각화를 많이 수행하고 있습니다. 이 경우에 제대로 색으로 이미지를 보기 위해선 반드시 RGB로 변환이 필요합니다.
3. PyTorch (Torchvision)
PyTorch는 대표적인 딥러닝 프레임워크로 데이터 전처리부터 모델 학습까지 모두 통합된 작업 환경을 제공하고 있습니다. Torchvision은 PyTorch 프로젝트의 한 부분으로 유명한 이미지 데이터셋, 모델 구조, 공통적인 변환을 위한 컴퓨터 비전 기능을 제공하고 있습니다. 편하게 GPU 학습을 진행하고 텐서 연산을 진행하기 위해서는 PyTorch 프레임워크에 맞는 데이터로 전환하는 과정이 Dataset을 통해 이뤄져야 합니다.
PyTorch 프레임워크에서는 이미지 데이터를 텐서(torch.Tensor) 형태로 이용합니다. 따라서, 위 두가지 라이브러리에서 이용하는 NumPy 데이터를 텐서로 바꾸는 과정이 필요합니다. 역으로 학습이 된 데이터를 NumPy 배열로 바꿔야 할 수도 있습니다.
아래 코드는 PyTorch와 NumPy 사이 변환을 나타내는 예제 코드입니다. 텐서는 주로 [C, H, W] 구조로 이뤄져 있기 때문에 NumPy의 [H, W, C]에 맞게끔 변환하기 위해 permute를 사용했습니다. [4]
Albumentations는 엄밀하게는 이미지를 읽는데 필요하진 않지만, 데이터 증강에 필요한 많은 기능을 제공하는 대표적인 데이터 증강 라이브러리 입니다. Torchvision에서도 제공하지만 기능이 상대적으로 적은 편인 것 같고, cv2보다는 훨씬 쓰기 편한 것 같아 자주 애용합니다. (특히, 객체 탐지에서 bbox나 세그멘테이션의 마스크 관련 처리 기능을 함께 제공해서 매우 편하게 사용할 수 있는 라이브러리라 생각합니다)
Albumentations는 NumPy 배열을 사용합니다. 따라서 증강 이후 훈련에 사용해야 하는 On-line Augmentation을 수행해야 할 때는 PyTorch로 바꿔야 합니다. 방법은 아래와 같이 가능합니다.
import albumentations as A
from albumentations.pytorch import ToTensorV2
# Albumentations 증강 정의
transform = A.Compose([
A.Resize(256, 256),
A.HorizontalFlip(p=0.5),
A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
ToTensorV2() # NumPy → PyTorch 텐서
])
# NumPy 이미지에 증강 적용
augmented = transform(image=image_np)
image_tensor = augmented['image']
앙상블(Ensemble)이란 일련의 예측 모델들로부터 예측 결과를 수집하여 더 좋은 예측을 도출하기 위한 방법입니다. 예를 들면, 훈련 데이터셋에서 무작위 다른 부분 데이터셋을 만들어 일련의 Decision Tree를 만들고 개별 트리의 예측을 구하여 종합해 가장 많은 선택을 받은 클래스를 최종 예측을 삼는 방법이 있을 수 있습니다.
여기선 기본이 되는 여러 앙상블 기법들을 살펴보고, 딥러닝에서 어떻게 활용할 수 있을지 간단히 코드를 통해 알아보려고 합니다.
1. 투표 기반
여러 분류기들의 예측을 모아서 가장 많이 선택된(투표 결과 높은) 클래스를 예측하는 방법을 직접 투표(hard voting)라고 합니다. 즉, 다수결 투표로 정해지는 것입니다. 머신러닝에서 많이 활용하는 사이킷런에서는 sklearn.ensemble에서 VotingClassifer 클래스를 불러오고 voting 매개변수를 'hard'로 불러올 수 있습니다.
사이킷런을 활용하지 않을 경우 hard voting의 원리에 따라 각 결과들을 저장하고 예측 결과(클래스)들에 대해 하나씩 세주는 Counter를 활용해 세주는 방식으로 구현할 수 있습니다.
import numpy as np
from collections import Counter
def hard_voting_ensemble(classifiers, X):
predictions = []
for classifier in classifiers:
predictions.append(classifier.predict(X))
final_predictions = []
for sample_predictions in zip(*predictions):
vote_count = Counter(sample_predictions)
final_predictions.append(vote_count.most_common(1)[0][0])
return np.array(final_predictions)
하지만, 이런 투표방식은 각 클래스의 예측 확률을 반영하지 못해 비교적 낮은 확률이라도 개수가 조금이라도 많으면 채택될 수 있습니다. 확률을 결과값으로 가져올 수 있다면 이들을 평균 내어서 확률이 가장 높은 클래스를 선택할 경우 조금 더 세밀한 앙상블 기법이 될 수 있습니다.
이를 간접 투표(soft voting)이라고 하며 위에서 사이킷런을 통해서 구현할 때는 voting만 'soft'로만 바꾸면 됩니다. 하지만, 사이킷런을 사용하지 않을 경우 아래 코드와 같이 사용할 수 있습니다.
import torch
def soft_voting_ensemble(models, inputs):
predictions = []
for model in models:
model.eval()
with torch.no_grad():
output = model(inputs) # 여기서 모델은 확률을 반환
predictions.append(output)
ensemble_pred = torch.mean(torch.stack(predictions), dim=0)
return torch.argmax(ensemble_pred, dim=1)
# Ensemble 예측
ensemble_models = [model1, model2, model3]
final_prediction = soft_voting_ensemble(ensemble_models, test_inputs)
2. 배깅과 페이스팅
앞서 다른 알고리즘으로 학습된 결과를 합치는 방법도 있지만, 같은 알고리즘을 통해 훈련할 때 훈련 데이터셋을 무작위로 서브셋을 구성해 모델을 각기 다르게 학습시키고 합치는 방법도 있습니다. 배깅(Bootstrap aggregating; Bagging)은 훈련 데이터셋을 중복 허용해 샘플링하는 방식이고, 중복을 허용하지 않고 샘플링하는 것을 페이스팅(pasting)이라고 합니다. 만약에 한 모델을 위해 적용한다면 배깅이 적합한 방식이라고 합니다. [2]
사이킷런으로는 아래와 같이 구현할 수 있습니다.
from sklearn.ensemble import BaggingClassifier
from sklearn.neighbors import KNeighborsClassifier
bagging = BaggingClassifier(KNeighborsClassifier(),
max_samples=0.5, max_features=0.5)
또는 딥러닝에서 활용할 수 있도록 pytorch로 구현할 수 있습니다.
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, SubsetRandomSampler
import numpy as np
class BaggingEnsemble(nn.Module):
def __init__(self, base_model, n_estimators, input_size, hidden_size, output_size):
super(BaggingEnsemble, self).__init__()
self.base_model = base_model
self.n_estimators = n_estimators
self.models = nn.ModuleList([base_model(input_size, hidden_size, output_size) for _ in range(n_estimators)])
def forward(self, x):
outputs = [model(x) for model in self.models]
return torch.stack(outputs).mean(dim=0)
def train_bagging(model, train_loader, criterion, optimizer, n_epochs):
model.train()
for epoch in range(n_epochs):
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
def predict_bagging(model, test_loader):
model.eval()
predictions = []
with torch.no_grad():
for data, _ in test_loader:
output = model(data)
pred = output.argmax(dim=1, keepdim=True)
predictions.extend(pred.numpy())
return np.array(predictions)
# 사용 예시
if __name__ == "__main__":
X = torch.randn(1000, 10)
y = torch.randint(0, 2, (1000,))
dataset = TensorDataset(X, y)
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(dataset, batch_size=32, shuffle=False)
input_size, hidden_size, output_size = 10, 50, 2
bagging_model = BaggingEnsemble(BaseModel, n_estimators=10, input_size=input_size, hidden_size=hidden_size, output_size=output_size)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(bagging_model.parameters())
train_bagging(bagging_model, train_loader, criterion, optimizer, n_epochs=5)
predictions = predict_bagging(bagging_model, test_loader)
print("Predictions shape:", predictions.shape)
랜덤 포레스트는 배깅이나 페이스팅을 이용해 하나의 베이스모델인 Decision Tree를 적용한 것들의 앙상블입니다. [3] 다만, 딥러닝 보다는 전통적인 머신러닝 기법이니만큼 굳이 pytorch의 필요성이 없기 때문에 참고자료의 내용으로 대신합니다.
3. 스태킹
모든 모델들의 예측을 모으는 과정을 간단함 함수(예를 들어, 투표 방식) 대신에 모델을 통해 앙상블을 수행하는 방식입니다. 개별 모델의 예측은 함께 모여서 마지막 모델에 예측을 위한 입력으로 사용되며, 마지막 모델은 교차 검증(cross-validation)을 통해 학습됩니다. 여기서 마지막 모델은 블렌더(Blender) 또는 메타 학습기(Meta learner)라고 불립니다.
사이킷런을 사용한다면 StackingClassifier 또는 StackingRegressor를 사용하면 되지만, 딥러닝은 조금 더 구현이 필요합니다. 메타 학습기(메타 모델)을 구현하고, 기존 모델 예측값을 합친 후에 다시 메타 모델에 넣어서 최종 예측값을 도출하는 과정이 추가되면 됩니다. 다소 길긴 하지만, 아래처럼 구현할 수 있습니다.
이렇게 데이터셋을 만드는 것은 좋은 모델을 구축하기에 필요한 과정이지만, 시간과 비용이 많이 필요한 작업입니다. 따라서, 필요할 경우 직접 데이터셋을 합성하는 것도 고려해볼 수 있습니다. 최근에는 디퓨전과 같이 성능이 좋은 합성 모델이 많지만, 이러한 모델들 역시 기본적으로 데이터가 많은 상태에서 많은 학습을 거쳐야 효과적인 것으로 보입니다. (예전에 30장 이미지 기준 100번의 학습을 진행해도 원하는 모습이 나오진 않았던 경험이 있습니다.)
모델을 학습시키기 충분한 양의 데이터를 비교적 적은 비용과 시간으로 얻기 위해 이미지 합성을 생각했고, 비교적 정형화된 모습을 갖고 있는 OCR의 경우 정해진 틀에 Rule-based의 방식으로도 충분히 다른 형태의 객체를 만들어낼 수 있을 것이라 생각했습니다.
이를 해내기 위해 기존의 데이터셋에서 객체의 정보를 가져와 이미 학습된 번역기에 전달해서 나온 결과를 기존 이미지에 입히는 방식을 생각했고, 실제 이를 구현하였습니다.
1. 2. 고려 사항
구현하는 과정에서 크게 3개의 고려 사항이 있었습니다.
번역된 문자를 넣을 깨끗한 영수증 이미지 필요
번역기 문자 수 제한 때문에 번역이 필요한 문자만 추출 필요
번역하면서 너무 길어지는 경우 잘라낼 필요
1.3. 설계
1.2에서 언급한 문제나 고려 사항들을 감안하여 다음과 같은 방향으로 진행하고자 설계하였으며, 고려 사항을 해결하기 위해 특별히 필요한 기능이 있다면 기재했습니다,
Import data
Clean image
Translate : 필요한 문자만 추출
Insert text : 글자 길이 조정
또한 구현을 위한 기초적인 환경 세팅은 다음과 같습니다. (파이썬 표준 라이브러리는 제외)
우선 COCO 데이터와 이미지 데이터를 불러와야 합니다. COCO 데이터 포맷의 경우 pycocotools라는 라이브러리가 있긴 하지만, 여기서는 사용하지 않았습니다. 아래 코드는 json 데이터를 불러옵니다.
import json
from pathlib import Path
def read_json(path: str):
with Path(path).open(encoding='utf8') as file:
data = json.load(file)
return data
json_data = read_json('./instances_default.json')
이미지도 불러오는 코드입니다. 여기서는 OpenCV를 활용했고, 코드를 돌리면 잘 불러오는 것을 확인할 수 있습니다.
import cv2
import matplotlib.patches as patches
import matplotlib.pyplot as plt
img = cv2.imread('./receipt.jpg')
annotations = json_data['annotations']
fig, ax = plt.subplots(1)
for ann in annotations:
x, y, w, h = ann['bbox']
coordinates = [(x, y), (x + w, y), (x + w, y + h), (x, y + h)]
polygon = patches.Polygon(coordinates, closed=True, linewidth=0.5, edgecolor='red', facecolor='none')
ax.add_patch(polygon)
plt.axis('off')
plt.imshow(img)
2.2. Clean image
구상했던 이미지 생성방법은 기존 영수증 이미지에서 글자들을 지우고 같은 자리에 번역된 글자를 넣는 방식입니다. 그러다보니, 기존 이미지에서 텍스트만을 지우는 방식을 고민했고 Stackoverflow에 제안된 방식을 활용했습니다. [1]
해당 방식의 가장 핵심적인 것은 모폴로지 연산이라고 생각하는데, 모폴로지 연산이란 이미지의 형태에 기반한 연산들을 말하며, 형태에 집중해야 하기 때문에 일반적으로 흑백 이미지에 적용됩니다. [2] 이를 위해 그레이 스케일로 변환하고 닫힘(Closing) 연산을 통해 텍스트 영역의 작은 구멍을 제거 및 팽창(Dilation) 연산으로 텍스트 영역을 확장합니다.
그 다음 팽창된 이미지에서 텍스트 영역의 윤곽선을 검출해 윤곽선 면적 일정 범위 내에 있는 것만 선택해 실제 텍스트 영역으로 간주합니다. 선택된 윤곽선을 기반으로 마스크 이미지를 생성해 마스크 이미지에서 흰색은 텍스트 / 검은색은 배경 영역으로 사각형을 그립니다. 원본 이미지와 생성된 마스크 이미지를 비교하면서 주변 픽셀 정보를 활용해 자연스럽게 채워넣는 기술인 inpainting으로 처리합니다.
import numpy as np
def inpaint_text_areas(image):
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
close_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (15, 3))
close = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, close_kernel, iterations=1)
dilate_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 3))
dilate = cv2.dilate(close, dilate_kernel, iterations=1)
cnts = cv2.findContours(dilate, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cnts = cnts[0] if len(cnts) == 2 else cnts[1]
mask = np.zeros(image.shape[:2], dtype=np.uint8)
for c in cnts:
area = cv2.contourArea(c)
if area > 100 and area < 30000:
x, y, w, h = cv2.boundingRect(c)
cv2.rectangle(mask, (x, y), (x + w, y + h), 255, -1)
inpainted_image = cv2.inpaint(image, mask, inpaintRadius=1, flags=cv2.INPAINT_TELEA)
return inpainted_image
inpainted_img = inpaint_text_areas(img)
plt.axis('off')
plt.imshow(inpainted_img)
아래 이미지를 보면 꽤 자연스럽게 잘 지워진 것을 확인할 수 있습니다.
2.3. Translate
직접적으로 번역을 위한 모델을 구축할 수는 없기에, 번역을 위해서는 DeepL의 API를 사용했습니다. 기존에 알고있는 다른 번역 API들은 유료였기 때문에 간단히 테스트용으로 사용하기는 부적합하다고 판단했고, DeepL은 500,000자까지는 무료로 제공하기 때문에 채택했습니다. 아래 코드는 DeepL 문서에서 알려주고 있는 코드를 활용해 함수를 만들었습니다. [3]
다만, 500,000자 제한이 있기 때문에 모든 문자를 번역하는 것은 다소 아까운 일입니다. 그래서 정규식을 이용해 해당 문자가 한국어를 포함하는 경우만 번역하도록 했습니다. 구현한 모든 함수들을 활용해 한국어가 포함된 경우에만 일본어로 번역하고 이를 동일한 annotations로 반환하도록 했습니다.
import re
def check_korean(text):
korean_pattern = re.compile(r'[\u3131-\u3163\uac00-\ud7a3]+')
matches = re.findall(korean_pattern, text)
if len(matches) < 1:
return False
return True
translator = authenticate()
translated_annotations = []
for ann in annotations:
transcription = ann['attributes']['transcription']
if check_korean(transcription):
translated_text = translate_text(translator, transcription, 'JA') # JA : Japanese
ann_copy = ann.copy()
ann_copy['attributes']['transcription'] = translated_text
translated_annotations.append(ann_copy)
else:
translated_annotations.append(ann)
2.4. Insert text
2.2.에서 만든 백지 이미지에 번역된 글자들을 추가하는 코드입니다. PIL에서 제공하고 있는 Draw를 통해 텍스트를 삽입하는 방식인데, 사용하는 메서드에 대한 설명은 다음 링크에서 확인하시면 됩니다.
add_new_text의 입력 매개변수 중 다른 것보다 특이한 font가 있습니다. 이는 PIL 라이브러리의 ImageFont. truetype으로 생성한 인스턴스로 글씨체와 사이즈를 지정합니다. 여기 프로젝트에서는 많은 세계 각국의 언어를 지원하는 Google과 Adobe가 협업해서 만들었다고 알려진 Noto Fonts 를 사용했습니다.
from PIL import Image, ImageDraw, ImageFont
def add_new_text(image, bbox, text, font):
x, y, w, h = bbox
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
draw = ImageDraw.Draw(pil_image)
text_pos_x = x
text_pos_y = y + h // 2
draw.text((text_pos_x, text_pos_y), text, fill=(0, 0, 0), font=font, anchor='lm')
return cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
하지만, 여전히 부족한 부분이 있는데, 번역을 통해 길어진 텍스트에 대한 문제 입니다. 다음 사진을 보면 일본어로 번역했을 때 바운딩 박스를 넘어가는 것을 확인할 수 있습니다. 이는 언어마다 표현하는 길이가 달라지면서 발생하는 문제입니다.
이를 해결하기 위해 3가지 방법을 시도했습니다.
폰트 사이즈 조절 : 바운딩 박스에 맞추는 방향으로 폰트 사이즈를 조절했으나, 너무 작아지는 텍스트가 만들어져 학습에 적절하지 않은 데이터가 생성된다고 판단했습니다.
Pseudo Character Center : 적당히 잘라내기 위해 PCC를 찾았지만, 단순한 방법으로는 문자마다 다른 길이로 표현되기 때문에 너무 많은 문자가 소실이 발생했습니다.
PIL getbbox : PIL font 인스턴스에 있는 getbbox를 활용하고 있는데, 이는 비교적 최신 라이브러리에서 지원하는 메서드이기 때문에 버전을 확인할 필요가 있습니다.
가장 효과적이라고 판단한 getbbox를 사용해 문자들의 길이를 측정하고, 그에 맞춰서 번역한 텍스트를 반환해서 문자를 기입하는 코드는 다음과 같습니다.
def get_char_widths(text, font:ImageFont.truetype):
char_widths = []
for char in text:
bbox = font.getbbox(char)
char_width = bbox[2] - bbox[0]
char_widths.append((char, char_width))
return char_widths
def get_text_in_box(char_widths):
text_in_box = ''
text_width = 0
for char, width in char_widths:
text_width += width
if text_width <= w:
text_in_box += char
else:
break
return text_in_box
fig, ax = plt.subplots(1)
for ann in translated_annotations:
x, y, w, h = ann['bbox']
coordinates = [(x, y), (x + w, y), (x + w, y + h), (x, y + h)]
font_size = h
font = ImageFont.truetype(font='./font/NotoSansJP-Regular.ttf', size=font_size)
polygon = patches.Polygon(coordinates, closed=True, linewidth=0.5, edgecolor='red', facecolor='none')
ax.add_patch(polygon)
char_widths = get_char_widths(ann['attributes']['transcription'], font)
text_in_box = get_text_in_box(char_widths)
inpainted_img=add_new_text(inpainted_img, ann['bbox'], text_in_box, font)
plt.axis('off')
plt.imshow(inpainted_img)
완벽하게 기존 것처럼 재현한 것은 아니지만, 꽤 그럴싸하게 만들어진 것 같습니다.
반응형
3. 최종 코드
import json
from pathlib import Path
import cv2
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
import deepl
import re
import getpass
from PIL import Image, ImageDraw, ImageFont
def read_json(path: str):
with Path(path).open(encoding='utf8') as file:
data = json.load(file)
return data
def inpaint_text_areas(image):
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
close_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (15, 3))
close = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, close_kernel, iterations=1)
dilate_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 3))
dilate = cv2.dilate(close, dilate_kernel, iterations=1)
cnts = cv2.findContours(dilate, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cnts = cnts[0] if len(cnts) == 2 else cnts[1]
mask = np.zeros(image.shape[:2], dtype=np.uint8)
for c in cnts:
area = cv2.contourArea(c)
if area > 100 and area < 30000:
x, y, w, h = cv2.boundingRect(c)
cv2.rectangle(mask, (x, y), (x + w, y + h), 255, -1)
inpainted_image = cv2.inpaint(image, mask, inpaintRadius=1, flags=cv2.INPAINT_TELEA)
return inpainted_image
def authenticate():
auth_key = getpass.getpass("Enter API Key : ")
translator = deepl.Translator(auth_key)
return translator
def translate_text(translator, text, target_lang):
result = translator.translate_text(text, target_lang=target_lang)
return result.text
def check_korean(text):
korean_pattern = re.compile(r'[\u3131-\u3163\uac00-\ud7a3]+')
matches = re.findall(korean_pattern, text)
if len(matches) < 1:
return False
return True
def add_new_text(image, bbox, text, font):
x, y, w, h = bbox
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
draw = ImageDraw.Draw(pil_image)
text_pos_x = x
text_pos_y = y + h // 2
draw.text((text_pos_x, text_pos_y), text, fill=(0, 0, 0), font=font, anchor='lm')
return cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
def get_char_widths(text, font:ImageFont.truetype):
char_widths = []
for char in text:
bbox = font.getbbox(char)
char_width = bbox[2] - bbox[0]
char_widths.append((char, char_width))
return char_widths
def get_text_in_box(char_widths):
text_in_box = ''
text_width = 0
for char, width in char_widths:
text_width += width
if text_width <= w:
text_in_box += char
else:
break
return text_in_box
json_data = read_json('./instances_default.json')
img = cv2.imread('./receipt.jpg')
annotations = json_data['annotations']
inpainted_img = inpaint_text_areas(img)
translator = authenticate()
translated_annotations = []
for ann in annotations:
transcription = ann['attributes']['transcription']
if check_korean(transcription):
translated_text = translate_text(translator, transcription, 'JA') # JA : Japanese
ann_copy = ann.copy()
ann_copy['attributes']['transcription'] = translated_text
translated_annotations.append(ann_copy)
else:
translated_annotations.append(ann)
fig, ax = plt.subplots(1)
for ann in translated_annotations:
x, y, w, h = ann['bbox']
coordinates = [(x, y), (x + w, y), (x + w, y + h), (x, y + h)]
font_size = h
font = ImageFont.truetype(font='./font/NotoSansJP-Regular.ttf', size=font_size)
polygon = patches.Polygon(coordinates, closed=True, linewidth=0.5, edgecolor='red', facecolor='none')
ax.add_patch(polygon)
char_widths = get_char_widths(ann['attributes']['transcription'], font)
text_in_box = get_text_in_box(char_widths)
inpainted_img=add_new_text(inpainted_img, ann['bbox'], text_in_box, font)
plt.axis('off')
plt.imshow(inpainted_img)
4. 한계 및 배운 점
4.1. 한계
(외부 툴) 직접 번역하는 모델을 구축하기 어려운만큼 외부 툴을 사용할 수 밖에 없었지만, 제한되는 환경이라서 실제로 사용하기에는 다소 무리가 있을 것 같습니다.
(완성도) 중간중간 나온 실패작보다는 괜찮은 결과물이지만, 생성한 데이터만으로 영수증을 떠올리기는 쉽지 않은 것 같습니다. 빈 영수증을 만드는 과정에서 그래도 자연스럽게 만들어졌다고 생각하지만, 로고까지 지워지지 않도록 규제하는 방향을 고민해볼 필요가 있습니다. 그리고 번역으로 인해 잘리는 문제가 발생하게 되는데, 위치만 찾는 Text Detector로는 기능할 수 있지만 글자까지 인식해야 하는 Text Recognizer까지 모델이 확장된다면 이러한 데이터 생성은 적절하지 않아 보입니다.
4.2. 배운 점
(OpenCV) Computer Vision에서 많이 사용한다는 OpenCV에 대해 간략하게만 알고 있었지만, 꽤나 강력한 도구들을 제공하고, 제대로 배워보면 이미지를 다루는데 큰 무기를 얻을 수 있겠다는 생각이 들었습니다. 물론 OpenCV 뿐 아니라 PIL도 꽤 유용한 것들이 많아서 앞으로 계속 프로젝트를 해보면서 익숙해져야 할 것 같습니다.
(함수) 프로젝트라는 이름을 붙일 수준은 아니지만, 여러 개의 함수들을 만들고 이들을 최대한 간결하게 작성해보려는 시도를 하면서 앞으로 어떤 방식으로 코드를 작성해야 할지에 대한 감을 늘릴 수 있었습니다. 최근에 본 몇몇 영상과 부스트 캠프 과정에서 배운 내용들을 상기해보면서 몇 번 재배치를 하고, 어떻게 함수를 만드는 것이 나중에 유지보수를 최소화할 수 있을지에 대해 잠깐이지만 고민해볼 기회였습니다.