Source code for pimmslearn.io.dataloaders
import pandas
import pandas as pd
from fastai.data.all import *
from fastai.data.core import DataLoaders
from fastai.data.load import DataLoader
from torch.utils.data import Dataset
from pimmslearn.io import datasets
from pimmslearn.io.datasets import DatasetWithTarget
from pimmslearn.transform import VaepPipeline
[docs]
def get_dls(
train_X: pandas.DataFrame,
valid_X: pandas.DataFrame,
transformer: VaepPipeline,
bs: int = 64,
num_workers=0,
) -> DataLoaders:
"""Create training and validation dataloaders
Parameters
----------
train_X : pandas.DataFrame
Training Data, index is ignored for data fetching
valid_X : pandas.DataFrame
Validation data, won't be shuffled.
transformer : VaepPipeline
Pipeline with separate encode and decode
bs : int, optional
batch size, by default 64
num_workers : int, optional
number of workers to use for data loading, by default 0
Returns
-------
fastai.data.core.DataLoaders
FastAI DataLoaders with train and valid Dataloder
Example
-------
import sklearn
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler
from pimmslearn.dataloader import get_dls
from pimmslearn.transform import VaepPipeline
dae_default_pipeline = sklearn.pipeline.Pipeline(
[('normalize', StandardScaler()),
('impute', SimpleImputer(add_indicator=False))
])
# train_X, val_X = None, None # pandas.DataFrames
transforms = VaepPipeline(df_train=train_X,
encode=dae_default_pipeline,
decode=['normalize'])
dls = get_dls(train_X, val_X, transforms, bs=4)
"""
train_ds = datasets.DatasetWithTarget(df=train_X, transformer=transformer)
if valid_X is not None:
valid_ds = datasets.DatasetWithTargetSpecifyTarget(
df=train_X, targets=valid_X, transformer=transformer
)
else:
# empty dataset will be ignored by fastai in training loops
valid_ds = datasets.DatasetWithTarget(df=pd.DataFrame())
# ! Need for script exection (as plain python file)
# https://pytorch.org/docs/stable/notes/windows.html#multiprocessing-error-without-if-clause-protection
drop_last = False
if (len(train_X) % bs) == 1:
# Batch-Normalization does not work with batches of size one
drop_last = True
return DataLoaders.from_dsets(
train_ds, valid_ds, bs=bs, drop_last=drop_last, num_workers=num_workers
)
# dls.test_dl
# needs to be part of setup procedure of a class
[docs]
def get_test_dl(
df: pandas.DataFrame,
transformer: VaepPipeline,
dataset: Dataset = DatasetWithTarget,
bs: int = 64,
):
"""[summary]
Parameters
----------
df : pandas.DataFrame
Test data in a DataFrame
transformer : pimmslearn.transform.VaepPipeline
Pipeline with separate encode and decode
dataset : torch.utils.data.Dataset, optional
torch Dataset to yield encoded samples, by default DatasetWithTarget
bs : int, optional
batch size, by default 64
Returns
-------
fastai.data.load.DataLoader
DataLoader from fastai for test data.
"""
ds = dataset(df, transformer)
return DataLoader(ds, bs=bs, shuffle=False)