Note
Go to the end to download the full example code. or to run this example in your browser via JupyterLite or Binder
Hyperparameter tuning with DataOps#
A machine-learning pipeline typically contains some values or choices which
may influence its prediction performance, such as hyperparameters (e.g. the
regularization parameter alpha
of a RidgeClassifier
,
the learning_rate
of a HistGradientBoostingClassifier
),
which estimator to use (e.g. RidgeClassifier
or HistGradientBoostingClassifier
),
or which steps to include (e.g. should we join a table to bring additional information
or not).
We want to tune those choices by trying several options and keeping those that give the best performance on a validation set.
Skrub DataOps provide a convenient way to specify the range of possible values, by inserting it directly in place of the actual value. For example we can write:
from sklearn.linear_model import RidgeClassifier
import skrub
RidgeClassifier(alpha=skrub.choose_from([0.1, 1.0, 10.0], name="α"))
RidgeClassifier(alpha=choose_from([0.1, 1.0, 10.0], name='α'))In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Parameters
alpha | choose_from([....0], name='α') | |
fit_intercept | True | |
copy_X | True | |
max_iter | None | |
tol | 0.0001 | |
class_weight | None | |
solver | 'auto' | |
positive | False | |
random_state | None |
instead of:
RidgeClassifier(alpha=1.0)
RidgeClassifier()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Parameters
alpha | 1.0 | |
fit_intercept | True | |
copy_X | True | |
max_iter | None | |
tol | 0.0001 | |
class_weight | None | |
solver | 'auto' | |
positive | False | |
random_state | None |
Skrub then inspects our DataOps plan to discover all the places where we used objects
like choose_from()
and builds a grid of hyperparameters for us.
We will illustrate hyperparameter tuning on the “toxicity” dataset. This dataset contains 1,000 texts and the task is to predict if they are flagged as being toxic or not.
We start from a very simple pipeline without any hyperparameters.
from sklearn.ensemble import HistGradientBoostingClassifier
import skrub
import skrub.datasets
data = skrub.datasets.fetch_toxicity().toxicity
# This dataset is sorted -- all toxic tweets appear first, so we shuffle it
data = data.sample(frac=1.0, random_state=1)
texts = data[["text"]]
labels = data["is_toxic"]
We mark the texts
column as the input variable and the labels
column as
the target variable.
See the previous example for a more detailed explanation
of skrub.X()
and skrub.y()
.
We then encode the text with a MinHashEncoder
and fit a
HistGradientBoostingClassifier
on the resulting features.
Show graph
text | |
---|---|
507 | The ekko middle finger makes me giggle like a fucking idiot every single time I see it. |
818 | Where's a Casper when you need one? |
452 | Let’s just bow down to the person who censored it all for the young viewers. Maybe the person’s soul rest in peace and honestly, fuck logan paul. |
368 | Paging Kyle Rittenhouse, Kyle please drive across state and international borders. I think some kinda militia wants you to take an AR-15 type weapon and conduct abatement activities of a gang of Murican missionaries. Also don’t forget to tell authorities you fear for your life so abating gangsters is not victimizing peeps. |
242 | I’m with you on the ridiculousness that Chris brings to the show. Chris is not in a place to call anyone UGLY. He thinks he’s a 10 when he needs to subtract 8. He’s a clown and we’re just gonna go ahead and name him “BOO BOO THE F’N FOOL”. |
767 | That's one of the things that I'm really proud of punx for. Never gave up the tradition of NeoNazi bashing. |
72 | To.. Biden\Hitler and administration... GTFO OF OUR WHITE HOUSE..P.O.S MTHR FCKRS.. |
908 | Smite jinx... int or pentakill? Both. |
235 | Amen. Woke self-righteous leftists are simply ugly humans to the core. |
37 | #Trumptards wanna spread the word of their "great" lard and savior, Donald Trump, to "WAKE UP" the people. |
text
ObjectDType- Null values
- 0 (0.0%)
- Unique values
-
999 (99.9%)
This column has a high cardinality (> 40).
No columns match the selected filter: . You can change the column filter in the dropdown menu above.
Column
|
Column name
|
dtype
|
Is sorted
|
Null values
|
Unique values
|
Mean
|
Std
|
Min
|
Median
|
Max
|
---|---|---|---|---|---|---|---|---|---|---|
0 | text | ObjectDType | False | 0 (0.0%) | 999 (99.9%) |
No columns match the selected filter: . You can change the column filter in the dropdown menu above.
Please enable javascript
The skrub table reports need javascript to display correctly. If you are displaying a report in a Jupyter notebook and you see this message, you may need to re-execute the cell or to trust the notebook (button on the top right or "File > Trust notebook").
Show graph
is_toxic | |
---|---|
507 | Not Toxic |
818 | Not Toxic |
452 | Toxic |
368 | Toxic |
242 | Toxic |
767 | Not Toxic |
72 | Toxic |
908 | Not Toxic |
235 | Toxic |
37 | Toxic |
is_toxic
ObjectDType- Null values
- 0 (0.0%)
- Unique values
- 2 (0.2%)
No columns match the selected filter: . You can change the column filter in the dropdown menu above.
Column
|
Column name
|
dtype
|
Is sorted
|
Null values
|
Unique values
|
Mean
|
Std
|
Min
|
Median
|
Max
|
---|---|---|---|---|---|---|---|---|---|---|
0 | is_toxic | ObjectDType | False | 0 (0.0%) | 2 (0.2%) |
No columns match the selected filter: . You can change the column filter in the dropdown menu above.
Please enable javascript
The skrub table reports need javascript to display correctly. If you are displaying a report in a Jupyter notebook and you see this message, you may need to re-execute the cell or to trust the notebook (button on the top right or "File > Trust notebook").
pred = X.skb.apply(skrub.MinHashEncoder()).skb.apply(
HistGradientBoostingClassifier(), y=y
)
pred.skb.cross_validate(n_jobs=4)["test_score"]
0 0.635
1 0.590
2 0.645
3 0.595
4 0.585
Name: test_score, dtype: float64
In this example, we will focus on the n_components
of the
MinHashEncoder
and the learning_rate
of the HistGradientBoostingClassifier
to illustrate the choices objects.
When we use a scikit-learn hyperparameter-tuner like
GridSearchCV
or
RandomizedSearchCV
, we need to specify a grid of
hyperparameters separately from the estimator, with something similar to
GridSearchCV(my_pipeline, param_grid={"encoder__n_components: [5, 10, 20]"})
.
Instead, within a skrub DataOps plan we can use
skrub.choose_from(...)
directly where the actual value
would normally go. Skrub then takes care of constructing the
GridSearchCV
’s parameter grid for us.
Note that skrub.choose_float()
and skrub.choose_int()
can be given a
log
argument to sample in log scale, and that it is possible to specify the
number of steps with the n_steps
argument.
X, y = skrub.X(texts), skrub.y(labels)
encoder = skrub.MinHashEncoder(
n_components=skrub.choose_int(5, 15, n_steps=5, name="N components")
)
classifier = HistGradientBoostingClassifier(
learning_rate=skrub.choose_float(0.01, 0.9, log=True, name="lr")
)
pred = X.skb.apply(encoder).skb.apply(classifier, y=y)
From here, the pred
DataOp can be used to perform hyperparameter search with
.skb.make_grid_search()
or .skb.make_randomized_search()
. They accept
the same arguments as their scikit-learn counterparts (e.g. scoring
, cv
,
n_jobs
). Also, like .skb.make_learner()
, they accept a fitted
argument: if``fitted=True``, the search is fitted on the data we provided
when initializing our pipeline’s variables.
search = pred.skb.make_randomized_search(
n_iter=8, n_jobs=4, random_state=1, fitted=True
)
search.results_
N components | lr | mean_test_score | |
---|---|---|---|
0 | 15 | 0.059618 | 0.595 |
1 | 12 | 0.255675 | 0.593 |
2 | 15 | 0.520061 | 0.585 |
3 | 15 | 0.057289 | 0.576 |
4 | 10 | 0.105948 | 0.570 |
5 | 10 | 0.450710 | 0.551 |
6 | 8 | 0.038979 | 0.547 |
7 | 5 | 0.015151 | 0.533 |
If the plotly library is installed, we can visualize the results of the
hyperparameter search with plot_results()
.
In the plot below, each line represents a combination of hyperparameters (in
this case, only N components
and learning rate
), and each column of
points represents either a hyperparameter, or the score of a given
combination of hyperparameters.
The color of the line represents the score of the combination of hyperparameters. The plot is interactive, and it is possible to select only a subset of the hyperparameters to visualize by dragging the mouse over each column to select the desired range.
This is particularly useful when there are many combinations of hyperparameters, and we are interested in understanding which hyperparameters have the largest impact on the score.
Finally, we can retrieve the best learner from the search results, and save it to disk. This learner will contain the best hyperparameter configuration found during the search, and can be used to make predictions on new data.
import pickle
best_learner = search.best_learner_
saved_model = pickle.dumps(best_learner)
Default choice values#
The goal of using the different choose_*
functions is to tune choices on
validation metrics with randomized or grid search. However, even when our
expression contains such choices we can still use it without tuning, for
example in previews or to get a quick first result before spending the
computation time to run the search. When we use .skb.make_learner()
, we get a pipeline that does not perform any tuning
and uses those default values. This default pipeline is used for
.skb.eval()
.
We can control what should be the default value for each choice. For
choose_int()
, choose_float()
and choose_bool()
, we can use
the default
parameter. For choose_from()
, the default is the first
item from the list or dict of outcomes we provide. For optional()
, we
can pass default=None
to force the default to be the alternative
outcome, None
.
When we do not set an explicit default, skrub picks one for depending on the kind of choice, as detailed in this table in the User Guide.
As mentioned we can control the default value:
skrub.choose_float(1.0, 100.0, default=12.0).default()
12.0
Choices can appear in many places#
Choices are not limited to selecting estimator hyperparameters. They can also be used to choose between different estimators, or in place of any value used in our pipeline.
For example, here we pass a choice to pandas DataFrame’s assign
method.
We want to add a feature that captures the length of the text, but we are not
sure if it is better to count length in characters or in words. We do not
want to add both because it would be redundant. We can add a column to the
dataframe, which will be chosen among the length in characters or the length
in words:
Show graph
text | length | |
---|---|---|
507 | The ekko middle finger makes me giggle like a fucking idiot every single time I see it. | 17 |
818 | Where's a Casper when you need one? | 8 |
452 | Let’s just bow down to the person who censored it all for the young viewers. Maybe the person’s soul rest in peace and honestly, fuck logan paul. | 29 |
368 | Paging Kyle Rittenhouse, Kyle please drive across state and international borders. I think some kinda militia wants you to take an AR-15 type weapon and conduct abatement activities of a gang of Murican missionaries. Also don’t forget to tell authorities you fear for your life so abating gangsters is not victimizing peeps. | 54 |
242 | I’m with you on the ridiculousness that Chris brings to the show. Chris is not in a place to call anyone UGLY. He thinks he’s a 10 when he needs to subtract 8. He’s a clown and we’re just gonna go ahead and name him “BOO BOO THE F’N FOOL”. | 55 |
767 | That's one of the things that I'm really proud of punx for. Never gave up the tradition of NeoNazi bashing. | 22 |
72 | To.. Biden\Hitler and administration... GTFO OF OUR WHITE HOUSE..P.O.S MTHR FCKRS.. | 15 |
908 | Smite jinx... int or pentakill? Both. | 6 |
235 | Amen. Woke self-righteous leftists are simply ugly humans to the core. | 12 |
37 | #Trumptards wanna spread the word of their "great" lard and savior, Donald Trump, to "WAKE UP" the people. | 18 |
text
ObjectDType- Null values
- 0 (0.0%)
- Unique values
-
999 (99.9%)
This column has a high cardinality (> 40).
length
Int64DType- Null values
- 0 (0.0%)
- Unique values
-
108 (10.8%)
This column has a high cardinality (> 40).
- Mean ± Std
- 25.3 ± 31.6
- Median ± IQR
- 17 ± 19
- Min | Max
- 1 | 365
No columns match the selected filter: . You can change the column filter in the dropdown menu above.
Column
|
Column name
|
dtype
|
Is sorted
|
Null values
|
Unique values
|
Mean
|
Std
|
Min
|
Median
|
Max
|
---|---|---|---|---|---|---|---|---|---|---|
0 | text | ObjectDType | False | 0 (0.0%) | 999 (99.9%) | |||||
1 | length | Int64DType | False | 0 (0.0%) | 108 (10.8%) | 25.3 | 31.6 | 1 | 17 | 365 |
No columns match the selected filter: . You can change the column filter in the dropdown menu above.
Please enable javascript
The skrub table reports need javascript to display correctly. If you are displaying a report in a Jupyter notebook and you see this message, you may need to re-execute the cell or to trust the notebook (button on the top right or "File > Trust notebook").
choose_from
can be given a dictionary if we want to provide
names for the individual outcomes, or a list, when names are not needed:
choose_from([1, 100], name='N')
,
choose_from({'small': 1, 'big': 100}, name='N')
.
Choices can be nested arbitrarily. For example, here we want to choose
between 2 possible encoder types: the MinHashEncoder
or the
StringEncoder
. Each of the possible outcomes contains a choice itself:
the number of components.
X, y = skrub.X(texts), skrub.y(labels)
n_components = skrub.choose_int(5, 15, name="N components")
encoder = skrub.choose_from(
{
"minhash": skrub.MinHashEncoder(n_components=n_components),
"lse": skrub.StringEncoder(n_components=n_components),
},
name="encoder",
)
X.skb.apply(encoder, cols="text")
Show graph
text_0 | text_1 | text_2 | text_3 | text_4 | text_5 | text_6 | text_7 | text_8 | text_9 | |
---|---|---|---|---|---|---|---|---|---|---|
507 | -2.14e+09 | -2.09e+09 | -2.11e+09 | -2.13e+09 | -2.13e+09 | -2.14e+09 | -2.12e+09 | -2.12e+09 | -2.07e+09 | -2.12e+09 |
818 | -2.11e+09 | -2.13e+09 | -2.07e+09 | -1.97e+09 | -2.12e+09 | -2.11e+09 | -2.04e+09 | -2.06e+09 | -2.14e+09 | -2.09e+09 |
452 | -2.14e+09 | -2.13e+09 | -2.15e+09 | -2.13e+09 | -2.14e+09 | -2.13e+09 | -2.11e+09 | -2.12e+09 | -2.14e+09 | -2.12e+09 |
368 | -2.14e+09 | -2.14e+09 | -2.15e+09 | -2.13e+09 | -2.14e+09 | -2.14e+09 | -2.11e+09 | -2.15e+09 | -2.13e+09 | -2.14e+09 |
242 | -2.14e+09 | -2.14e+09 | -2.15e+09 | -2.13e+09 | -2.14e+09 | -2.14e+09 | -2.14e+09 | -2.15e+09 | -2.14e+09 | -2.14e+09 |
767 | -2.14e+09 | -2.14e+09 | -2.12e+09 | -2.09e+09 | -2.12e+09 | -2.13e+09 | -2.13e+09 | -2.15e+09 | -2.13e+09 | -2.13e+09 |
72 | -2.13e+09 | -2.13e+09 | -2.12e+09 | -2.12e+09 | -2.12e+09 | -2.11e+09 | -2.14e+09 | -2.14e+09 | -2.13e+09 | -2.13e+09 |
908 | -2.10e+09 | -2.07e+09 | -2.12e+09 | -2.11e+09 | -2.10e+09 | -2.07e+09 | -2.08e+09 | -2.13e+09 | -2.07e+09 | -2.14e+09 |
235 | -2.12e+09 | -2.14e+09 | -2.11e+09 | -2.14e+09 | -2.10e+09 | -2.14e+09 | -2.14e+09 | -2.10e+09 | -2.13e+09 | -2.13e+09 |
37 | -2.13e+09 | -2.14e+09 | -2.12e+09 | -2.10e+09 | -2.14e+09 | -2.14e+09 | -2.13e+09 | -2.12e+09 | -2.12e+09 | -2.13e+09 |
text_0
Float32DType- Null values
- 0 (0.0%)
- Unique values
-
140 (14.0%)
This column has a high cardinality (> 40).
- Mean ± Std
- -2.12e+09 ± 5.98e+07
- Median ± IQR
- -2.14e+09 ± 1.43e+07
- Min | Max
- -2.15e+09 | -9.84e+08
text_1
Float32DType- Null values
- 0 (0.0%)
- Unique values
-
124 (12.4%)
This column has a high cardinality (> 40).
- Mean ± Std
- -2.12e+09 ± 5.10e+07
- Median ± IQR
- -2.14e+09 ± 1.90e+07
- Min | Max
- -2.15e+09 | -9.12e+08
text_2
Float32DType- Null values
- 0 (0.0%)
- Unique values
-
167 (16.7%)
This column has a high cardinality (> 40).
- Mean ± Std
- -2.10e+09 ± 8.17e+07
- Median ± IQR
- -2.12e+09 ± 5.15e+07
- Min | Max
- -2.15e+09 | -9.82e+08
text_3
Float32DType- Null values
- 0 (0.0%)
- Unique values
-
214 (21.4%)
This column has a high cardinality (> 40).
- Mean ± Std
- -2.09e+09 ± 1.04e+08
- Median ± IQR
- -2.11e+09 ± 4.98e+07
- Min | Max
- -2.15e+09 | 2.15e+08
text_4
Float32DType- Null values
- 0 (0.0%)
- Unique values
-
119 (11.9%)
This column has a high cardinality (> 40).
- Mean ± Std
- -2.12e+09 ± 6.15e+07
- Median ± IQR
- -2.13e+09 ± 2.59e+07
- Min | Max
- -2.15e+09 | -1.26e+09
text_5
Float32DType- Null values
- 0 (0.0%)
- Unique values
-
135 (13.5%)
This column has a high cardinality (> 40).
- Mean ± Std
- -2.12e+09 ± 5.20e+07
- Median ± IQR
- -2.13e+09 ± 6.47e+06
- Min | Max
- -2.15e+09 | -1.60e+09
text_6
Float32DType- Null values
- 0 (0.0%)
- Unique values
-
190 (19.0%)
This column has a high cardinality (> 40).
- Mean ± Std
- -2.10e+09 ± 8.84e+07
- Median ± IQR
- -2.12e+09 ± 4.06e+07
- Min | Max
- -2.15e+09 | -2.45e+08
text_7
Float32DType- Null values
- 0 (0.0%)
- Unique values
-
201 (20.1%)
This column has a high cardinality (> 40).
- Mean ± Std
- -2.10e+09 ± 6.88e+07
- Median ± IQR
- -2.11e+09 ± 5.52e+07
- Min | Max
- -2.15e+09 | -1.35e+09
text_8
Float32DType- Null values
- 0 (0.0%)
- Unique values
-
153 (15.3%)
This column has a high cardinality (> 40).
- Mean ± Std
- -2.11e+09 ± 7.98e+07
- Median ± IQR
- -2.13e+09 ± 2.08e+07
- Min | Max
- -2.15e+09 | -5.95e+08
text_9
Float32DType- Null values
- 0 (0.0%)
- Unique values
-
175 (17.5%)
This column has a high cardinality (> 40).
- Mean ± Std
- -2.11e+09 ± 9.59e+07
- Median ± IQR
- -2.13e+09 ± 2.53e+07
- Min | Max
- -2.15e+09 | -4.80e+08
No columns match the selected filter: . You can change the column filter in the dropdown menu above.
Column
|
Column name
|
dtype
|
Is sorted
|
Null values
|
Unique values
|
Mean
|
Std
|
Min
|
Median
|
Max
|
---|---|---|---|---|---|---|---|---|---|---|
0 | text_0 | Float32DType | False | 0 (0.0%) | 140 (14.0%) | -2.12e+09 | 5.98e+07 | -2.15e+09 | -2.14e+09 | -9.84e+08 |
1 | text_1 | Float32DType | False | 0 (0.0%) | 124 (12.4%) | -2.12e+09 | 5.10e+07 | -2.15e+09 | -2.14e+09 | -9.12e+08 |
2 | text_2 | Float32DType | False | 0 (0.0%) | 167 (16.7%) | -2.10e+09 | 8.17e+07 | -2.15e+09 | -2.12e+09 | -9.82e+08 |
3 | text_3 | Float32DType | False | 0 (0.0%) | 214 (21.4%) | -2.09e+09 | 1.04e+08 | -2.15e+09 | -2.11e+09 | 2.15e+08 |
4 | text_4 | Float32DType | False | 0 (0.0%) | 119 (11.9%) | -2.12e+09 | 6.15e+07 | -2.15e+09 | -2.13e+09 | -1.26e+09 |
5 | text_5 | Float32DType | False | 0 (0.0%) | 135 (13.5%) | -2.12e+09 | 5.20e+07 | -2.15e+09 | -2.13e+09 | -1.60e+09 |
6 | text_6 | Float32DType | False | 0 (0.0%) | 190 (19.0%) | -2.10e+09 | 8.84e+07 | -2.15e+09 | -2.12e+09 | -2.45e+08 |
7 | text_7 | Float32DType | False | 0 (0.0%) | 201 (20.1%) | -2.10e+09 | 6.88e+07 | -2.15e+09 | -2.11e+09 | -1.35e+09 |
8 | text_8 | Float32DType | False | 0 (0.0%) | 153 (15.3%) | -2.11e+09 | 7.98e+07 | -2.15e+09 | -2.13e+09 | -5.95e+08 |
9 | text_9 | Float32DType | False | 0 (0.0%) | 175 (17.5%) | -2.11e+09 | 9.59e+07 | -2.15e+09 | -2.13e+09 | -4.80e+08 |
No columns match the selected filter: . You can change the column filter in the dropdown menu above.
Please enable javascript
The skrub table reports need javascript to display correctly. If you are displaying a report in a Jupyter notebook and you see this message, you may need to re-execute the cell or to trust the notebook (button on the top right or "File > Trust notebook").
In a similar vein, we might want to choose between a HGB classifier and a Ridge classifier, each with its own set of hyperparameters. We can then define a choice for the classifier and a choice for the hyperparameters of each classifier.
from sklearn.linear_model import RidgeClassifier
hgb = HistGradientBoostingClassifier(
learning_rate=skrub.choose_float(0.01, 0.9, log=True, name="lr")
)
ridge = RidgeClassifier(alpha=skrub.choose_float(0.01, 100, log=True, name="α"))
classifier = skrub.choose_from({"hgb": hgb, "ridge": ridge}, name="classifier")
pred = X.skb.apply(encoder).skb.apply(classifier, y=y)
print(pred.skb.describe_param_grid())
- encoder: 'minhash'
N components: choose_int(5, 15, name='N components')
classifier: 'hgb'
lr: choose_float(0.01, 0.9, log=True, name='lr')
- encoder: 'minhash'
N components: choose_int(5, 15, name='N components')
classifier: 'ridge'
α: choose_float(0.01, 100, log=True, name='α')
- encoder: 'lse'
N components: choose_int(5, 15, name='N components')
classifier: 'hgb'
lr: choose_float(0.01, 0.9, log=True, name='lr')
- encoder: 'lse'
N components: choose_int(5, 15, name='N components')
classifier: 'ridge'
α: choose_float(0.01, 100, log=True, name='α')
search = pred.skb.make_randomized_search(
n_iter=16, n_jobs=4, random_state=1, fitted=True
)
search.plot_results()
Now that we have a more complex plan, we can draw more conclusions from the
parallel coordinate plot. For example, we can see that the
HistGradientBoostingClassifier
performs better than the RidgeClassifier
in most cases, that the StringEncoder
outperforms the MinHashEncoder
, and that the choice of the additional length
feature does not have a significant impact on the score.
Concluding, we have seen how to use skrub’s choose_from
objects to tune
hyperparameters, choose optional configurations, and nest choices. We then
looked at how the different choices affect the plan and the prediction
scores.
There is more to say about skrub choices than what is covered in this
example. In particular, choices are not limited to choosing estimators and
their hyperparameters: they can be used anywhere DataOps are used,
such as the argument of a deferred()
function, or the argument of
other DataOps’ method or operator. Finally, choices can be
inter-dependent. Please find more information in the user guide.
Total running time of the script: (0 minutes 41.017 seconds)