Source code for pyrolite.util.skl.select
import pandas as pd
from ...geochem.ind import REE, _common_elements, _common_oxides
from ..log import Handle
logger = Handle(__name__)
try:
from sklearn.base import BaseEstimator, TransformerMixin
except ImportError:
msg = "scikit-learn not installed"
logger.warning(msg)
[docs]class TypeSelector(BaseEstimator, TransformerMixin):
def __init__(self, dtype):
"""Select specific data types from a dataframe for further transformation."""
self.dtype = dtype
[docs] def transform(self, X):
assert isinstance(X, pd.DataFrame)
return X.select_dtypes(include=[self.dtype])
[docs]class ColumnSelector(BaseEstimator, TransformerMixin):
def __init__(self, columns):
"""Select specific columns from a dataframe for further transformation."""
self.columns = columns
[docs] def transform(self, X):
assert isinstance(X, pd.DataFrame)
try:
return X.loc[:, self.columns]
except KeyError:
cols_error = list(set(self.columns) - set(X.columns))
raise KeyError(
"The DataFrame does not include the columns: %s" % cols_error
)
[docs]class CompositionalSelector(BaseEstimator, TransformerMixin):
def __init__(self, columns=None, inverse=False):
"""Select the oxide and element components from a dataframe."""
if columns is None:
columns = _common_elements | _common_oxides
self.columns = columns
self.inverse = inverse
[docs] def transform(self, X):
assert isinstance(X, pd.DataFrame)
if self.inverse:
out_cols = [i for i in X.columns if i not in self.columns]
else:
out_cols = [i for i in X.columns if i in self.columns]
out = X.loc[:, out_cols]
return out
[docs]class MajorsSelector(BaseEstimator, TransformerMixin):
def __init__(self, components=None):
"""Select the major element oxides from a dataframe."""
if components is None:
components = _common_oxides
self.columns = components
[docs] def transform(self, X):
assert isinstance(X, pd.DataFrame)
out_cols = [i for i in X.columns if i in self.columns]
out = X.loc[:, out_cols]
return out
[docs]class ElementSelector(BaseEstimator, TransformerMixin):
def __init__(self, components=None):
"""Select the (trace) elements from a dataframe."""
if components is None:
components = _common_elements
self.columns = components
[docs] def transform(self, X):
assert isinstance(X, pd.DataFrame)
out_cols = [i for i in X.columns if i in self.columns]
out = X.loc[:, out_cols]
return out
[docs]class REESelector(BaseEstimator, TransformerMixin):
def __init__(self, components=None):
"""Select the Rare Earth Elements (REE) from a dataframe."""
if components is None:
components = REE()
components = [i for i in components if not i == "Pm"]
self.columns = components
[docs] def transform(self, X):
assert isinstance(X, pd.DataFrame)
out_cols = [i for i in self.columns if i in X.columns]
out = X.loc[:, out_cols]
return out