skrub.Expr.skb.train_test_split#

Expr.skb.train_test_split(environment=None, splitter=<function train_test_split>, **splitter_kwargs)[source]#

Split an environment into a training an testing environments.

Parameters:
environmentdict, optional

The environment (dict mapping variable names to values) containing the full data. If None (the default), the data is retrieved from the expression.

splitterfunction, optional

The function used to split X and y once they have been computed. By default, sklearn.train_test_split is used.

splitter_kwargs

Additional named arguments to pass to the splitter.

Returns:
dict

The return value is slightly different than scikit-learn’s. Rather than a tuple, it returns a dictionary with the following keys:

  • train: a dictionary containing the training environment

  • test: a dictionary containing the test environment

  • X_train: the value of the variable marked with skb.mark_as_x() in the train environment

  • X_test: the value of the variable marked with skb.mark_as_x() in the test environment

  • y_train: the value of the variable marked with skb.mark_as_y() in the train environment, if there is one (may not be the case for unsupervised learning).

  • y_test: the value of the variable marked with skb.mark_as_y() in the test environment, if there is one (may not be the case for unsupervised learning).

Examples

>>> import skrub
>>> from sklearn.dummy import DummyClassifier
>>> from sklearn.metrics import accuracy_score
>>> orders = skrub.var("orders", skrub.toy_orders().orders)
>>> X = orders.skb.drop("delayed").skb.mark_as_X()
>>> y = orders["delayed"].skb.mark_as_y()
>>> delayed = X.skb.apply(skrub.TableVectorizer()).skb.apply(
...     DummyClassifier(), y=y
... )
>>> split = delayed.skb.train_test_split(random_state=0)
>>> split.keys()
dict_keys(['train', 'test', 'X_train', 'X_test', 'y_train', 'y_test'])
>>> pipeline = delayed.skb.get_pipeline()
>>> pipeline.fit(split["train"])
SkrubPipeline(expr=<Apply DummyClassifier>)
>>> pipeline.score(split["test"])
0.0
>>> predictions = pipeline.predict(split["test"])
>>> accuracy_score(split["y_test"], predictions)
0.0