Skip to main content
Ctrl+K
skrub - Home skrub - Home
  • Install
  • User guide
  • API Reference
  • Examples
    • Learning Materials
    • Release history
    • Development
    • Contributing to skrub
  • GitHub
  • Discord
  • Bluesky
  • X (ex-Twitter)
  • Install
  • User guide
  • API Reference
  • Examples
  • Learning Materials
  • Release history
  • Development
  • Contributing to skrub
  • GitHub
  • Discord
  • Bluesky
  • X (ex-Twitter)

Section Navigation

  • pipeline
    • tabular_learner
    • TableVectorizer
    • Cleaner
    • SelectCols
    • DropCols
    • DropUninformative
  • encoders
    • StringEncoder
    • TextEncoder
    • MinHashEncoder
    • GapEncoder
    • SimilarityEncoder
    • ToCategorical
    • DatetimeEncoder
    • ToDatetime
    • to_datetime
  • reporting
    • TableReport
    • patch_display
    • unpatch_display
    • column_associations
  • cleaning
    • deduplicate
  • joining
    • Joiner
    • AggJoiner
    • MultiAggJoiner
    • AggTarget
    • InterpolationJoiner
    • fuzzy_join
  • selectors
    • all
    • any_date
    • boolean
    • cardinality_below
    • categorical
    • cols
    • filter
    • filter_names
    • float
    • glob
    • has_nulls
    • integer
    • inv
    • make_selector
    • numeric
    • regex
    • select
    • string
  • expressions
    • var
    • X
    • y
    • as_expr
    • deferred
    • Expr
    • choose_bool
    • choose_float
    • choose_int
    • choose_from
    • optional
    • cross_validate
    • eval_mode
    • skrub.Expr.skb.apply
    • skrub.Expr.skb.apply_func
    • skrub.Expr.skb.clone
    • skrub.Expr.skb.concat
    • skrub.Expr.skb.cross_validate
    • skrub.Expr.skb.describe_defaults
    • skrub.Expr.skb.describe_param_grid
    • skrub.Expr.skb.describe_steps
    • skrub.Expr.skb.draw_graph
    • skrub.Expr.skb.drop
    • skrub.Expr.skb.eval
    • skrub.Expr.skb.freeze_after_fit
    • skrub.Expr.skb.full_report
    • skrub.Expr.skb.get_data
    • skrub.Expr.skb.get_pipeline
    • skrub.Expr.skb.get_grid_search
    • skrub.Expr.skb.get_randomized_search
    • skrub.Expr.skb.if_else
    • skrub.Expr.skb.iter_pipelines_grid
    • skrub.Expr.skb.iter_pipelines_randomized
    • skrub.Expr.skb.mark_as_X
    • skrub.Expr.skb.mark_as_y
    • skrub.Expr.skb.match
    • skrub.Expr.skb.preview
    • skrub.Expr.skb.select
    • skrub.Expr.skb.set_description
    • skrub.Expr.skb.set_name
    • skrub.Expr.skb.subsample
    • skrub.Expr.skb.train_test_split
    • skrub.Expr.skb.description
    • skrub.Expr.skb.is_X
    • skrub.Expr.skb.is_y
    • skrub.Expr.skb.name
    • skrub.Expr.skb.applied_estimator
    • SkrubPipeline
    • ParamSearch
  • datasets
    • fetch_bike_sharing
    • fetch_country_happiness
    • fetch_credit_fraud
    • fetch_drug_directory
    • fetch_employee_salaries
    • fetch_flight_delays
    • fetch_ken_embeddings
    • fetch_ken_table_aliases
    • fetch_ken_types
    • fetch_medical_charge
    • fetch_midwest_survey
    • fetch_movielens
    • fetch_open_payments
    • fetch_toxicity
    • fetch_traffic_violations
    • fetch_videogame_sales
    • get_data_dir
    • make_deduplication_data
  • API Reference
  • Expressions
  • skrub.Expr.skb.get_grid_search

