skrub.Expr.skb.get_grid_search#
- Expr.skb.get_grid_search(*, fitted=False, keep_subsampling=False, **kwargs)[source]#
Find the best parameters with grid search.
This function returns a
ParamSearch
, an object similar to scikit-learn’sGridSearchCV
. The main difference is that methods such asfit()
andpredict()
accept a dictionary of inputs rather thanX
andy
. Please refer to the examples gallery for an in-depth explanation.If the expression contains some numeric ranges (
choose_float
,choose_int
), either discretize them by providing then_steps
argument or useget_randomized_search
instead ofget_grid_search
.- Parameters:
- fitted
bool
(default=False) If
True
, the gridsearch is fitted on the data provided when initializing variables in this expression (the data returned by.skb.get_data()
).- keep_subsampling
bool
(default=False) If True, and if subsampling has been configured (see
Expr.skb.subsample()
), fit on a subsample of the data. By default subsampling is not applied and all the data is used. This is only applied for fitting the grid search whenfitted=True
, subsequent use of the grid search is not affected by subsampling. Therefore it is an error to passkeep_subsampling=True
andfitted=False
(becausekeep_subsampling=True
would have no effect).- kwargs
dict
All other named arguments are forwarded to
sklearn.search.GridSearchCV
.
- fitted
- Returns:
- ParamSearch
An object implementing the hyperparameter search. Besides the usual
fit
,predict
, attributes of interest areresults_
andplot_results()
.
See also
skrub.Expr.skb.get_randomized_search
Find the best parameters with grid search.
Examples
>>> import skrub >>> from sklearn.datasets import make_classification >>> from sklearn.linear_model import LogisticRegression >>> from sklearn.ensemble import RandomForestClassifier >>> from sklearn.dummy import DummyClassifier
>>> X_a, y_a = make_classification(random_state=0) >>> X, y = skrub.X(X_a), skrub.y(y_a) >>> logistic = LogisticRegression(C=skrub.choose_from([0.1, 10.0], name="C")) >>> rf = RandomForestClassifier( ... n_estimators=skrub.choose_from([3, 30], name="N 🌴"), ... random_state=0, ... ) >>> classifier = skrub.choose_from( ... {"logistic": logistic, "rf": rf, "dummy": DummyClassifier()}, name="classifier" ... ) >>> pred = X.skb.apply(classifier, y=y) >>> print(pred.skb.describe_param_grid()) - classifier: 'logistic' C: [0.1, 10.0] - classifier: 'rf' N 🌴: [3, 30] - classifier: 'dummy'
>>> search = pred.skb.get_grid_search(fitted=True) >>> search.results_ C N 🌴 classifier mean_test_score 0 NaN 30.0 rf 0.89 1 0.1 NaN logistic 0.84 2 10.0 NaN logistic 0.80 3 NaN 3.0 rf 0.65 4 NaN NaN dummy 0.50