Interpolation join: infer missing rows when joining two tables#

We illustrate the InterpolationJoiner, which is a type of join where values from the second table are inferred with machine-learning, rather than looked up in the table. It is useful when exact matches are not available but we have rows that are close enough to make an educated guess – in this sense it is a generalization of a fuzzy_join().

The InterpolationJoiner is therefore a transformer that adds the outputs of one or more machine-learning models as new columns to the table it operates on.

In this example we want our transformer to add weather data (temperature, rain, etc.) to the table it operates on. We have a table containing information about commercial flights, and we want to add information about the weather at the time and place where each flight took off. This could be useful to predict delays – flights are often delayed by bad weather.

We have a table of weather data containing, at many weather stations, measurements such as temperature, rain and snow at many time points. Unfortunately, our weather stations are not inside the airports, and the measurements are not timed according to the flight schedule. Therefore, a simple equi-join would not yield any matching pair of rows from our two tables. Instead, we use the InterpolationJoiner to infer the temperature at the airport at take-off time. We train supervised machine-learning models using the weather table, then query them with the times and locations in the flights table.

Load weather data#

We join the table containing the measurements to the table that contains the weather stations’ latitude and longitude. We subsample these large tables for the example to run faster.

from skrub.datasets import fetch_figshare

weather = fetch_figshare("41771457").X
weather = weather.sample(100_000, random_state=0, ignore_index=True)
stations = fetch_figshare("41710524").X
weather = stations.merge(weather, on="ID")[
    ["LATITUDE", "LONGITUDE", "YEAR/MONTH/DAY", "TMAX", "PRCP", "SNOW"]
]

The 'TMAX' is in tenths of degree Celsius – a 'TMAX' of 297 means the maximum temperature that day was 29.7℃. We convert it to degrees for readability

weather["TMAX"] /= 10

InterpolationJoiner with a ground truth: joining the weather table on itself#

As a first simple example, we apply the InterpolationJoiner in a situation where the ground truth is known. We split the weather table in half and join the second half on the first half. Thus, the values from the right side table of the join are inferred, whereas the corresponding columns from the left side contain the ground truth and we can compare them.

LATITUDE LONGITUDE YEAR/MONTH/DAY TMAX PRCP SNOW
0 25.333 55.517 2008-07-26 42.6 0.0 NaN
1 25.333 55.517 2008-03-07 26.9 NaN NaN
2 25.333 55.517 2008-09-13 41.6 NaN NaN
3 25.255 55.364 2008-07-14 46.6 0.0 NaN
4 25.255 55.364 2008-10-10 36.1 0.0 NaN


LATITUDE LONGITUDE YEAR/MONTH/DAY TMAX PRCP SNOW
50000 39.3745 104.8145 2008-08-19 NaN 3.0 0.0
50001 39.3745 104.8145 2008-12-12 NaN 0.0 0.0
50002 39.4850 104.9089 2008-10-08 NaN 0.0 0.0
50003 39.4850 104.9089 2008-07-06 NaN 0.0 0.0
50004 39.4850 104.9089 2008-08-06 NaN 0.0 0.0


Joining the tables#

Now we join our two tables and check how well the InterpolationJoiner can reconstruct the matching rows that are missing from the right side table. To avoid clashes in the column names, we use the suffix parameter to append "predicted" to the right side table column names.

from skrub import InterpolationJoiner

joiner = InterpolationJoiner(
    aux_table,
    key=["LATITUDE", "LONGITUDE", "YEAR/MONTH/DAY"],
    suffix="_predicted",
).fit(main_table)
join = joiner.transform(main_table)
join.head()
LATITUDE LONGITUDE YEAR/MONTH/DAY TMAX PRCP SNOW TMAX_predicted PRCP_predicted SNOW_predicted
0 25.333 55.517 2008-07-26 42.6 0.0 NaN 35.637157 82.189197 -0.039541
1 25.333 55.517 2008-03-07 26.9 NaN NaN 26.834254 137.171376 0.237755
2 25.333 55.517 2008-09-13 41.6 NaN NaN 31.871600 27.685462 -0.039896
3 25.255 55.364 2008-07-14 46.6 0.0 NaN 34.787195 49.989701 -0.039896
4 25.255 55.364 2008-10-10 36.1 0.0 NaN 29.502211 33.688246 -0.039896


Comparing the estimated values to the ground truth#

from matplotlib import pyplot as plt

join = join.sample(2000, random_state=0, ignore_index=True)
fig, axes = plt.subplots(
    3,
    1,
    figsize=(5, 9),
    gridspec_kw={"height_ratios": [1.0, 0.5, 0.5]},
    layout="compressed",
)
for ax, col in zip(axes.ravel(), ["TMAX", "PRCP", "SNOW"]):
    ax.scatter(
        join[col].values,
        join[f"{col}_predicted"].values,
        alpha=0.1,
    )
    ax.set_aspect(1)
    ax.set_xlabel(f"true {col}")
    ax.set_ylabel(f"predicted {col}")
