Source code for pimmslearn.sklearn
"""Scikit-learn related functions for the project for ALD part.
Might be moved to a separate package in the future.
"""
import logging
from njab.sklearn import run_pca
from sklearn.impute import SimpleImputer
from pimmslearn.io import add_indices
logger = logging.getLogger(__name__)
[docs]
def get_PCA(df, n_components=2, imputer=SimpleImputer):
imputer_ = imputer()
X = imputer_.fit_transform(df)
X = add_indices(X, df)
assert all(X.notna())
PCs, _ = run_pca(X, n_components=n_components)
return PCs