Spatial join for flight data: Joining across multiple columns#

Joining tables may be difficult if one entry on one side does not have an exact match on the other side.

This problem becomes even more complex when multiple columns are significant for the join. For instance, this is the case for spatial joins on two columns, typically longitude and latitude.

Joiner() is a scikit-learn compatible transformer that enables performing joins across multiple keys, independently of the data type (numerical, string or mixed).

The following example uses US domestic flights data to illustrate how space and time information from a pool of tables are combined for machine learning.

Flight-delays data#

The goal is to predict flight delays. We have a pool of tables that we will use to improve our prediction.

The following tables are at our disposal:

The main table: flights dataset#

  • The flights datasets. It contains all US flights date, origin and destination airports and flight time. Here, we consider only flights from 2008.

import pandas as pd

from skrub.datasets import fetch_figshare

seed = 1
flights = fetch_figshare("41771418").X

# Sampling for faster computation.
flights = flights.sample(5_000, random_state=seed, ignore_index=True)
flights.head()
Year_Month_DayofMonth DayOfWeek CRSDepTime CRSArrTime UniqueCarrier FlightNum TailNum CRSElapsedTime ArrDelay Origin Dest Distance
0 2008-01-13 7 1900-01-01 18:35:00 1900-01-01 20:08:00 CO 150 N17244 213.0 1.0 IAH ONT 1334.0
1 2008-02-21 4 1900-01-01 14:30:00 1900-01-01 16:06:00 NW 807 N590NW 216.0 2.0 MSP SEA 1399.0
2 2008-04-17 4 1900-01-01 09:40:00 1900-01-01 13:15:00 WN 1684 N642WN 155.0 -13.0 SEA DEN 1024.0
3 2008-01-03 4 1900-01-01 08:40:00 1900-01-01 12:03:00 CO 287 N21723 383.0 46.0 EWR SNA 2433.0
4 2008-01-31 4 1900-01-01 12:50:00 1900-01-01 14:10:00 MQ 3157 N848AE 80.0 -14.0 SJC SNA 342.0


Let us see the arrival delay of the flights in the dataset:

import matplotlib.pyplot as plt
import seaborn as sns

sns.set_theme(style="ticks")

ax = sns.histplot(data=flights, x="ArrDelay")
ax.set_yscale("log")
plt.show()
07 multiple key join

Interesting, most delays are relatively short (<100 min), but there are some very long ones.

Airport data: an auxiliary table from the same database#

  • The airports dataset, with information such as their name and location (longitude, latitude).

airports = fetch_figshare("41710257").X
airports.head()
iata airport city state country lat long
0 00M Thigpen Bay Springs MS USA 31.953765 -89.234505
1 00R Livingston Municipal Livingston TX USA 30.685861 -95.017928
2 00V Meadow Lake Colorado Springs CO USA 38.945749 -104.569893
3 01G Perry-Warsaw Perry NY USA 42.741347 -78.052081
4 01J Hilliard Airpark Hilliard FL USA 30.688012 -81.905944


Weather data: auxiliary tables from external sources#

  • The weather table. Weather details by measurement station. Both tables are from the Global Historical Climatology Network. Here, we consider only weather measurements from 2008.

weather = fetch_figshare("41771457").X
# Sampling for faster computation.
weather = weather.sample(10_000, random_state=seed, ignore_index=True)
weather.head()
ID YEAR/MONTH/DAY TMAX PRCP SNOW
0 ASN00041037 2008-06-18 NaN 0.0 NaN
1 USC00164696 2008-01-04 39.0 0.0 0.0
2 US1ILSP0008 2008-08-19 NaN 0.0 0.0
3 USC00164931 2008-10-05 NaN 0.0 0.0
4 NOE00111309 2008-08-18 NaN 0.0 NaN


  • The stations dataset. Provides location of all the weather measurement stations in the US.

stations = fetch_figshare("41710524").X
stations.head()
ID LATITUDE LONGITUDE ELEVATION STATE NAME GSN FLAG HCN/CRN FLAG WMO ID
0 ACW00011604 17.1167 -61.7833 10.1 ST JOHNS COOLIDGE FLD None None NaN NaN
1 ACW00011647 17.1333 -61.7833 19.2 ST JOHNS None None NaN NaN
2 AE000041196 25.3330 55.5170 34.0 SHARJAH INTER. AIRP None GSN 41196.0 NaN
3 AEM00041194 25.2550 55.3640 10.4 DUBAI INTL None None 41194.0 NaN
4 AEM00041217 24.4330 54.6510 26.8 ABU DHABI INTL None None 41217.0 NaN


Joining: feature augmentation across tables#

First we join the stations with weather on the ID (exact join):

ID LATITUDE LONGITUDE ELEVATION STATE NAME GSN FLAG HCN/CRN FLAG WMO ID YEAR/MONTH/DAY TMAX PRCP SNOW
0 AEM00041218 24.2620 55.6090 264.9 AL AIN INTL None None 41218.0 NaN 2008-02-18 310.0 NaN NaN
1 AG000060590 30.5667 2.8667 397.0 EL-GOLEA None GSN 60590.0 NaN 2008-10-09 278.0 0.0 NaN
2 AGE00147716 35.1000 -1.8500 83.0 NEMOURS (GHAZAOUET) None None 60517.0 NaN 2008-09-23 301.0 0.0 NaN
3 AGE00147718 34.8500 5.7200 125.0 BISKRA None None 60525.0 NaN 2008-08-09 429.0 0.0 NaN
4 AGM00060353 36.8170 5.8830 6.0 JIJEL-PORT None None 60353.0 NaN 2008-07-01 286.0 0.0 NaN


