.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/06_ken_embeddings.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. or to run this example in your browser via JupyterLite or Binder .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_06_ken_embeddings.py: Wikipedia embeddings to enrich the data ======================================= When the data comprises common entities (cities, companies or famous people), bringing new information assembled from external sources may be the key to improving the analysis. Embeddings, or vectorial representations of entities, are a conveniant way to capture and summarize the information on an entity. Relational data embeddings capture all common entities from Wikipedia. [#]_ These will be called `KEN embeddings` in the following example. We will see that these embeddings of common entities significantly improve our results. .. note:: This example requires `pyarrow` to be installed. .. [#] https://soda-inria.github.io/ken_embeddings/ .. |Pipeline| replace:: :class:`~sklearn.pipeline.Pipeline` .. |OneHotEncoder| replace:: :class:`~sklearn.preprocessing.OneHotEncoder` .. |ColumnTransformer| replace:: :class:`~sklearn.compose.ColumnTransformer` .. |MinHash| replace:: :class:`~skrub.MinHashEncoder` .. |HGBR| replace:: :class:`~sklearn.ensemble.HistGradientBoostingRegressor` .. GENERATED FROM PYTHON SOURCE LINES 40-45 The data -------- We will take a look at the video game sales dataset. Let's retrieve the dataset: .. GENERATED FROM PYTHON SOURCE LINES 45-56 .. code-block:: Python import pandas as pd X = pd.read_csv( "https://raw.githubusercontent.com/William2064888/vgsales.csv/main/vgsales.csv", sep=";", on_bad_lines="skip", ) # Shuffle the data X = X.sample(frac=1, random_state=11, ignore_index=True) X.head(3) .. raw:: html
Rank Name Platform Year Genre Publisher NA_Sales EU_Sales JP_Sales Other_Sales Global_Sales
0 6500 Star Wars: Bounty Hunter GC 2002 Shooter LucasArts 0.20 0.05 0.0 0.01 0.26
1 13442 Thrillville: Off the Rails DS 2007 Strategy LucasArts 0.03 0.01 0.0 0.00 0.05
2 15074 Thomas and Friends: Steaming around Sodor 3DS 2015 Action Avanquest Software 0.00 0.02 0.0 0.00 0.02


.. GENERATED FROM PYTHON SOURCE LINES 57-58 Our goal will be to predict the sales amount (y, our target column): .. GENERATED FROM PYTHON SOURCE LINES 58-62 .. code-block:: Python y = X["Global_Sales"] y .. rst-class:: sphx-glr-script-out .. code-block:: none 0 0.26 1 0.05 2 0.02 3 1.16 4 0.03 ... 16567 0.25 16568 0.49 16569 0.22 16570 0.53 16571 0.11 Name: Global_Sales, Length: 16572, dtype: float64 .. GENERATED FROM PYTHON SOURCE LINES 63-64 Let's take a look at the distribution of our target variable: .. GENERATED FROM PYTHON SOURCE LINES 64-72 .. code-block:: Python import matplotlib.pyplot as plt import seaborn as sns sns.set_theme(style="ticks") sns.histplot(y) plt.show() .. image-sg:: /auto_examples/images/sphx_glr_06_ken_embeddings_001.png :alt: 06 ken embeddings :srcset: /auto_examples/images/sphx_glr_06_ken_embeddings_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 73-74 It seems better to take the log of sales rather than the absolute values: .. GENERATED FROM PYTHON SOURCE LINES 74-80 .. code-block:: Python import numpy as np y = np.log(y) sns.histplot(y) plt.show() .. image-sg:: /auto_examples/images/sphx_glr_06_ken_embeddings_002.png :alt: 06 ken embeddings :srcset: /auto_examples/images/sphx_glr_06_ken_embeddings_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 81-82 Before moving further, let's carry out some basic preprocessing: .. GENERATED FROM PYTHON SOURCE LINES 82-89 .. code-block:: Python # Get a mask of the rows with missing values in "Publisher" and "Global_Sales" mask = X.isna()["Publisher"] | X.isna()["Global_Sales"] # And remove them X.dropna(subset=["Publisher", "Global_Sales"], inplace=True) y = y[~mask] .. GENERATED FROM PYTHON SOURCE LINES 90-97 Extracting entity embeddings ---------------------------- We will use KEN embeddings to enrich our data. We will start by checking out the available tables with :class:`~skrub.datasets.fetch_ken_table_aliases`: .. GENERATED FROM PYTHON SOURCE LINES 97-101 .. code-block:: Python from skrub.datasets import fetch_ken_table_aliases fetch_ken_table_aliases() .. rst-class:: sphx-glr-script-out .. code-block:: none {'schools', 'companies', 'all_entities', 'games', 'movies', 'albums'} .. GENERATED FROM PYTHON SOURCE LINES 102-105 The *games* table is the most relevant to our case. Let's see what kind of types we can find in it with the function :class:`~skrub.datasets.fetch_ken_types`: .. GENERATED FROM PYTHON SOURCE LINES 105-109 .. code-block:: Python from skrub.datasets import fetch_ken_types fetch_ken_types(embedding_table_id="games") .. raw:: html
Type
0 wikicat_1994_video_games
1 wikicat_irem_games
2 wikicat_ea_guingamp_players
3 wikicat_video_game_companies_of_the_united_kin...
4 wikicat_asian_games_medalists_in_swimming
... ...
636 wikicat_college_football_games
637 wikicat_sonic_team_games
638 wikicat_space_opera_video_games
639 wikicat_boxers_at_the_2002_asian_games
640 wikicat_motorcycle_video_games

641 rows × 1 columns



.. GENERATED FROM PYTHON SOURCE LINES 110-114 Interesting, we have a broad range of topics! Next, we'll use :class:`~skrub.datasets.fetch_ken_embeddings` to extract the embeddings of entities we need: .. GENERATED FROM PYTHON SOURCE LINES 114-116 .. code-block:: Python from skrub.datasets import fetch_ken_embeddings .. GENERATED FROM PYTHON SOURCE LINES 117-128 KEN Embeddings are classified by types. See the example on :class:`~skrub.datasets.fetch_ken_embeddings` to understand how you can filter types you are interested in. The :class:`~skrub.datasets.fetch_ken_embeddings` function allows us to specify the types to be included and/or excluded so as not to load all Wikipedia entity embeddings in a table. In a first table, we include all embeddings with the type name "game" and exclude those with type name "companies" or "developer". .. GENERATED FROM PYTHON SOURCE LINES 128-134 .. code-block:: Python embedding_games = fetch_ken_embeddings( search_types="game", exclude="companies|developer", embedding_table_id="games", ) .. GENERATED FROM PYTHON SOURCE LINES 135-137 In a second table, we include all embeddings containing the type name "game_development_companies", "game_companies" or "game_publish": .. GENERATED FROM PYTHON SOURCE LINES 137-149 .. code-block:: Python embedding_publisher = fetch_ken_embeddings( search_types="game_development_companies|game_companies|game_publish", embedding_table_id="games", ) # We keep the 200 embeddings column names in a list (for the |Pipeline|): n_dim = 200 emb_columns = [f"X{j}" for j in range(n_dim)] emb_columns2 = [f"X{j}_aux" for j in range(n_dim)] .. GENERATED FROM PYTHON SOURCE LINES 150-159 Merging the entities .................... We will now merge the entities from Wikipedia with their equivalent match in our video game sales table: The entities from the 'embedding_games' table will be merged along the column "Name" and the ones from 'embedding_publisher' table with the column "Publisher" .. GENERATED FROM PYTHON SOURCE LINES 159-167 .. code-block:: Python from skrub import Joiner fa1 = Joiner(embedding_games, aux_key="Entity", main_key="Name") fa2 = Joiner(embedding_publisher, aux_key="Entity", main_key="Publisher", suffix="_aux") X_full = fa1.fit_transform(X) X_full = fa2.fit_transform(X_full) .. GENERATED FROM PYTHON SOURCE LINES 168-174 Prediction with base features ----------------------------- We will forget for now the KEN Embeddings and build a typical learning pipeline, where will we try to predict the amount of sales only using the base features contained in the initial table. .. GENERATED FROM PYTHON SOURCE LINES 176-179 We first use scikit-learn's |ColumnTransformer| to define the columns that will be included in the learning process and the appropriate encoding of categorical variables using the |MinHash| and |OneHotEncoder|: .. GENERATED FROM PYTHON SOURCE LINES 179-194 .. code-block:: Python from sklearn.compose import make_column_transformer from sklearn.preprocessing import OneHotEncoder from skrub import MinHashEncoder min_hash = MinHashEncoder(n_components=100) ohe = OneHotEncoder(handle_unknown="ignore", sparse_output=False) encoder = make_column_transformer( ("passthrough", ["Year"]), (ohe, ["Genre"]), (min_hash, "Platform"), remainder="drop", ) .. GENERATED FROM PYTHON SOURCE LINES 195-197 We incorporate our |ColumnTransformer| into a |Pipeline|. We define a predictor, |HGBR|, fast and reliable for big datasets. .. GENERATED FROM PYTHON SOURCE LINES 197-203 .. code-block:: Python from sklearn.ensemble import HistGradientBoostingRegressor from sklearn.pipeline import make_pipeline hgb = HistGradientBoostingRegressor(random_state=0) pipeline = make_pipeline(encoder, hgb) .. GENERATED FROM PYTHON SOURCE LINES 204-205 The |Pipeline| can now be readily applied to the dataframe for prediction: .. GENERATED FROM PYTHON SOURCE LINES 205-226 .. code-block:: Python from sklearn.model_selection import cross_validate # We will save the results in a dictionnary: all_r2_scores = dict() all_rmse_scores = dict() cv_results = cross_validate( pipeline, X_full, y, scoring=["r2", "neg_root_mean_squared_error"] ) all_r2_scores["Base features"] = cv_results["test_r2"] all_rmse_scores["Base features"] = -cv_results["test_neg_root_mean_squared_error"] print("With base features:") print( f"Mean R2 is {all_r2_scores['Base features'].mean():.2f} +-" f" {all_r2_scores['Base features'].std():.2f} and the RMSE is" f" {all_rmse_scores['Base features'].mean():.2f} +-" f" {all_rmse_scores['Base features'].std():.2f}" ) .. rst-class:: sphx-glr-script-out .. code-block:: none With base features: Mean R2 is 0.21 +- 0.01 and the RMSE is 1.30 +- 0.01 .. GENERATED FROM PYTHON SOURCE LINES 227-232 Prediction with KEN Embeddings ------------------------------ We will now build a second learning pipeline using only the KEN embeddings from Wikipedia. .. GENERATED FROM PYTHON SOURCE LINES 234-235 We keep only the embeddings columns: .. GENERATED FROM PYTHON SOURCE LINES 235-239 .. code-block:: Python encoder2 = make_column_transformer( ("passthrough", emb_columns), ("passthrough", emb_columns2), remainder="drop" ) .. GENERATED FROM PYTHON SOURCE LINES 240-241 We redefine the |Pipeline|: .. GENERATED FROM PYTHON SOURCE LINES 241-243 .. code-block:: Python pipeline2 = make_pipeline(encoder2, hgb) .. GENERATED FROM PYTHON SOURCE LINES 244-245 Let's look at the results: .. GENERATED FROM PYTHON SOURCE LINES 245-260 .. code-block:: Python cv_results = cross_validate( pipeline2, X_full, y, scoring=["r2", "neg_root_mean_squared_error"] ) all_r2_scores["KEN features"] = cv_results["test_r2"] all_rmse_scores["KEN features"] = -cv_results["test_neg_root_mean_squared_error"] print("With KEN Embeddings:") print( f"Mean R2 is {all_r2_scores['KEN features'].mean():.2f} +-" f" {all_r2_scores['KEN features'].std():.2f} and the RMSE is" f" {all_rmse_scores['KEN features'].mean():.2f} +-" f" {all_rmse_scores['KEN features'].std():.2f}" ) .. rst-class:: sphx-glr-script-out .. code-block:: none With KEN Embeddings: Mean R2 is 0.36 +- 0.01 and the RMSE is 1.17 +- 0.01 .. GENERATED FROM PYTHON SOURCE LINES 261-263 It seems including the embeddings is very relevant for the prediction task at hand! .. GENERATED FROM PYTHON SOURCE LINES 265-271 Prediction with KEN Embeddings and base features ------------------------------------------------ As we have seen the predictions scores in the case when embeddings are only present and when they are missing, we will do a final prediction with all variables included. .. GENERATED FROM PYTHON SOURCE LINES 273-274 We include both the embeddings and the base features: .. GENERATED FROM PYTHON SOURCE LINES 274-283 .. code-block:: Python encoder3 = make_column_transformer( ("passthrough", emb_columns), ("passthrough", emb_columns2), ("passthrough", ["Year"]), (ohe, ["Genre"]), (min_hash, "Platform"), remainder="drop", ) .. GENERATED FROM PYTHON SOURCE LINES 284-285 We redefine the |Pipeline|: .. GENERATED FROM PYTHON SOURCE LINES 285-287 .. code-block:: Python pipeline3 = make_pipeline(encoder3, hgb) .. GENERATED FROM PYTHON SOURCE LINES 288-289 Let's look at the results: .. GENERATED FROM PYTHON SOURCE LINES 289-304 .. code-block:: Python cv_results = cross_validate( pipeline3, X_full, y, scoring=["r2", "neg_root_mean_squared_error"] ) all_r2_scores["Base + KEN features"] = cv_results["test_r2"] all_rmse_scores["Base + KEN features"] = -cv_results["test_neg_root_mean_squared_error"] print("With KEN Embeddings and base features:") print( f"Mean R2 is {all_r2_scores['Base + KEN features'].mean():.2f} +-" f" {all_r2_scores['Base + KEN features'].std():.2f} and the RMSE is" f" {all_rmse_scores['Base + KEN features'].mean():.2f} +-" f" {all_rmse_scores['Base + KEN features'].std():.2f}" ) .. rst-class:: sphx-glr-script-out .. code-block:: none With KEN Embeddings and base features: Mean R2 is 0.49 +- 0.01 and the RMSE is 1.04 +- 0.01 .. GENERATED FROM PYTHON SOURCE LINES 305-309 Plotting the results .................... Finally, we plot the scores on a boxplot: .. GENERATED FROM PYTHON SOURCE LINES 309-316 .. code-block:: Python plt.figure(figsize=(5, 3)) # sphinx_gallery_thumbnail_number = -1 ax = sns.boxplot(data=pd.DataFrame(all_r2_scores), orient="h") plt.xlabel("Prediction accuracy ", size=15) plt.yticks(size=15) plt.tight_layout() .. image-sg:: /auto_examples/images/sphx_glr_06_ken_embeddings_003.png :alt: 06 ken embeddings :srcset: /auto_examples/images/sphx_glr_06_ken_embeddings_003.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 317-326 There is a clear improvement when including the KEN embeddings among the explanatory variables. In this case, the embeddings from Wikipedia introduced additional background information on the game and the publisher of the game that would otherwise be missed. It helped significantly improve the prediction score. .. rst-class:: sphx-glr-timing **Total running time of the script:** (1 minutes 6.648 seconds) .. _sphx_glr_download_auto_examples_06_ken_embeddings.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: binder-badge .. image:: images/binder_badge_logo.svg :target: https://mybinder.org/v2/gh/skrub-data/skrub/0.3.0?urlpath=lab/tree/notebooks/auto_examples/06_ken_embeddings.ipynb :alt: Launch binder :width: 150 px .. container:: lite-badge .. image:: images/jupyterlite_badge_logo.svg :target: ../lite/lab/index.html?path=auto_examples/06_ken_embeddings.ipynb :alt: Launch JupyterLite :width: 150 px .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: 06_ken_embeddings.ipynb <06_ken_embeddings.ipynb>` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: 06_ken_embeddings.py <06_ken_embeddings.py>` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: 06_ken_embeddings.zip <06_ken_embeddings.zip>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_