skrub.Expr.skb.get_grid_search#

Expr.skb.get_grid_search(*, fitted=False, keep_subsampling=False, **kwargs)[source]#

Find the best parameters with grid search.

This function returns a ParamSearch, an object similar to scikit-learn’s GridSearchCV, where the main difference is that fit() and predict() accept a dictionary of inputs rather than X and y. The best pipeline can be returned by calling .best_pipeline_.

Parameters:
fittedbool (default=False)

If True, the gridsearch is fitted on the data provided when initializing variables in this expression (the data returned by .skb.get_data()).

keep_subsamplingbool (default=False)

If True, and if subsampling has been configured (see Expr.skb.subsample()), fit on a subsample of the data. By default subsampling is not applied and all the data is used. This is only applied for fitting the grid search when fitted=True, subsequent use of the grid search is not affected by subsampling. Therefore it is an error to pass keep_subsampling=True and fitted=False (because keep_subsampling=True would have no effect).

kwargsdict

All other named arguments are forwarded to sklearn.search.GridSearchCV.

Returns:
ParamSearch

An object implementing the hyperparameter search. Besides the usual fit, predict, attributes of interest are

results_, plot_results(), and ``best_pipeline_`.

See also

skrub.Expr.skb.get_randomized_search

Find the best parameters with grid search.

Examples

>>> import skrub
>>> from sklearn.datasets import make_classification
>>> from sklearn.linear_model import LogisticRegression
>>> from sklearn.ensemble import RandomForestClassifier
>>> from sklearn.dummy import DummyClassifier
>>> X_a, y_a = make_classification(random_state=0)
>>> X, y = skrub.X(X_a), skrub.y(y_a)
>>> logistic = LogisticRegression(C=skrub.choose_from([0.1, 10.0], name="C"))
>>> rf = RandomForestClassifier(
...     n_estimators=skrub.choose_from([3, 30], name="N 🌴"),
...     random_state=0,
... )
>>> classifier = skrub.choose_from(
...     {"logistic": logistic, "rf": rf, "dummy": DummyClassifier()}, name="classifier"
... )
>>> pred = X.skb.apply(classifier, y=y)
>>> print(pred.skb.describe_param_grid())
- classifier: 'logistic'
  C: [0.1, 10.0]
- classifier: 'rf'
  N 🌴: [3, 30]
- classifier: 'dummy'
>>> search = pred.skb.get_grid_search(fitted=True)
>>> search.results_
      C   N 🌴 classifier mean_test_score
0   NaN  30.0         rf             0.89
1   0.1   NaN   logistic             0.84
2  10.0   NaN   logistic             0.80
3   NaN   3.0         rf             0.65
4   NaN   NaN      dummy             0.50

If the expression contains some numeric ranges (choose_float, choose_int), either discretize them by providing the n_steps argument or use get_randomized_search instead of get_grid_search.

>>> logistic = LogisticRegression(
...     C=skrub.choose_float(0.1, 10.0, log=True, n_steps=5, name="C")
... )
>>> pred = X.skb.apply(logistic, y=y)
>>> print(pred.skb.describe_param_grid())
- C: choose_float(0.1, 10.0, log=True, n_steps=5, name='C')
>>> search = pred.skb.get_grid_search(fitted=True)
>>> search.results_
    C   mean_test_score
0       0.100000        0.84
1       0.316228        0.83
2       1.000000        0.81
3       3.162278        0.80
4       10.000000       0.80

Please refer to the examples gallery for an in-depth explanation.

previous

skrub.Expr.skb.get_pipeline

next

skrub.Expr.skb.get_randomized_search

On this page
  • Expr.skb.get_grid_search()

This Page

  • Show Source

© Copyright 2018-2023, the dirty_cat developers, 2023-2025, the skrub developers.

Created using Sphinx 8.2.3.

Built with the PyData Sphinx Theme 0.16.1.