Google ColaboratoryでPytorchのDataLoaderを試してみる

本サイトは広告収入およびアフィリエイト収益を受けております。

スポンサーリンク
スポンサーリンク

お疲れ様です。きざきまるおです。

今回はPytorchのDataLoaderについて使い方を書いていこうと思います。
ざっくりとした説明になってしまいますがご了承ください。

それではどうぞ。

DataLoaderとは

DataLoaderとは与えられたDatasetをミニバッチ(小さいデータの集合)に分割するものです。
学習を回す際、このミニバッチ単位で学習を行います。

実装

事前準備

まずはDatasetを準備しましょう。
今回はscikit-learnのワインデータをDatasetとしてセットします。

import torch

import pandas as pd
from sklearn.datasets import load_wine

wine_df = load_wine()

wine_data = pd.DataFrame(wine_df.data, columns=wine_df.feature_names)
wine_target = pd.DataFrame(wine_df.target, columns=['target'])

class WineDataset(torch.utils.data.Dataset):
  def __init__(self, df, transform=None):
    self.features_values = df.data
    self.labels = df.target
    self.transform = transform
  
  def __len__(self):
    return len(self.features_values)

  def __getitem__(self, idx):
    features_x = torch.FloatTensor(self.features_values[idx])
    labels = torch.LongTensor([self.labels[idx]])
    return features_x, labels

Dataset = WineDataset(wine_df)

そしてDataLoaderの設定です。

DataLoader = torch.utils.data.DataLoader(dataset=Dataset, batch_size=30, shuffle=True)

for i, (features, label) in enumerate(DataLoader):
  features = features
  labels = labels
  print(i)

これだけです。簡単ですね。
DataLoaderがデータをどれくらい分割しているかはfor分のenumerateで取得できます。

それではまた。

タイトルとURLをコピーしました