파이토치에서는 DataLoader 클래스를 통해 데이터를 편리하게 불러올 수 있다.
(ex : 특정 batch size만큼씩 데이터 불러오기, shuffle, 미리 정의한 순서대로 데이터 불러오기 등)
이를 위해서는 Dataset 클래스로 데이터를 사전 정의해야 한다.
import torch
from torch.utils.data import Dataset
테스트를 위해 간단한 tensor 데이터를 정의하였다.
x = [[0,0,0], [0,0,1], [0,1,0],
[1,0,0], [0,1,1], [1,1,0],
[1,0,1], [1,1,1]]
y = [0, 1, 2, 3, 4, 5, 6, 7]
x = torch.Tensor(x).float()
y = torch.Tensor(y).long()
이제 Dataset 클래스를 이용해 custom dataset을 정의하자.
class Data(Dataset):
def __init__(self):
self.x_data = x
self.y_data = y
def __len__(self):
return len(self.x_data)
def __getitem__(self, idx):
return self.x_data[idx], self.y_data[idx]
Dataset을 정의할 때에는 다음 세 가지 메서드를 포함해야 한다.
- __init__ : 인스턴스 생성 시 바로 실행되는 부분
- __len__ : 데이터셋의 길이를 리턴하는 부분
- __getitem__ : 데이터의 특정 인덱스에서의 값을 리턴하는 부분
위의 예시에서는 데이터셋 인스턴스를 생성할 때, 즉 __init__() 함수가 호출될 때 미리 정의한 데이터를 불러오도록 하였다.
__len__과 __getitem__ 메서드가 필요한 이유는 DataLoader를 통해 데이터를 불러올 때 인덱싱을 이용하기 때문이다.
다음과 같이 인스턴스를 생성하면 데이터셋이 만들어진다.
dataset = Data()
위의 예시에서는 class 내부에서 곧바로 x, y를 불러오도록 하였으나 __init__ 부분에서 불러올 데이터를 인수로 받을 수도 있다.
(Train / Valid / Test 데이터가 미리 나뉘어져 있을 때 편함)
class Data(Dataset):
def __init__(self, X, Y):
self.x_data = X
self.y_data = Y
def __len__(self):
return len(self.x_data)
def __getitem__(self, idx):
return self.x_data[idx], self.y_data[idx]
이 경우 인스턴스 생성은 다음과 같이 할 수 있다.
dataset = Data(x,y)
Github :
'Deep Learning > Pytorch' 카테고리의 다른 글
[Pytorch] ImageFolder로 쉽게 Image Dataset 만들기 (0) | 2022.07.14 |
---|