[PyTorch] Early Stopping & LRScheduler

딥러닝 모델 성능을 최적화 하기 위해 콜백 함수를 사용한다.
콜백 함수는 개발자가 명시적으로 함수를 호출하는 것이 아니라, 함수를 등록하고 특정 이벤트 발생에 의해 함수를 호출하고 처리하도록 하는 함수이다.
대표적으로 조기 종료를 뜻하는 Early Stopping과 학습률을 조정하는 LRScheduler가 있다.

아래에서 조기종료, 콜백 클래스를 구현한 코드를 살펴보겠다.

# google colab gpu 환경

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from torchvision import transforms, datasets

import matplotlib
import matplotlib.pyplot as plt
import time
import argparse
from tqdm import tqdm
matplotlib.style.use('ggplot')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

파이토치의 ImageFolder 라이브러리를 활용하여 데이터셋을 정의한다.
데이터는 캐글의 hotdog | not hotdog 데이터를 사용한다. https://www.kaggle.com/datasets/thedatasith/hotdog-nothotdog

train_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
val_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])


train_dataset = datasets.ImageFolder(
    root=train_path,
    transform=train_transform
)
train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=16, shuffle=True,
)
val_dataset = datasets.ImageFolder(
    root=test_path,
    transform=val_transform
)
val_dataloader = torch.utils.data.DataLoader(
    val_dataset, batch_size=16, shuffle=False,
)

1. LRScheduler

학습률 조정을 위한 LRScheduler 클래스를 아래와 같이 구현하였다. PyTorch에서 제공하는 torch.optim.lr_scheduler를 활용한다.
LRScheduler 클래스는 학습이 진행되는 과정에서 주어진 'patience' 횟수만큼 검증 데이터셋에 대한 오차 감소가 없으면 주어진 'factor'만큼 학습률을 감소시켜 모델 최적화가 가능하게 도와준다. 클래스의 call 함수에서 인자로 val_loss를 넣어주었다.

class LRScheduler():
    def __init__(self, optimizer, patience=5, min_lr=1e-6, factor=0.5):
        self.optimizer = optimizer
        self.patience = patience
        self.min_lr = min_lr
        self.factor = factor # LR을 factor배로 감소시킴
        self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer,
            mode='min',
            patience=self.patience,
            factor=self.factor,
            min_lr=self.min_lr,
            verbose=True
        )

    def __call__(self, val_loss):
        self.lr_scheduler.step(val_loss)
  • torch.optim.lr_scheduler.ReduceLRonPlateau는 val_loss에 대해서 변동이 없으면 학습률을 factor배로 감소시킨다.
  • optimizer는 학습률과 관련되어 모델 파라미터를 갱신시키는 부분이다.
  • mode는 언제 학습률을 조정할지 지정하는 기준이다. 이때 인자가 min이 되어야 하는지, max가 되어야 하는 알려주는 파라미터이다.
  • patience: 몇 번에 에포크만큼 기다릴 것인지 지정한다.
  • factor: 학습률을 얼마나 감소시킬지 지정하는 파라미터이다.
  • min_lr: 말 그대로 학습률의 최소 값이다. 지속적으로 콜백함수가 호출되어 학습률이 감소할 때 최소 한계를 지정해준다.
  • verbose: 0 or 1

2. EarlyStopping

EarlyStopping 클래스는 특정 에포크 동안 오차가 개선되지 않을 때 훈련을 조기 종료한다.

class EarlyStopping():
    def __init__(self, patience=5, verbose=False, delta=0, path=path):

        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False # 조기 종료를 의미하며 초기값은 False로 설정
        self.delta = delta # 오차가 개선되고 있다고 판단하기 위한 최소 변화량
        self.path = path
        self.val_loss_min = np.Inf

    def __call__(self, val_loss, model):
        # 에포크 만큼 한습이 반복되면서 best_loss가 갱신되고, bset_loss에 진전이 없으면 조기종료 후 모델을 저장
        score = -val_loss
        
        if best_score is None:
            self.bset_score = score
            self.save_checkpoint(val_loss, model)

        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'Early Stopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True

        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0
    
    def save_checkpoint(self, val_loss, model):
        if self.verbos:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}. Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

  • patience: 몇 번 에포크만큼 기다릴지 지정한다.
  • delta: 오차가 개선되었다고 판단하기 위한 최소 변화량을 나타낸다.

3. ArgumentParser()

