.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/09_interpolation_join.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. or to run this example in your browser via JupyterLite or Binder .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_09_interpolation_join.py: Interpolation join: infer missing rows when joining two tables ============================================================== We illustrate the :class:`~skrub.InterpolationJoiner`, which is a type of join where values from the second table are inferred with machine-learning, rather than looked up in the table. It is useful when exact matches are not available but we have rows that are close enough to make an educated guess -- in this sense it is a generalization of a :func:`~skrub.fuzzy_join`. The :class:`~skrub.InterpolationJoiner` is therefore a transformer that adds the outputs of one or more machine-learning models as new columns to the table it operates on. In this example we want our transformer to add weather data (temperature, rain, etc.) to the table it operates on. We have a table containing information about commercial flights, and we want to add information about the weather at the time and place where each flight took off. This could be useful to predict delays -- flights are often delayed by bad weather. We have a table of weather data containing, at many weather stations, measurements such as temperature, rain and snow at many time points. Unfortunately, our weather stations are not inside the airports, and the measurements are not timed according to the flight schedule. Therefore, a simple equi-join would not yield any matching pair of rows from our two tables. Instead, we use the :class:`~skrub.InterpolationJoiner` to *infer* the temperature at the airport at take-off time. We train supervised machine-learning models using the weather table, then query them with the times and locations in the flights table. .. GENERATED FROM PYTHON SOURCE LINES 32-37 Load weather data ----------------- We join the table containing the measurements to the table that contains the weather stations’ latitude and longitude. We subsample these large tables for the example to run faster. .. GENERATED FROM PYTHON SOURCE LINES 37-47 .. code-block:: Python from skrub.datasets import fetch_figshare weather = fetch_figshare("41771457").X weather = weather.sample(100_000, random_state=0, ignore_index=True) stations = fetch_figshare("41710524").X weather = stations.merge(weather, on="ID")[ ["LATITUDE", "LONGITUDE", "YEAR/MONTH/DAY", "TMAX", "PRCP", "SNOW"] ] .. GENERATED FROM PYTHON SOURCE LINES 48-50 The ``'TMAX'`` is in tenths of degree Celsius -- a ``'TMAX'`` of 297 means the maximum temperature that day was 29.7℃. We convert it to degrees for readability .. GENERATED FROM PYTHON SOURCE LINES 50-53 .. code-block:: Python weather["TMAX"] /= 10 .. GENERATED FROM PYTHON SOURCE LINES 54-61 InterpolationJoiner with a ground truth: joining the weather table on itself ---------------------------------------------------------------------------- As a first simple example, we apply the :class:`~skrub.InterpolationJoiner` in a situation where the ground truth is known. We split the weather table in half and join the second half on the first half. Thus, the values from the right side table of the join are inferred, whereas the corresponding columns from the left side contain the ground truth and we can compare them. .. GENERATED FROM PYTHON SOURCE LINES 61-66 .. code-block:: Python n_main = weather.shape[0] // 2 main_table = weather.iloc[:n_main] main_table.head() .. raw:: html
LATITUDE LONGITUDE YEAR/MONTH/DAY TMAX PRCP SNOW
0 25.333 55.517 2008-07-26 42.6 0.0 NaN
1 25.333 55.517 2008-03-07 26.9 NaN NaN
2 25.333 55.517 2008-09-13 41.6 NaN NaN
3 25.255 55.364 2008-07-14 46.6 0.0 NaN
4 25.255 55.364 2008-10-10 36.1 0.0 NaN


.. GENERATED FROM PYTHON SOURCE LINES 67-71 .. code-block:: Python aux_table = weather.iloc[n_main:] aux_table.head() .. raw:: html
LATITUDE LONGITUDE YEAR/MONTH/DAY TMAX PRCP SNOW
50000 39.3745 104.8145 2008-08-19 NaN 3.0 0.0
50001 39.3745 104.8145 2008-12-12 NaN 0.0 0.0
50002 39.4850 104.9089 2008-10-08 NaN 0.0 0.0
50003 39.4850 104.9089 2008-07-06 NaN 0.0 0.0
50004 39.4850 104.9089 2008-08-06 NaN 0.0 0.0


