Google Colaboratory でPytorchのDatasetについて試してみる

本サイトは広告収入およびアフィリエイト収益を受けております。

スポンサーリンク
スポンサーリンク

お疲れ様です。きざきまるおです。

PytorchのDatasetについて色々調査してたのですが、腑に落ちるものがなかったため、こちらについて簡潔に書いていこうと思います。

それではどうぞ。

Pytorch Datasetとは

PytorchでのDatasetとは特徴量(答えを出すために必要な情報)とラベル(答え)をまとめて定義するクラスになります。

こちらで定義したものをDataLoader(後日まとめます)に渡すことでデータの一部をランダムに取得して学習させることができます。

では早速実践してみましょう。

必要ライブラリのインポート

以下ライブラリをインポートしましょう。

import torch

import pandas as pd
from sklearn.datasets import load_wine

「import torch」がPytorchライブラリです。そのほか2つは今回Datasetを確認するために使用するサンプルデータをダウンロード/加工するためにインポートします。

サンプルデータ

次にサンプルデータを定義しましょう。

wine_df = load_wine()

wine_data = pd.DataFrame(wine_df.data, columns=wine_df.feature_names)
wine_target = pd.DataFrame(wine_df.target, columns=['target'])

scikit-learnに内包されているワインデータを使用します。
特徴量が「wine_data」でラベルが「wine_dataset」になります。

Dataset

では早速Detasetのクラスを作成してみましょう。

class WineDataset(torch.utils.data.Dataset):
  def __init__(self, df, transform=None):
    self.features_values = df.data
    self.labels = df.target
  
  def __len__(self):
    return len(self.features_values)

  def __getitem__(self, idx):
    features_x = torch.FloatTensor(self.features_values[idx])
    labels = torch.LongTensor([self.labels[idx]])
    return features_x, labels

まず、前提条件として、Datasetでは「__init__」「__len__」「__getitem__」を関数として定義する必要があります。
それぞれの関数が何を意味しているのか一つ一つ説明していきます。

__init__

クラスが呼び出された際に最初に実行される部分です。
クラスに渡された変数をクラス内で初期化します。

__len__

Datasetのデータ量を定義します。
len関数で定義することで柔軟に対応できるかと思います。

__getitem__

Datasetに定義されたデータを返すために加工する部分です。
こちらでtensor型へ変更して最終的にDataLoaderへ帰す値を定義します。

Datasetオブジェクト呼び出し

最後にDatasetオブジェクトを呼び出して結果がどうなるか見てみましょう。

Dataset = WineDataset(wine_df)

feature_data, label_data = Dataset[0]
print(feature_data, label_data)

結果

tensor([1.4230e+01, 1.7100e+00, 2.4300e+00, 1.5600e+01, 1.2700e+02, 2.8000e+00,
        3.0600e+00, 2.8000e-01, 2.2900e+00, 5.6400e+00, 1.0400e+00, 3.9200e+00,
        1.0650e+03]) tensor([0])

はい。ということで無事にDataLoaderへ渡せる形に加工されましたね。

Datasetを調査してみて難しい文章が多く、理解をするのに苦労されている方は結構いるのではないでしょうか?
そんな方々の助けになれればうれしいです。

それではまた。

タイトルとURLをコピーしました