[Tech +]빠른 학습을 위한, 빠른 Augmentation!

안녕하세요. 머지리티 AI 팀에서 비전 딥러닝 시스템 구축을 담당하고 있는 Theo입니다.

오늘 제가 다룰 주제는 'Data augmentation'에 대한 부분인데요, 대용량 학습을 진행하다보면 Bottle neck 등으로 인해 속도 저하가 발생하기 쉬운데, 이를 해결하기 위해 제가 시도한 방법들을 공유 드릴까 해요! 

AI 학습을 담당하시는 분들께 작지만 유익한 콘텐츠가 되길 바라며.... 바로 시작할게요 😉 


빠른 학습을 위한 

빠른 Augmentation.


목차

  1. Vision Deep learining Data augmentatio의 필요성
  2. torchvision.transforms VS kornia
  3. Resnet과 Kornia를 활용하여 학습 하기
  4. 결론

 

1. Vision Deep learining Data augmentation의 필요성

최근 팀에서 대용량 학습을 진행하며 다양한 문제들이 발생했고, augmentation 과정 중 생기는 bottle neck 현상이 원인의 한가지로 확인되었다. 이를 해결하기 위해 Torchvision 에서 Albumentations을 사용하여 속도를 높였지만 보다 더 빠른 속도로 학습시키고 싶어졌고 그렇게 알아보게된  API Kornia를 소개 해보려한다. Kornia는 이미지 프로세싱에서 속도를 높이기 위해 사용하듯 CUDA를 사용하여  Data augmentation를 진행하였고 이를 통해 Bottle neck 현상을 현저하게 개선할 수 있었다.

본 콘텐츠에선 Data augmentation의 필요성에 대한 설명과 간략한 실습코드를 제시한다. 이를 통해 다양한 환경에서 training을 진행하고, 효율적인 training 전략을 수립할 수 있었으면 좋겠다.


1-1 Data augmentation, 왜 필요할까?

최근 사용되는 Transforemrs(ex. BERT)의 Parameter는 100B 이상 크기를 가지고 있다. 대용량의 하이퍼 파라미터를 최적화 시키기 위해선 더 많은 데이터가 요구되는데, 최근 Vision model에서 유행하는 Transformers model(ex. vit)은 대용량의 데이터를 요구하며 일정 이하의 데이터로 학습을 진행하면 ResNet model과 큰 성능차이를 느끼기 어렵다.

또한 딥러닝의 training data의 다양성이 줄어들면 overfiting 문제와 함께 모델의 generlize를 보장하기 어렵다. 이러한 문제점을 해결하기위해 더 많은 데이터를 수집하면 해결 가능하지만, 더 많은 데이터를 통해 학습했을때 모델의 inference 성능은 향상되기때문에 data augmentation을 사용하면 더 좋은 결과를 기대할 수 있다

최근 data augmentation에 대한 연구가 많이 진행되고 있는데 대표적으로 mixup , Cutout 등이 있다. 이러한 연구에서 보여주듯 단순하게 데이터 증가를 통해 overfitting을 방지하는것뿐 아니라 데이터 변환을 통해 training 난이도를 올리고 inference 성능을 높이고 있다.

1-2 visionData augmentation의 결과

pytorch에서 제공해주는 data augmentation의 결과는 pytorch doc에서 확인 가능하며 다양한 예제가 존재한다.


위의 예시 이미지와 같이 이미지 사이즈 변화, 색상 변화, 위상 변화 등을 줄 수 있으며 이를 통해 데이터의 다양성을 확보할 수 있음을 확인하였다.

1-3 visionData augmentation의 결과

아무리 좋은 Data Augmentation을 진행하였다한들 고민 없이 모두가 사용하는 것을 똑같이 활용한다면 생각한 것과 다른 결과가 나올 수 있다. 가령 Face detection을 진행하고 싶은데 위상을 변화 시켜 사람의 얼굴의 형태가 변환된다면 오히려 Inference에 문제를 발생 시킬 가능성이 높다. 

Deep learning은 Data로 시작해 Data로 끝난다고 해도 과언이 아니다. Data에 대한 충분한 고려와 검토를 통해 자신에게 알맞은 Data Augmentation을 사용해야 원하는 결과와 성능 향상을 기대할 수 있겠다.



2. torchvision.transforms VS kornia

2-1 torchvision.transforms과 Kornia의 차이점

Kornia 역시 pytorch 종속 API로 Data augmentation및 Vision 처리를 위해 만들어진 API다. torchvison.transforms로 Data augmentation를 사용하여 LageScale training을 한다면 Bottleneck 현상이 발생하는 것을 확인할 수 있는데 이를 해결하기 위해 GPU를 사용하여 이미지 연산을 진행하는 API Kornia를 사용하려 한다.

Libraries
TorchVision
Albumentations
Kornia (GPU)
Kornia (GPU)
Kornia (GPU)
Batch Size11132128
RandomPerspective4.88±1.82
4.68±3.60
4.74±2.84
0.37±2.67 0.20±27.00
ColorJiggle4.40±2.88
3.58±3.66
4.14±3.85
0.90±24.68
0.83±12.96
RandomAffine
3.12±5.80
2.43±7.11
3.01±7.80
0.30±4.39
0.18±6.30
RandomVerticalFlip
0.32±0.08
0.34±0.16
0.35±0.82
0.02±0.13 0.01±0.35 
RandomHorizontalFlip
0.32±0.08 0.34±0.18 0.31±0.59
0.01±0.26
0.01±0.37
RandomRotate
1.82±4.70
1.59±4.33
1.58±4.44
0.25±2.09
0.17±5.69 
RandomCrop
4.09±3.414.03±4.94
3.84±3.07 0.16±1.17
0.08±9.42
RandomErasing
2.31±1.47
1.89±1.08
2.32±3.31 0.44±2.82
0.57±9.74
RandomGrayscale
0.41±0.18 0.43±0.60 0.45±1.20
0.03±0.11
0.03±7.10
RandomResizedCrop
4.23±2.86
3.80±3.61
4.07±2.67
0.23±5.27 0.13±8.04
RandomCenterCrop
2.93±1.29 2.81±1.38 2.88±2.34 0.13±2.20
0.07±9.41 