.. GENERATED FROM PYTHON SOURCE LINES 72-78 Joining the tables ------------------ Now we join our two tables and check how well the :class:`~skrub.InterpolationJoiner` can reconstruct the matching rows that are missing from the right side table. To avoid clashes in the column names, we use the ``suffix`` parameter to append ``"predicted"`` to the right side table column names. .. GENERATED FROM PYTHON SOURCE LINES 78-89 .. code-block:: Python from skrub import InterpolationJoiner joiner = InterpolationJoiner( aux_table, key=["LATITUDE", "LONGITUDE", "YEAR/MONTH/DAY"], suffix="_predicted", ).fit(main_table) join = joiner.transform(main_table) join.head() .. raw:: html
LATITUDE LONGITUDE YEAR/MONTH/DAY TMAX PRCP SNOW TMAX_predicted PRCP_predicted SNOW_predicted
0 25.333 55.517 2008-07-26 42.6 0.0 NaN 34.016554 60.221732 0.146801
1 25.333 55.517 2008-03-07 26.9 NaN NaN 27.389339 83.225358 0.440140
2 25.333 55.517 2008-09-13 41.6 NaN NaN 31.310463 19.196909 0.146801
3 25.255 55.364 2008-07-14 46.6 0.0 NaN 33.429145 53.017047 0.146801
4 25.255 55.364 2008-10-10 36.1 0.0 NaN 30.312918 36.310926 0.262846


.. GENERATED FROM PYTHON SOURCE LINES 90-92 Comparing the estimated values to the ground truth -------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 92-114 .. code-block:: Python from matplotlib import pyplot as plt join = join.sample(2000, random_state=0, ignore_index=True) fig, axes = plt.subplots( 3, 1, figsize=(5, 9), gridspec_kw={"height_ratios": [1.0, 0.5, 0.5]}, layout="compressed", ) for ax, col in zip(axes.ravel(), ["TMAX", "PRCP", "SNOW"]): ax.scatter( join[col].values, join[f"{col}_predicted"].values, alpha=0.1, ) ax.set_aspect(1) ax.set_xlabel(f"true {col}") ax.set_ylabel(f"predicted {col}") plt.show() .. image-sg:: /auto_examples/images/sphx_glr_09_interpolation_join_001.png :alt: 09 interpolation join :srcset: /auto_examples/images/sphx_glr_09_interpolation_join_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 115-117 We see that in this case the interpolation join works well for the temperature, but not precipitation nor snow. So we will only add the temperature to our flights table. .. GENERATED FROM PYTHON SOURCE LINES 117-120 .. code-block:: Python aux_table = aux_table.drop(["PRCP", "SNOW"], axis=1) .. GENERATED FROM PYTHON SOURCE LINES 121-126 Loading the flights table ------------------------- We load the flights table and join it to the airports table using the flights’ ``'Origin'`` which refers to the departure airport’s IATA code. We use only a subset to speed up the example. .. GENERATED FROM PYTHON SOURCE LINES 126-134 .. code-block:: Python flights = fetch_figshare("41771418").X[["Year_Month_DayofMonth", "Origin", "ArrDelay"]] flights = flights.sample(20_000, random_state=0, ignore_index=True) airports = fetch_figshare("41710257").X[["iata", "airport", "state", "lat", "long"]] flights = flights.merge(airports, left_on="Origin", right_on="iata") # printing the first row is more readable than the head() when we have many columns flights.iloc[0] .. rst-class:: sphx-glr-script-out .. code-block:: none Year_Month_DayofMonth 2008-02-24 00:00:00 Origin DTW ArrDelay 35.0 iata DTW airport Detroit Metropolitan-Wayne County state MI lat 42.212059 long -83.348836 Name: 0, dtype: object .. GENERATED FROM PYTHON SOURCE LINES 135-141 Joining the flights and weather data ------------------------------------ As before, we initialize our join transformer with the weather table. Then, we use it to transform the flights table -- it adds a ``'TMAX'`` column containing the predicted maximum daily temperature. .. GENERATED FROM PYTHON SOURCE LINES 141-150 .. code-block:: Python joiner = InterpolationJoiner( aux_table, main_key=["lat", "long", "Year_Month_DayofMonth"], aux_key=["LATITUDE", "LONGITUDE", "YEAR/MONTH/DAY"], ) join = joiner.fit_transform(flights) join.head() .. raw:: html
Year_Month_DayofMonth Origin ArrDelay iata airport state lat long TMAX
0 2008-02-24 DTW 35.0 DTW Detroit Metropolitan-Wayne County MI 42.212059 -83.348836 1.867857
1 2008-03-02 DFW 65.0 DFW Dallas-Fort Worth International TX 32.895951 -97.037200 21.181435
2 2008-03-16 GSO -15.0 GSO Piedmont Triad International NC 36.097747 -79.937297 17.180010
3 2008-04-27 ORD 26.0 ORD Chicago O'Hare International IL 41.979595 -87.904464 16.801401
4 2008-01-27 DFW -10.0 DFW Dallas-Fort Worth International TX 32.895951 -97.037200 13.456575