콜백 함수를 인수로 지정하여 활용하기 위해 argparse 라이브러리를 사용한다.

parser = argparse.ArgumentParser()
parser.add_argument('--lr-scheduler', dest='lr_scheduler', action='store_true')
parser.add_argument('--early-stopping', dest='early_stopping', action='store_true')
parser.add_argument("-f", "--fff", help="a dummy argument to fool ipython", default="1") #주피터 노트북에서 실행할때 필요합니다
args = vars(parser.parse_args())

parser.add_argument()

  • 첫 번째 파라미터: 옵션 문자열의 이름으로 명령을 실행할 때 사용한다.
    ex) python main.py --lr-scheduler
  • dest: 입력 값이 저장되는 변수 이름
  • action=True: dest 파라미터에 의해 생선된 변수에 지정된 입력 값을 저장한다.

학습 목적을 위해 pretrained 모델을 활용한다.

model = models.resnet50(pretrained=True).to(device)
lr = 0.001
epochs = 100
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

# 모델 학습 함수
def training(model, train_dataloader, train_dataset, optimizer, criterion):
    print('Training')
    model.train()
    train_running_loss = 0.0
    train_running_correct = 0
    counter = 0
    total = 0
    prog_bar = tqdm(enumerate(train_dataloader), total=int(len(train_dataset)/train_dataloader.batch_size))
    for i, data in prog_bar:
        counter += 1
        data, target = data[0].to(device), data[1].to(device)
        total += target.size(0)
        optimizer.zero_grad()
        outputs = model(data)
        loss = criterion(outputs, target)
        train_running_loss += loss.item()
        _, preds = torch.max(outputs.data, 1)
        train_running_correct += (preds == target).sum().item()
        loss.backward()
        optimizer.step()
        
    train_loss = train_running_loss / counter
    train_accuracy = 100. * train_running_correct / total
    return train_loss, train_accuracy
# 콜백을 적용할 검증 함수
def validate(model, test_dataloader, val_dataset, criterion):
    print('Validating')
    model.eval()
    val_running_loss = 0.0
    val_running_correct = 0
    counter = 0
    total = 0
    prog_bar = tqdm(enumerate(test_dataloader), total=int(len(val_dataset)/test_dataloader.batch_size))
    with torch.no_grad():
        for i, data in prog_bar:
            counter += 1
            data, target = data[0].to(device), data[1].to(device)
            total += target.size(0)
            outputs = model(data)
            loss = criterion(outputs, target)
            
            val_running_loss += loss.item()
            _, preds = torch.max(outputs.data, 1)
            val_running_correct += (preds == target).sum().item()
        
        val_loss = val_running_loss / counter
        val_accuracy = 100. * val_running_correct / total
        return val_loss, val_accuracy
# 모델 학습
train_loss, train_accuracy = [], []
val_loss, val_accuracy = [], []
start = time.time()
for epoch in range(epochs):
    print(f"Epoch {epoch+1} of {epochs}")
    train_epoch_loss, train_epoch_accuracy = training(
        model, train_dataloader, train_dataset, optimizer, criterion
    )
    val_epoch_loss, val_epoch_accuracy = validate(
        model, val_dataloader, val_dataset, criterion
    )
    train_loss.append(train_epoch_loss)
    train_accuracy.append(train_epoch_accuracy)
    val_loss.append(val_epoch_loss)
    val_accuracy.append(val_epoch_accuracy)
    
    if args['lr_scheduler']:
        lr_scheduler(val_epoch_loss)
    if args['early_stopping']:
        early_stopping(val_epoch_loss, model)
        if early_stopping.early_stop:
            break
    print(f"Train Loss: {train_epoch_loss:.4f}, Train Acc: {train_epoch_accuracy:.2f}")
    print(f'Val Loss: {val_epoch_loss:.4f}, Val Acc: {val_epoch_accuracy:.2f}')
end = time.time()
print(f"Training time: {(end-start)/60:.3f} minutes")

코드를 파이썬 (.py)로 저장한 후 프롬프트 명령에서 아래와 같이 인수를 불러온다.

python 파일명.py --lr-scheduler
python 파일명.py --early-stopping


사실 간단하게 training 단계에서 에포크마다 scheduler.step을 지정해주는게 편하다.
본 포스팅은 [길벗] 딥러닝 파이토치 교과서 서적 을 참고하여 공부를 위해 작성했다.

반응형