위 표에 나타나있듯 batch size가 증가할수록 강력한 성능을 보임을 확인할 수 있다.

torchvision.transforms에서 제공하는 함수 외에도 더 많은 기능을 제공하고 있으며 속도 또한 증가하여 효과적인 training과 inference를 진행할 수 있기에 본 콘텐츠에서는 kornia를 사용하여 toy dataset 학습을 목표로 진행하려고 한다.


3. Resnet과 Kornia를 활용하여 학습하기


3-1 사전 설치 api

Window에서 python 3.8.x을 활용하여 학습을 진행하였으며, CUDA 11.6을 진행 또한 Jupyer notbook을 활용하여 실행

pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116
pip install kornia
pip install timm
pip install numpy
pip install tqdm


3-1 training code 

import torch
import torchvision
from torch.utils.data import DataLoader
import torch.nn as nn
from torchvision.models import resnet18
from torchvision.datasets import CIFAR10
from torch.optim import Adam
from torch.nn.functional import cross_entropy

from timm.utils import AverageMeter
from kornia import image_to_tensor
from kornia.augmentation import RandomAffine, RandomHorizontalFlip, RandomVerticalFlip

import numpy as np
from tqdm import tqdm

학습에 사용될 API import

model = resnet18(False).cuda()
optim = Adam(model.parameters())

losses = AverageMeter()

학습에 사용되는 모델을 torchvision에서 불러왔으며, optimazation은 Adam을 이용하여 학습을 진행하였으며 이때 로그를 위해 timm에서 제공하는 AVerameter()를 사용

class DataAugmentation(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.transforms = nn.Sequential(
            RandomAffine([0.0, 359.9]),
            RandomHorizontalFlip(p=1),
            RandomVerticalFlip(p=1)
        )

    @torch.no_grad()
    def forward(self, x):
        x_out = self.transforms(x)

        return x_out

class Preprocess(nn.Module):
    """Module to perform pre-process using Kornia on torch tensors."""

    @torch.no_grad()  # disable gradients for effiency
    def forward(self, x) :
        x_tmp = np.array(x)  # HxWxC
        x_out = image_to_tensor(x_tmp, keepdim=True)  # CxHxW
        return x_out.float() / 255.0

Preprocess 는 torchvision에서 제공하는 transform과 달리 cuda를 사용해야 하기때문에 image를 tensor로 변환만 시켜준다. 필요하다면 image resize도 같이 진행하면 된다.

DataAugmentation은 본격적인 데이터 변환을 위한 클래스로서 cuda를 pytorch에서 활용하기 위해 신경망과 같이 nn.Module을 상속받아 사용하고 있다.

#
transform = Preprocess()


data = CIFAR10('data', transform=transform, download=True)
data_load = DataLoader(data, batch_size=100 )
tran = DataAugmentation()

for image, label in tqdm(data_load):
        image, label = image.cuda(), label.cuda()
    #cuda 활용을 위해 본격적으로 여기서 실행한다
        image = tran(image)
        optim.zero_grad()
        
        output = model(image)
        loss = cross_entropy(output, label)
        loss.backward()
    losses.update(loss)
        optim.step()
print(losses.avg)

에폭을 통해 모델이 정상적으로 학습됨을 확인 할 수 있다!


4. 결론

Batch size가 클 수록 더 효과적임을 확인할 수 있는데 속도 측면에서 매우 유리한 성능 지표를 가지고 있지만, GPU resource cost가 매우 올라갈 수 있음을 고려해야 한다.

하지만 이러한 고민을 제외하더라도 large scale training 진행할 경우 좋은 선택 사항이 될 수 있으며, 위 예시로든 function뿐 아니라 최신 augmentation 기법도 많이 첨부되어 있어 활용도는 더 높을 수 있다. 간단한 예제를 통해 kornia 사용법을 익혔으니 실제로 사용하여 training 소요시간을 줄여 진행하면 좋을 듯 하다.



참조

https://kornia.readthedocs.io/en/latest/augmentation.html

https://tutorials.pytorch.kr/beginner/blitz/cifar10_tutorial.html

https://pytorch.org/vision/stable/transforms.html

https://arxiv.org/pdf/1710.09412.pdf

https://arxiv.org/pdf/1708.04552.pdf



Contact us

일반문의 - official@mergerity.com

투자문의 - ir@mergerity.com
언론보도 – media@mergerity.com
파트너쉽 – partnership@mergerity.com

Social Media

3F, 211, Hakdong-ro, Gangnam-gu, Seoul, Republic of Korea  

 02-545-2091 | official@mergerity.com


Contact us

일반문의 - official@mergerity.com

투자문의 - ir@mergerity.com
언론보도 – media@mergerity.com
파트너쉽 – partnership@mergerity.com

Contact Info

official@mergerity.com

3F, 211, Hakdong-ro, Gangnam-gu, 

Seoul, Republic of Korea