[PyTorch] Yelp 데이터로 커스텀 데이터셋 만들기

Contents
1. Dataset and Dataloader 기초
1.1  내장된 데이터세트 로딩
1.2  내장된 데이터세트 시각화
1.3  Dataloader에 전달

2. Custom Dataset
2.1  Dataset 구성요소
2.2  Yelp 데이터를 이용하여 커스텀 데이터셋 만들기

 

1. Dataset and Dataloader 기초

PyTorch는 모델에 데이터를 입력하기 위해, torch.utils.data.Dataloader와 torch.utils.data.Dataset를 사용한다.

이때, 내장된 dataset 뿐만 아니라, custom dataset도 사용 가능하다.

 

- Dataset은 데이터의 샘플과 label을 저장한다.

- Dataloader는 데이터세트를 미니배치 단위로 iterable 객체로 감싼다.

 

1.1  내장된 데이터세트 로딩

TorchVision에서 Fasion-MNIST 데이터셋을 불러오는 실습이다.

저장된 데이터셋의 각 샘플은 각 샘플은 28x28 grayscale 이미지, 10개의 정답 클래스로 구성된다.

 

이때 내장된 데이터셋을 불러오는 매개 변수는 다음과 같다.

  • root: 저장되는 경로
  • train: 학습용 또는 테스트용 데이터세트 여부
  • dataload=True: root에 데이터가 없는 경우 인터넷에서 다운로드
  • transform or target_transform: featrue or label의 transform 지정
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt


training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

 

1.2  내장된 데이터세트 시각화

내장 데이터셋의 정답 레이블은 0부터 9까지의 정수로 구성되어 있기 때문에, 각 클래스의 이름을 매핑한다.

  • len(dataset)은 데이터셋에 존재하는 샘플의 총 개수를 반환한다.
  • squeeze() 함수는 차원이 1인 차원을 제거한다.
    아래 코드에서는 랜덤한 데이터 세트를 하나씩 불러온 후에, imshow 함수를 적용하기 위해, 배치 사이즈를 의미하는 차원을 제거한다.
labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(training_data), size=(1,)).item()
    img, label = training_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(labels_map[label])
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()

 

1.3  Dataloader에 전달

Dataset은 각 샘플의 feature를 가져오고, 하나의 샘플에 label을 지정하는 일을 한 번에 진행한다.
일반적으로 샘플들을 미니배치 단위로 전달하고, 매 에포크마다 데이터를 셔플하여, overfit을 막는다.

from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

 

next(iter(dataloader))를 통해 다음 미니 배치 단위의 샘플들을 불러온다.

아래 코드는 첫번째 배치 단위 (64개의 샘플)에서 하나의 샘플을 불러와 시각화하는 코드다.

train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")
### Result ###

Feature batch shape: torch.Size([64, 1, 28, 28])
Labels batch shape: torch.Size([64])

 

2. Custom Dataset

2.1  Dataset 구성요소

커스텀 Dataset 클래스를 구성하기 위해서는 3개의 함수가 반드시 구현되어야 한다.

  • __init__
    init 함수는 Dataset 객체가 생성될 때 한 번만 실행된다.
    데이터 파일을 불러오거나, 데이터에 대한 변환을 진행하는 transform을 초기화 한다.

 

  • __len__
    데이터 세트의 샘플 개수를 반환한다.
    즉 총 샘플의 수를 반환하도록 만들어야 한다.

 

  • __getitem__
    주어진 인덱스 (idx)에 해당하는 샘플을 데이터셋에서 반환해야 한다.
    즉 데이터 셋의 특정 1개의 샘플을 가져오는 함수로 구성된다.

구조는 아래 코드와 같다.

class CustomDataset(torch.utils.data.Dataset): 
  def __init__(self):

  def __len__(self):

  def __getitem__(self, idx):

 

2.2  Yelp 데이터를 이용하여 커스텀 데이터셋 만들기

오픈 데이터인 Yelp 데이터를 pandas 데이터 프레임으로 불러온 후에, 커스텀 데이터셋 객체를 만드는 코드이다.

사용된 json 파일은 공식 홈페이지에서 다운 받을 수 있다 (https://www.yelp.com/dataset).

 

먼저 json 파일로 저장된 review 파일을 pandas 데이터 프레임 형태로 불러온다.

import pandas as pd
import json

def init_ds(json):
    ds= {}
    keys = json.keys()
    for k in keys:
        ds[k]= []
    return ds, keys

def read_json(file):
    dataset = {}
    keys = []
    with open(file) as file_lines:
        for count, line in enumerate(file_lines):
            data = json.loads(line.strip())
            if count ==0:
                dataset, keys = init_ds(data)
            for k in keys:
                dataset[k].append(data[k])

        return pd.DataFrame(dataset)
        
df = read_json('yelp_academic_dataset_review.json')

df.info()

데이터 구성은 다음과 같다.

총 699만개의 샘플과 9개의 컬럼 or feature로 이루어져 있다.

이 중 유저ID와 레스토랑ID, 평점 정보를 가져오는 커스텀 데이터 세트를 만든다.

 

### Result ###

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 6990280 entries, 0 to 6990279
Data columns (total 9 columns):
 #   Column       Dtype  
---  ------       -----  
 0   review_id    object 
 1   user_id      object 
 2   business_id  object 
 3   stars        float64
 4   useful       int64  
 5   funny        int64  
 6   cool         int64  
 7   text         object 
 8   date         object 
dtypes: float64(1), int64(3), object(5)
memory usage: 480.0+ MB

 

Pandas DataFrame 객체인 df에서 'user_id' / 'business_id' / 'stars' 컬럼 데이터를 불러오는 코드이다.

  • __init__ 함수는 daframe을 입력받아, 각 컬럼의 데이터를 불러온다.
  • __len__ 함수에서 전체 데이터 세트의 개수를 반환하기 위해 user 정보의 길이를 반환한다.
  • __getitem__은 입력받은 idx에 대해 각 user, item, rating 정보를 가져온다.

마지막으로 데이터셋 객체에서, 정수로 이루어진 rating은 torch.tensor를 적용하였다.

class Rating_Datset(torch.utils.data.Dataset):
    def __init__(self, dataset):
        super(Rating_Datset, self).__init__()
        self.user = dataset['user_id']
        self.item = dataset['business_id']
        self.rating = dataset['stars']

    def __len__(self):
        return len(self.user)

    def __getitem__(self, idx):
        user = self.user[idx]
        item = self.item[idx]
        rating = self.rating[idx]

        return (
            user,
            item,
            torch.tensor(rating, dtype=torch.float))

 

커스텀 데이터세트 클래스를 불러온 후에, Dataloader에 전달한다.

batch_size는 4로 설정하였다.

이후 next(iter(dataloader))를 통해 데이터 로더의 첫번째 iteration을 불러온다.

dataset = Rating_Datset(df)

data_loader = DataLoader(dataset,
                         batch_size=4,
                         shuffle=True)
                         
next(iter(data_loader))

불러온 결과는 다음과 같다.

### Result ###

[('zKdkxy21Q-3zX03ObcVm9A',
  'CUh9kryOG18271MidjsESQ',
  '345cM-IMsRDwutYG-AkJNw',
  'NY_Moyh9hwAcORMbMZRmmg'),
 ('i9rqSiwl3EAXWU9XA1YyOw',
  'XCArZZsxUNsE7SmQqCYjxg',
  '9xLxbTsG2a-K1qkaU8M1aw',
  '0Mjhi7hYia7iv8da7NBL3Q'),
 tensor([3., 1., 4., 1.])]
반응형