plt.show()
09 interpolation join

We see that in this case the interpolation join works well for the temperature, but not precipitation nor snow. So we will only add the temperature to our flights table.

aux_table = aux_table.drop(["PRCP", "SNOW"], axis=1)

Loading the flights table#

We load the flights table and join it to the airports table using the flights’ 'Origin' which refers to the departure airport’s IATA code. We use only a subset to speed up the example.

flights = fetch_figshare("41771418").X[["Year_Month_DayofMonth", "Origin", "ArrDelay"]]
flights = flights.sample(20_000, random_state=0, ignore_index=True)
airports = fetch_figshare("41710257").X[["iata", "airport", "state", "lat", "long"]]
flights = flights.merge(airports, left_on="Origin", right_on="iata")
# printing the first row is more readable than the head() when we have many columns
flights.iloc[0]
Year_Month_DayofMonth                  2008-02-24 00:00:00
Origin                                                 DTW
ArrDelay                                              35.0
iata                                                   DTW
airport                  Detroit Metropolitan-Wayne County
state                                                   MI
lat                                              42.212059
long                                            -83.348836
Name: 0, dtype: object

Joining the flights and weather data#

As before, we initialize our join transformer with the weather table. Then, we use it to transform the flights table – it adds a 'TMAX' column containing the predicted maximum daily temperature.

joiner = InterpolationJoiner(
    aux_table,
    main_key=["lat", "long", "Year_Month_DayofMonth"],
    aux_key=["LATITUDE", "LONGITUDE", "YEAR/MONTH/DAY"],
)
join = joiner.fit_transform(flights)
join.head()
Year_Month_DayofMonth Origin ArrDelay iata airport state lat long TMAX
0 2008-02-24 DTW 35.0 DTW Detroit Metropolitan-Wayne County MI 42.212059 -83.348836 0.929497
1 2008-04-29 DTW -11.0 DTW Detroit Metropolitan-Wayne County MI 42.212059 -83.348836 16.652344
2 2008-03-31 DTW 7.0 DTW Detroit Metropolitan-Wayne County MI 42.212059 -83.348836 6.490080
3 2008-03-19 DTW 84.0 DTW Detroit Metropolitan-Wayne County MI 42.212059 -83.348836 7.248217
4 2008-01-07 DTW -11.0 DTW Detroit Metropolitan-Wayne County MI 42.212059 -83.348836 7.106750


Sanity checks#

This time we do not have a ground truth for the temperatures. We can perform a few basic sanity checks.

state_temperatures = join.groupby("state")["TMAX"].mean().sort_values()

States with the lowest average predicted temperatures: Alaska, Montana, North Dakota, Washington, Minnesota.

state
AK   -2.597808
WA    0.834756
ND    0.899300
MT    0.927627
MN    1.461241
Name: TMAX, dtype: float64

States with the highest predicted temperatures: Puerto Rico, Virgin Islands, Hawaii, Florida, Louisiana.

state
LA    21.441632
FL    24.618361
HI    28.353713
VI    30.155955
PR    30.581503
Name: TMAX, dtype: float64

Higher latitudes (farther up north) are colder – the airports in this dataset are in the United States.

fig, ax = plt.subplots()
ax.scatter(join["lat"], join["TMAX"])
ax.set_xlabel("Latitude (higher is farther north)")
ax.set_ylabel("TMAX")
plt.show()
09 interpolation join

Winter months are colder than spring – in the north hemisphere January is colder than April

import seaborn as sns

join["month"] = join["Year_Month_DayofMonth"].dt.strftime("%m %B")
plt.figure(layout="constrained")
sns.barplot(data=join.sort_values(by="month"), y="month", x="TMAX")
plt.show()
09 interpolation join

Of course these checks do not guarantee that the inferred values in our join table’s 'TMAX' column are accurate. But at least the InterpolationJoiner seems to have learned a few reasonable trends from its training table.

Conclusion#

We have seen how to fit an InterpolationJoiner transformer: we give it a table (the weather data) and a set of matching columns (here date, latitude, longitude) and it learns to predict the other columns’ values (such as the max daily temperature). Then, it transforms tables by predicting values that a matching row would contain, rather than by searching for an actual match. It is a generalization of the fuzzy_join(), as fuzzy_join() is the same thing as an InterpolationJoiner where the estimators are 1-nearest-neighbor estimators.

Total running time of the script: (0 minutes 7.251 seconds)

Gallery generated by Sphinx-Gallery