Then we join this table with the airports so that we get all auxiliary tables into one.

from skrub import Joiner

joiner = Joiner(airports, aux_key=["lat", "long"], main_key=["LATITUDE", "LONGITUDE"])

aux_augmented = joiner.fit_transform(aux)

aux_augmented.head()
ID LATITUDE LONGITUDE ELEVATION STATE NAME GSN FLAG HCN/CRN FLAG WMO ID YEAR/MONTH/DAY TMAX PRCP SNOW iata airport city state country lat long skrub_Joiner_distance skrub_Joiner_rescaled_distance skrub_Joiner_match_accepted
0 AEM00041218 24.2620 55.6090 264.9 AL AIN INTL None None 41218.0 NaN 2008-02-18 310.0 NaN NaN ROP Prachinburi None None Thailand 14.078333 101.378334 2.345567 3.214517 True
1 AG000060590 30.5667 2.8667 397.0 EL-GOLEA None GSN 60590.0 NaN 2008-10-09 278.0 0.0 NaN X96 Cruz Bay Harbor Seaplane Base Cruz Bay VI USA 18.336898 -64.799583 3.303558 4.527409 True
2 AGE00147716 35.1000 -1.8500 83.0 NEMOURS (GHAZAOUET) None None 60517.0 NaN 2008-09-23 301.0 0.0 NaN ACK Nantucket Memorial Nantucket MA USA 41.253052 -70.060181 3.073160 4.211656 True
3 AGE00147718 34.8500 5.7200 125.0 BISKRA None None 60525.0 NaN 2008-08-09 429.0 0.0 NaN ACK Nantucket Memorial Nantucket MA USA 41.253052 -70.060181 3.402099 4.662456 True
4 AGM00060353 36.8170 5.8830 6.0 JIJEL-PORT None None 60353.0 NaN 2008-07-01 286.0 0.0 NaN EPM Eastport Municipal Eastport ME USA 44.910111 -67.012694 3.332759 4.567428 True


Joining airports with flights data: Let’s instantiate another multiple key joiner on the date and the airport:

joiner = Joiner(
    aux_augmented,
    aux_key=["YEAR/MONTH/DAY", "iata"],
    main_key=["Year_Month_DayofMonth", "Origin"],
)

flights.drop(columns=["TailNum", "FlightNum"])
Year_Month_DayofMonth DayOfWeek CRSDepTime CRSArrTime UniqueCarrier CRSElapsedTime ArrDelay Origin Dest Distance
0 2008-01-13 7 1900-01-01 18:35:00 1900-01-01 20:08:00 CO 213.0 1.0 IAH ONT 1334.0
1 2008-02-21 4 1900-01-01 14:30:00 1900-01-01 16:06:00 NW 216.0 2.0 MSP SEA 1399.0
2 2008-04-17 4 1900-01-01 09:40:00 1900-01-01 13:15:00 WN 155.0 -13.0 SEA DEN 1024.0
3 2008-01-03 4 1900-01-01 08:40:00 1900-01-01 12:03:00 CO 383.0 46.0 EWR SNA 2433.0
4 2008-01-31 4 1900-01-01 12:50:00 1900-01-01 14:10:00 MQ 80.0 -14.0 SJC SNA 342.0
... ... ... ... ... ... ... ... ... ... ...
4995 2008-02-14 4 1900-01-01 09:00:00 1900-01-01 11:30:00 AS 150.0 -2.0 SEA LAS 866.0
4996 2008-03-25 2 1900-01-01 23:20:00 1900-01-01 06:21:00 US 241.0 -5.0 LAS CLT 1916.0
4997 2008-01-20 7 1900-01-01 06:00:00 1900-01-01 07:30:00 AQ 90.0 -13.0 LAS OAK 407.0
4998 2008-03-28 5 1900-01-01 09:20:00 1900-01-01 11:10:00 AA 170.0 93.0 DFW SLC 988.0
4999 2008-04-15 2 1900-01-01 23:59:00 1900-01-01 08:00:00 B6 301.0 0.0 SEA JFK 2421.0

5000 rows × 10 columns



Training data is then passed through a Pipeline:

  • We will combine all the information from our pool of tables into “flights”,

our main table. - We will use this main table to model the prediction of flight delay.

from sklearn.ensemble import HistGradientBoostingClassifier
from sklearn.pipeline import make_pipeline

from skrub import TableVectorizer

tv = TableVectorizer()
hgb = HistGradientBoostingClassifier()

pipeline_hgb = make_pipeline(joiner, tv, hgb)

We isolate our target variable and remove useless ID variables:

y = flights["ArrDelay"]
X = flights.drop(columns=["ArrDelay"])

We want to frame this as a classification problem: suppose that your company is obliged to reimburse the ticket price if the flight is delayed.

We have a binary classification problem: the flight was delayed (1) or not (0).

y = (y > 0).astype(int)
y.value_counts()
ArrDelay
0    2646
1    2354
Name: count, dtype: int64

The results:

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=seed)
pipeline_hgb.fit(X_train, y_train).score(X_test, y_test)
/home/circleci/project/.pixi/envs/doc/lib/python3.12/site-packages/sklearn/preprocessing/_encoders.py:242: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros
  warnings.warn(

0.5512

Conclusion#

In this example, we have combined multiple tables with complex joins on imprecise and multiple-key correspondences. This is made easy by skrub’s Joiner() transformer.

Our final cross-validated accuracy score is 0.55.

Total running time of the script: (0 minutes 27.247 seconds)

Gallery generated by Sphinx-Gallery