filter() and filter_names() to select with user-defined criteria#

filter() and filter_names() allow selecting columns based on arbitrary user-defined criteria. These are also used to implement many of the other selectors provided in this module.

filter() accepts a function which will be called on a column (i.e., a Pandas or polars Series). This function, called a predicate, must return True if the column should be selected.

>>> import pandas as pd
>>> import skrub.selectors as s
>>> df = pd.DataFrame(
...     {
...         "height_mm": [297.0, 420.0],
...         "width_mm": [210.0, 297.0],
...         "kind": ["A4", "A3"],
...         "ID": [4, 3],
...     }
... )
>>> s.select(df, s.filter(lambda col: "A4" in col.tolist()))
  kind
0   A4
1   A3

filter_names() accepts a predicate that is passed the column name, instead of the column.

>>> s.select(df, s.filter_names(lambda name: name.endswith('mm')))
   height_mm  width_mm
0      297.0     210.0
1      420.0     297.0

We can pass args and kwargs that will be forwarded to the predicate, to help avoid lambda or local functions and thus ensure the selector is picklable.

>>> s.select(df, s.filter_names(str.endswith, 'mm'))
   height_mm  width_mm
0      297.0     210.0
1      420.0     297.0

Example of custom criteria in filter(): selecting columns with outliers#

The filter() selector can be used to select columns based on custom criteria. For example, we can define a function that checks if a column contains outliers using the Interquartile Range (IQR) method, and then use this function with filter() to select such columns.

Specifically, we define a function that computes the IQR (Inter Quartile Range) of a column and checks if any data points extend further than 2 IQRs of the lower and upper quartile.

>>> def has_outliers(column):
...    q1 = column.quantile(0.25)
...    q3 = column.quantile(0.75)
...    IQR = q3 - q1
...    lower_bound = q1 - 2 * IQR
...    upper_bound = q3 + 2 * IQR
...    outliers = (column < lower_bound) | (column > upper_bound)
...    return any(outliers)
>>> from skrub import SelectCols
>>> select = SelectCols(s.filter(has_outliers))
>>> data = pd.DataFrame({
...     "A": [10, 12, 14, 15, 100],  # Outlier in column A
...     "B": [20, 22, 21, 19, 20],   # No outliers in column B
...     "C": [30, 29, 31, 32, 300]   # Outlier in column C
... })
>>> select.fit_transform(data)
     A    C
0   10   30
1   12   29
2   14   31
3   15   32
4  100  300