.. GENERATED FROM PYTHON SOURCE LINES 151-155 Sanity checks ------------- This time we do not have a ground truth for the temperatures. We can perform a few basic sanity checks. .. GENERATED FROM PYTHON SOURCE LINES 155-158 .. code-block:: Python state_temperatures = join.groupby("state")["TMAX"].mean().sort_values() .. GENERATED FROM PYTHON SOURCE LINES 159-161 States with the lowest average predicted temperatures: Alaska, Montana, North Dakota, Washington, Minnesota. .. GENERATED FROM PYTHON SOURCE LINES 161-163 .. code-block:: Python state_temperatures.head() .. rst-class:: sphx-glr-script-out .. code-block:: none state AK -3.732601 MT 0.159955 WA 0.739974 ND 0.951557 MN 1.410564 Name: TMAX, dtype: float64 .. GENERATED FROM PYTHON SOURCE LINES 164-166 States with the highest predicted temperatures: Puerto Rico, Virgin Islands, Hawaii, Florida, Louisiana. .. GENERATED FROM PYTHON SOURCE LINES 166-168 .. code-block:: Python state_temperatures.tail() .. rst-class:: sphx-glr-script-out .. code-block:: none state LA 21.831033 FL 24.815112 HI 27.092572 VI 30.370775 PR 30.971560 Name: TMAX, dtype: float64 .. GENERATED FROM PYTHON SOURCE LINES 169-171 Higher latitudes (farther up north) are colder -- the airports in this dataset are in the United States. .. GENERATED FROM PYTHON SOURCE LINES 171-177 .. code-block:: Python fig, ax = plt.subplots() ax.scatter(join["lat"], join["TMAX"]) ax.set_xlabel("Latitude (higher is farther north)") ax.set_ylabel("TMAX") plt.show() .. image-sg:: /auto_examples/images/sphx_glr_09_interpolation_join_002.png :alt: 09 interpolation join :srcset: /auto_examples/images/sphx_glr_09_interpolation_join_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 178-181 Winter months are colder than spring -- in the north hemisphere January is colder than April .. GENERATED FROM PYTHON SOURCE LINES 181-189 .. code-block:: Python import seaborn as sns join["month"] = join["Year_Month_DayofMonth"].dt.strftime("%m %B") plt.figure(layout="constrained") sns.barplot(data=join.sort_values(by="month"), y="month", x="TMAX") plt.show() .. image-sg:: /auto_examples/images/sphx_glr_09_interpolation_join_003.png :alt: 09 interpolation join :srcset: /auto_examples/images/sphx_glr_09_interpolation_join_003.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 190-194 Of course these checks do not guarantee that the inferred values in our ``join`` table’s ``'TMAX'`` column are accurate. But at least the :class:`~skrub.InterpolationJoiner` seems to have learned a few reasonable trends from its training table. .. GENERATED FROM PYTHON SOURCE LINES 197-207 Conclusion ---------- We have seen how to fit an :class:`~skrub.InterpolationJoiner` transformer: we give it a table (the weather data) and a set of matching columns (here date, latitude, longitude) and it learns to predict the other columns’ values (such as the max daily temperature). Then, it transforms tables by *predicting* values that a matching row would contain, rather than by searching for an actual match. It is a generalization of the :func:`~skrub.fuzzy_join`, as :func:`~skrub.fuzzy_join` is the same thing as an :class:`~skrub.InterpolationJoiner` where the estimators are 1-nearest-neighbor estimators. .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 5.611 seconds) .. _sphx_glr_download_auto_examples_09_interpolation_join.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: binder-badge .. image:: images/binder_badge_logo.svg :target: https://mybinder.org/v2/gh/skrub-data/skrub/main?urlpath=lab/tree/notebooks/auto_examples/09_interpolation_join.ipynb :alt: Launch binder :width: 150 px .. container:: lite-badge .. image:: images/jupyterlite_badge_logo.svg :target: ../lite/lab/index.html?path=auto_examples/09_interpolation_join.ipynb :alt: Launch JupyterLite :width: 150 px .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: 09_interpolation_join.ipynb <09_interpolation_join.ipynb>` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: 09_interpolation_join.py <09_interpolation_join.py>` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: 09_interpolation_join.zip <09_interpolation_join.zip>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_