skrub.Expr.skb.train_test_split#
- Expr.skb.train_test_split(environment=None, splitter=<function train_test_split>, **splitter_kwargs)[source]#
Split an environment into a training an testing environments.
- Parameters:
- environment
dict
, optional The environment (dict mapping variable names to values) containing the full data. If
None
(the default), the data is retrieved from the expression.- splitterfunction, optional
The function used to split X and y once they have been computed. By default,
sklearn.train_test_split
is used.- splitter_kwargs
Additional named arguments to pass to the splitter.
- environment
- Returns:
dict
The return value is slightly different than scikit-learn’s. Rather than a tuple, it returns a dictionary with the following keys:
train: a dictionary containing the training environment
test: a dictionary containing the test environment
X_train: the value of the variable marked with
skb.mark_as_x()
in the train environmentX_test: the value of the variable marked with
skb.mark_as_x()
in the test environmenty_train: the value of the variable marked with
skb.mark_as_y()
in the train environment, if there is one (may not be the case for unsupervised learning).y_test: the value of the variable marked with
skb.mark_as_y()
in the test environment, if there is one (may not be the case for unsupervised learning).
Examples
>>> import skrub >>> from sklearn.dummy import DummyClassifier >>> from sklearn.metrics import accuracy_score
>>> orders = skrub.var("orders", skrub.toy_orders().orders) >>> X = orders.skb.drop("delayed").skb.mark_as_X() >>> y = orders["delayed"].skb.mark_as_y() >>> delayed = X.skb.apply(skrub.TableVectorizer()).skb.apply( ... DummyClassifier(), y=y ... )
>>> split = delayed.skb.train_test_split(random_state=0) >>> split.keys() dict_keys(['train', 'test', 'X_train', 'X_test', 'y_train', 'y_test']) >>> pipeline = delayed.skb.get_pipeline() >>> pipeline.fit(split["train"]) SkrubPipeline(expr=<Apply DummyClassifier>) >>> pipeline.score(split["test"]) 0.0 >>> predictions = pipeline.predict(split["test"]) >>> accuracy_score(split["y_test"], predictions) 0.0