Skip to content

SklearnTrainConfig

SklearnTrainConfig

Configuration for scikit-learn training parameters.

This class encapsulates the estimator to be trained. Additional keyword arguments for the estimator's fit() method can be passed as attributes of this configuration.

Parameters:

Name Type Description Default
estimator Predictable | ProbPredictable

A scikit-learn estimator instance that implements the fit method. Must also implement either predict or predict_proba method.

required
See Also

sklearn.base.BaseEstimator : The base class for all scikit-learn estimators.

Examples:

>>> from sklearn.ensemble import RandomForestRegressor
>>> from factrainer.sklearn import SklearnTrainConfig
>>> # Basic configuration
>>> estimator = RandomForestRegressor(n_estimators=100, random_state=42)
>>> config = SklearnTrainConfig(estimator=estimator)
>>> # With additional fit keyword arguments
>>> import numpy as np
>>> from sklearn.linear_model import SGDClassifier
>>> sample_weights = np.array([1, 2, 1, 1, 2])
>>> estimator = SGDClassifier()
>>> config = SklearnTrainConfig(
...     estimator=estimator,
...     sample_weight=sample_weights,  # passed as kwargs to fit()
... )

Attributes

estimator instance-attribute

estimator: Predictable | ProbPredictable