SklearnTrainConfig¶
SklearnTrainConfig
¶
Configuration for scikit-learn training parameters.
This class encapsulates the estimator to be trained. Additional keyword
arguments for the estimator's fit()
method can be passed as attributes
of this configuration.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
estimator
|
Predictable | ProbPredictable
|
A scikit-learn estimator instance that implements the fit method. Must also implement either predict or predict_proba method. |
required |
See Also
sklearn.base.BaseEstimator : The base class for all scikit-learn estimators.
Examples:
>>> from sklearn.ensemble import RandomForestRegressor
>>> from factrainer.sklearn import SklearnTrainConfig
>>> # Basic configuration
>>> estimator = RandomForestRegressor(n_estimators=100, random_state=42)
>>> config = SklearnTrainConfig(estimator=estimator)
>>> # With additional fit keyword arguments
>>> import numpy as np
>>> from sklearn.linear_model import SGDClassifier
>>> sample_weights = np.array([1, 2, 1, 1, 2])
>>> estimator = SGDClassifier()
>>> config = SklearnTrainConfig(
... estimator=estimator,
... sample_weight=sample_weights, # passed as kwargs to fit()
... )