お疲れ様です。きざきまるおです。
今回はPytorchのDataLoaderについて使い方を書いていこうと思います。
ざっくりとした説明になってしまいますがご了承ください。
それではどうぞ。
DataLoaderとは
DataLoaderとは与えられたDatasetをミニバッチ(小さいデータの集合)に分割するものです。
学習を回す際、このミニバッチ単位で学習を行います。
実装
事前準備
まずはDatasetを準備しましょう。
今回はscikit-learnのワインデータをDatasetとしてセットします。
import torch
import pandas as pd
from sklearn.datasets import load_wine
wine_df = load_wine()
wine_data = pd.DataFrame(wine_df.data, columns=wine_df.feature_names)
wine_target = pd.DataFrame(wine_df.target, columns=['target'])
class WineDataset(torch.utils.data.Dataset):
def __init__(self, df, transform=None):
self.features_values = df.data
self.labels = df.target
self.transform = transform
def __len__(self):
return len(self.features_values)
def __getitem__(self, idx):
features_x = torch.FloatTensor(self.features_values[idx])
labels = torch.LongTensor([self.labels[idx]])
return features_x, labels
Dataset = WineDataset(wine_df)
そしてDataLoaderの設定です。
DataLoader = torch.utils.data.DataLoader(dataset=Dataset, batch_size=30, shuffle=True)
for i, (features, label) in enumerate(DataLoader):
features = features
labels = labels
print(i)
これだけです。簡単ですね。
DataLoaderがデータをどれくらい分割しているかはfor分のenumerateで取得できます。
それではまた。