[PyTorch] Weight Initialization (기울기 초기화)

  딥러닝 모델 학습시 가중치를 초기화하고, 예측값과 실제 값의 차이를 loss function을 통해 구한다. 이후 loss를 줄이는 방향으로 가중치를 업데이트를 하게 된다. 결국 모델이 학습하는 대상은 optimal한 가중치이다.

 

첫 가중치는 랜덤한 수로 초기화 된다. 이때 초기화 되는 가중치 값이 모델 성능에 영향을 미친다. 아래 그림은 가중치 초기값에 따라 학습의 진행 방향이 달라지는 모습을 보여준다.


  Weight Initialization은 nn.init을 통해 설정할 수 있다. 이때 bias도 따로 설정해줘야 한다.

 

torch.nn.init.constant_(tensor, val): 상수로 설정
torch.nn.init.ones_(tensor): 1로 설정
torch.nn.init.zeros_(tensor): 0으로 설정
torch.nn.init.unifiom_(tensor, a=0.0, b=1.0): a부터 b사이의 값을 균일한 분포로 설정
torch.nn.init.normal_(tensor, mean=0.0, std=1.0): 평균이 0이고 표준편차가 1인 분포로 설정

 

Weight Initialization은 이미 연구가 많이 되어서 Layer의 특성에 맞추어 초기화 하는 방법이 좋다고 알려져있다.

크게 두가지 방법을 살펴보겠다.

 

1. Xavier Initialization

  • 활성화 함수로 Sigmoid 함수를 사용할 때 적용한다.
    (여기서 Tanh도 Sigmoid 함수에 포함됨)
  • ReLU에서는 출력 값이 0으로 수렴하는 현상을 발생시킨다
  • Xavier initialization은 이전 노드와 다음 노드의 개수에 의존한다.

 

Xavier-uniform: $ u (-a, a)$

$$ a = gain \times  \sqrt{\frac{6}{fan\_in + fan\_out}} $$

 

  • torch.nn.init.xavier_uniform_(tensor, gain=1.0)
  • torch.nn.init.xavier_normal_(tensor, gain=1.0)

 

아래와  같이 Custom Network를 정의하고, xavier_uniform_을 지정해준다.

import torch
import torch.nn as nn

class Custom_Net(nn.Module):
    def __init__(self):
        super(Custom_Net, self).__init__()

        self.linear_1 = nn.Linear(1024, 1024, bias=False)
        self.linear_2 = nn.Linear(1024, 512, bias=False)
        self.linear_3 = nn.Linear(512, 10, bias=True)

        for m in self.modules():
            if isinstance(m, nn.Linear):
                torch.nn.init.uniform_(m.weight.data)
          
    def forward(self, x):
        ...

 

2. He Initialization

  • 활성화 함수가 RuLU일 때 적용

 

kaiming_normal: $N(0, std^2)$

$$ bound = gain \times \sqrt{\frac{3}{fan\_mode}} $$

 

아래는 kaiming_uniform_을 함수로 만들어서 apply 한 코드이다.

 

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=10,
                               kernel_size=5,
                               stride=1)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_bn = nn.BatchNorm2d(20)
        self.dense1 = nn.Linear(in_features=320, out_features=50)
        self.dense1_bn = nn.BatchNorm1d(50)
        self.dense2 = nn.Linear(50, 10)
 
    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_bn(self.conv2(x)), 2))
        x = x.view(-1, 320) #reshape
        x = F.relu(self.dense1_bn(self.dense1(x)))
        x = F.relu(self.dense2(x))
        return F.log_softmax(x)
def initialize_weights(m):
  if isinstance(m, nn.Conv2d):
      nn.init.kaiming_uniform_(m.weight.data,nonlinearity='relu')
      if m.bias is not None:
          nn.init.constant_(m.bias.data, 0)
  elif isinstance(m, nn.BatchNorm2d):
      nn.init.constant_(m.weight.data, 1)
      nn.init.constant_(m.bias.data, 0)
  elif isinstance(m, nn.Linear):
      nn.init.kaiming_uniform_(m.weight.data)
      nn.init.constant_(m.bias.data, 0)
      
model=CNN()
model.apply(initialize_weights)

 

아래 코드는 최종적으로 여러 방법을 적용할 수 있는 코드이다.

class CNN(nn.Module):
    def __init__(self):
        super(CNN,self).__init__()
        self.layer = nn.Sequential(
            nn.Conv2d(1,16,3,padding=1),  # 28 x 28
            nn.ReLU(),
            nn.Conv2d(16,32,3,padding=1), # 28 x 28
            nn.ReLU(),
            nn.MaxPool2d(2,2),            # 14 x 14
            nn.Conv2d(32,64,3,padding=1), # 14 x 14
            nn.ReLU(),
            nn.MaxPool2d(2,2)             #  7 x 7
        )
        self.fc_layer = nn.Sequential(
            nn.Linear(64*7*7,100),
            nn.ReLU(),
            nn.Linear(100,10)
        )             
        
        # 초기화 하는 방법
        # 모델의 모듈을 차례대로 불러옵니다.
        for m in self.modules():
            # 만약 그 모듈이 nn.Conv2d인 경우
            if isinstance(m, nn.Conv2d):
                '''
                # 작은 숫자로 초기화하는 방법
                # 가중치를 평균 0, 편차 0.02로 초기화합니다.
                # 편차를 0으로 초기화합니다.
                m.weight.data.normal_(0.0, 0.02)
                m.bias.data.fill_(0)
                
                # Xavier Initialization
                # 모듈의 가중치를 xavier normal로 초기화합니다.
                # 편차를 0으로 초기화합니다.
                init.xavier_normal(m.weight.data)
                m.bias.data.fill_(0)
                '''
                
                # Kaming Initialization
                # 모듈의 가중치를 kaming he normal로 초기화합니다.
                # 편차를 0으로 초기화합니다.
                init.kaiming_normal_(m.weight.data)
                m.bias.data.fill_(0)
            
            # 만약 그 모듈이 nn.Linear인 경우
            elif isinstance(m, nn.Linear):
                '''
                # 작은 숫자로 초기화하는 방법
                # 가중치를 평균 0, 편차 0.02로 초기화합니다.
                # 편차를 0으로 초기화합니다.
                m.weight.data.normal_(0.0, 0.02)
                m.bias.data.fill_(0)
                
                # Xavier Initialization
                # 모듈의 가중치를 xavier normal로 초기화합니다.
                # 편차를 0으로 초기화합니다.
                init.xavier_normal(m.weight.data)
                m.bias.data.fill_(0)
                '''
                
                # Kaming Initialization
                # 모듈의 가중치를 kaming he normal로 초기화합니다.
                # 편차를 0으로 초기화합니다.
                init.kaiming_normal_(m.weight.data)
                m.bias.data.fill_(0)

    def forward(self,x):
        out = self.layer(x)
        out = out.view(batch_size,-1)
        out = self.fc_layer(out)
        return out
반응형