Crossvalidation and Regularization¶

Peter Ralph

https://uodsci.github.io/dsci345

In [1]:
import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams['figure.figsize'] = (15, 8)
import numpy as np
import pandas as pd
from dsci345 import pretty

rng = np.random.default_rng(123)

$$\renewcommand{\P}{\mathbb{P}} \newcommand{\E}{\mathbb{E}} \newcommand{\var}{\text{var}} \newcommand{\sd}{\text{sd}} \newcommand{\cov}{\text{cov}} \newcommand{\cor}{\text{cor}}$$ This is here so we can use \P and \E and \var and \cov and \cor and \sd in LaTeX below.

A cautionary tale¶

In [2]:
import datetime
full_covid = pd.read_csv("data/United_States_COVID-19_Cases_and_Deaths_by_State_over_Time.csv")[["state", "tot_cases", "submission_date"]]
covid = full_covid.rename(columns={'submission_date': 'date', "tot_cases": "cases"})
covid.date = pd.to_datetime(covid.date)
covid = covid[covid.date < np.datetime64("2021-01-01")]
covid = covid[covid.date.dt.dayofweek == 0]
states = ['AK', 'AL', 'AR', 'AZ', 'CA', 'CO', 'CT', 'DC', 'DE', 'FL',
       'GA', 'GU', 'HI', 'IA', 'ID', 'IL', 'IN', 'KS', 'KY', 'LA',
       'MA', 'MD', 'ME', 'MI', 'MN', 'MO', 'MS', 'MT', 'NC', 'ND',
       'NE', 'NH', 'NJ', 'NM', 'NV', 'NY', 'OH', 'OK', 'OR', 'PA',
       'PR', 'RI', 'SC', 'SD', 'TN', 'TX', 'UT', 'VA',
       'VT', 'WA', 'WI', 'WV', 'WY']
covid = covid[np.isin(covid.state, states)]
covid = covid.pivot(index="date", columns="state")
covid = covid.diff()[1:]
covid.columns = [x[1] for x in covid.columns]

Here are weekly COVID case counts across the US states plus DC, PR, and GU for 2020, a 48 x 53 matrix:

In [3]:
covid
Out[3]:
AK AL AR AZ CA CO CT DC DE FL ... SD TN TX UT VA VT WA WI WV WY
date
2020-02-03 0.0 0.0 0.0 0.0 6.0 0.0 0.0 0.0 1.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
2020-02-10 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 2.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
2020-02-17 0.0 0.0 0.0 0.0 2.0 0.0 0.0 0.0 2.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
2020-02-24 0.0 0.0 0.0 0.0 17.0 0.0 0.0 0.0 2.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
2020-03-02 0.0 0.0 0.0 0.0 14.0 1.0 0.0 0.0 0.0 2.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
2020-03-09 0.0 0.0 0.0 5.0 94.0 11.0 1.0 2.0 0.0 16.0 ... 0.0 0.0 12.0 1.0 2.0 1.0 153.0 1.0 0.0 0.0
2020-03-16 1.0 36.0 17.0 12.0 259.0 135.0 40.0 20.0 7.0 138.0 ... 10.0 10.0 44.0 28.0 49.0 7.0 656.0 52.0 0.0 3.0
2020-03-23 37.0 188.0 180.0 235.0 1341.0 582.0 374.0 115.0 58.0 1065.0 ... 18.0 495.0 231.0 251.0 203.0 60.0 1212.0 363.0 20.0 23.0
2020-03-30 80.0 770.0 307.0 904.0 4030.0 1890.0 2156.0 358.0 245.0 4408.0 ... 73.0 1037.0 2587.0 526.0 766.0 168.0 2553.0 805.0 125.0 69.0
2020-04-06 73.0 1076.0 411.0 1299.0 8573.0 2572.0 4335.0 602.0 475.0 7585.0 ... 187.0 2117.0 4399.0 910.0 1858.0 274.0 3222.0 1219.0 200.0 117.0
2020-04-13 86.0 1789.0 495.0 1246.0 8012.0 2719.0 6475.0 858.0 921.0 6979.0 ... 580.0 1730.0 6633.0 676.0 2869.0 191.0 2303.0 988.0 281.0 161.0
2020-04-20 44.0 1288.0 554.0 1362.0 8630.0 2390.0 6434.0 972.0 1097.0 5966.0 ... 817.0 1721.0 5552.0 867.0 3243.0 55.0 1899.0 1071.0 276.0 55.0
2020-04-27 24.0 1457.0 1091.0 1652.0 12486.0 3779.0 6182.0 965.0 1448.0 5545.0 ... 560.0 2556.0 5839.0 1048.0 4545.0 31.0 1706.0 1582.0 175.0 92.0
2020-05-04 25.0 1581.0 414.0 2203.0 11473.0 3364.0 3976.0 1278.0 1310.0 4489.0 ... 423.0 3516.0 7035.0 1099.0 5957.0 42.0 1855.0 2155.0 147.0 76.0
2020-05-11 11.0 2053.0 574.0 2461.0 13002.0 2809.0 3792.0 1219.0 1498.0 3991.0 ... 946.0 1836.0 7537.0 989.0 5578.0 23.0 1681.0 2182.0 145.0 73.0
2020-05-18 20.0 2175.0 770.0 2790.0 12491.0 2394.0 4351.0 881.0 1104.0 5238.0 ... 413.0 2753.0 8824.0 1064.0 6070.0 12.0 1435.0 2269.0 133.0 97.0
2020-05-25 10.0 2870.0 1216.0 2391.0 14128.0 2207.0 2757.0 955.0 820.0 5146.0 ... 559.0 2451.0 7278.0 1125.0 6587.0 22.0 1455.0 2897.0 280.0 77.0
2020-06-01 59.0 3213.0 1414.0 3562.0 18448.0 2087.0 1867.0 632.0 638.0 5239.0 ... 448.0 3051.0 8909.0 1517.0 7671.0 21.0 1863.0 2959.0 246.0 67.0
2020-06-08 97.0 2739.0 2297.0 7555.0 18313.0 1578.0 1352.0 532.0 396.0 8321.0 ... 437.0 3228.0 10736.0 2328.0 5853.0 84.0 1997.0 2495.0 133.0 50.0
2020-06-15 101.0 5209.0 3177.0 9033.0 20133.0 1163.0 1143.0 410.0 388.0 13687.0 ... 457.0 3793.0 13492.0 2397.0 3635.0 50.0 2096.0 4442.0 161.0 119.0
2020-06-22 99.0 4266.0 3166.0 17888.0 26602.0 1524.0 547.0 259.0 397.0 23612.0 ... 398.0 4547.0 25773.0 3309.0 3579.0 36.0 2725.0 2358.0 249.0 151.0
2020-06-29 141.0 6797.0 4174.0 19954.0 38496.0 1788.0 580.0 234.0 666.0 48535.0 ... 390.0 5434.0 38130.0 4128.0 3724.0 31.0 3204.0 3195.0 299.0 220.0
2020-07-06 260.0 7295.0 3996.0 26915.0 55134.0 1982.0 614.0 223.0 851.0 59914.0 ... 389.0 10877.0 47546.0 3741.0 3913.0 41.0 4659.0 4197.0 572.0 225.0
2020-07-13 369.0 11142.0 4686.0 22390.0 57478.0 2988.0 534.0 391.0 891.0 76760.0 ... 419.0 11915.0 63756.0 4718.0 5540.0 45.0 4881.0 5152.0 871.0 228.0
2020-07-20 417.0 13038.0 4988.0 21366.0 62376.0 3350.0 545.0 433.0 692.0 76510.0 ... 419.0 14570.0 68121.0 4362.0 6733.0 54.0 5958.0 6293.0 829.0 284.0
2020-07-27 672.0 11976.0 5520.0 18651.0 69012.0 4095.0 928.0 519.0 767.0 70503.0 ... 501.0 16021.0 53489.0 3831.0 7697.0 40.0 5513.0 6606.0 912.0 333.0
2020-08-03 707.0 11206.0 5150.0 15677.0 54351.0 3680.0 1079.0 455.0 668.0 53915.0 ... 576.0 14060.0 56091.0 3238.0 7034.0 21.0 5670.0 6120.0 919.0 328.0
2020-08-10 455.0 9842.0 5431.0 8033.0 47010.0 2986.0 505.0 494.0 474.0 44277.0 ... 643.0 13014.0 48803.0 2853.0 7643.0 31.0 4872.0 5955.0 781.0 194.0
2020-08-17 527.0 6662.0 3049.0 6489.0 66120.0 2209.0 700.0 466.0 511.0 36154.0 ... 697.0 11112.0 52133.0 2481.0 6672.0 58.0 4193.0 5359.0 878.0 289.0
2020-08-24 517.0 6911.0 3817.0 4416.0 40584.0 2138.0 744.0 366.0 471.0 24946.0 ... 1065.0 10106.0 37434.0 2566.0 6209.0 36.0 3913.0 4904.0 680.0 272.0
2020-08-31 458.0 10366.0 4330.0 3428.0 35470.0 2091.0 868.0 353.0 651.0 24465.0 ... 2084.0 10075.0 32585.0 2739.0 6964.0 50.0 3344.0 4949.0 938.0 239.0
2020-09-07 520.0 5695.0 5056.0 4136.0 31150.0 2159.0 486.0 323.0 697.0 18770.0 ... 1791.0 10229.0 27401.0 2951.0 6977.0 31.0 3163.0 6360.0 1325.0 190.0
2020-09-14 541.0 6336.0 4347.0 2768.0 22543.0 2404.0 1530.0 307.0 692.0 18268.0 ... 1501.0 8427.0 23034.0 3540.0 7000.0 44.0 2717.0 8569.0 1245.0 360.0
2020-09-21 580.0 6609.0 5737.0 5533.0 23916.0 3760.0 1129.0 356.0 681.0 18630.0 ... 2068.0 10994.0 34983.0 5111.0 6567.0 24.0 2907.0 13091.0 1351.0 552.0
2020-09-28 733.0 6905.0 5685.0 3266.0 23569.0 4354.0 1123.0 286.0 732.0 16536.0 ... 2869.0 9406.0 40835.0 7048.0 5455.0 21.0 3947.0 15629.0 1341.0 810.0
2020-10-05 982.0 7010.0 5381.0 3567.0 21521.0 4319.0 1973.0 283.0 958.0 15795.0 ... 2860.0 8763.0 30081.0 7281.0 5964.0 64.0 3829.0 17438.0 1230.0 875.0
2020-10-12 1258.0 6911.0 6057.0 4987.0 23244.0 5803.0 2257.0 475.0 948.0 18209.0 ... 4327.0 14216.0 25823.0 8109.0 7013.0 54.0 4271.0 18676.0 1539.0 1173.0
2020-10-19 1322.0 7452.0 6110.0 5854.0 20763.0 8035.0 2644.0 373.0 909.0 21336.0 ... 4911.0 13443.0 33401.0 8730.0 7258.0 69.0 4557.0 22811.0 2012.0 1509.0
2020-10-26 2166.0 11592.0 7130.0 7074.0 30219.0 11308.0 4078.0 417.0 944.0 25682.0 ... 5905.0 18277.0 38548.0 10521.0 7447.0 139.0 5089.0 28786.0 1930.0 2166.0
2020-11-02 2716.0 9513.0 6914.0 9211.0 29618.0 15715.0 5759.0 626.0 1127.0 29906.0 ... 8109.0 13939.0 110710.0 11623.0 9143.0 134.0 6119.0 33000.0 3012.0 2690.0
2020-11-09 3239.0 10609.0 9170.0 11581.0 41223.0 25028.0 7605.0 649.0 1552.0 34832.0 ... 8461.0 21972.0 61179.0 17162.0 10059.0 216.0 9421.0 41452.0 3570.0 3843.0
2020-11-16 4071.0 14028.0 11537.0 17234.0 57384.0 34597.0 11821.0 977.0 2273.0 44342.0 ... 9967.0 28019.0 77266.0 20911.0 11160.0 620.0 13727.0 48182.0 5655.0 5183.0
2020-11-23 3932.0 15812.0 11842.0 25433.0 81135.0 37795.0 13456.0 1226.0 2999.0 54578.0 ... 7570.0 28054.0 85027.0 23641.0 16401.0 677.0 15029.0 45131.0 6654.0 6238.0
2020-11-30 4443.0 15860.0 11169.0 24514.0 102598.0 32762.0 10555.0 1262.0 3426.0 53444.0 ... 6616.0 27143.0 77175.0 16286.0 16797.0 460.0 18941.0 32037.0 6728.0 3874.0
2020-12-07 4903.0 23381.0 14683.0 39047.0 153467.0 34561.0 18549.0 1767.0 5055.0 64630.0 ... 6036.0 32250.0 106372.0 21932.0 21035.0 752.0 17031.0 31659.0 8286.0 3680.0
2020-12-14 3908.0 24904.0 15015.0 54426.0 218609.0 26619.0 18148.0 1719.0 5818.0 69233.0 ... 4854.0 55834.0 97337.0 18234.0 26279.0 713.0 23549.0 27896.0 8266.0 2790.0
2020-12-21 2383.0 27509.0 16050.0 41118.0 307304.0 18576.0 13385.0 1702.0 4735.0 77384.0 ... 3720.0 64380.0 109807.0 16911.0 25741.0 648.0 15917.0 23068.0 8943.0 2343.0
2020-12-28 1839.0 21699.0 13421.0 43099.0 263628.0 14576.0 13823.0 1602.0 4504.0 67592.0 ... 2583.0 40264.0 91478.0 13807.0 25285.0 544.0 12829.0 15784.0 8099.0 1586.0

48 rows × 53 columns

Question: Can we predict, say, Oregon's case counts using the other states?

In [4]:
from sklearn.linear_model import LinearRegression as lm
other_states = covid.loc[:,covid.columns != "OR"]
obs_OR = covid.loc[:,"OR"]
OR_fit = lm().fit(other_states, obs_OR)
est_OR = OR_fit.predict(other_states)
In [5]:
plt.plot(covid.index, obs_OR, label="observed")
plt.plot(covid.index, est_OR, label="estimated", linestyle="--")
plt.legend(); plt.xlabel("date"); plt.ylabel("OR case counts");
No description has been provided for this image

Gee, we can predict the case counts perfectly? Does that seem very likely? What's going on?

Let's think about what we're trying to do here. We're trying to find coefficients $b_1, \ldots, b_k$ so that the the linear combination of the columns of $X$ $$ \hat y = b_1 X_{\cdot 1} + \cdots + b_k X_{\cdot k} $$ is as close to $y$ as possible.

Well, when is it possible to find $b$ so that $\hat y = y$? It is possible if $y$ is in the column space of $X$.

Recall that $X$ is an $n \times k$ matrix. If $n \le k$ then the columns of $X$ span then entire space $\mathbb{R}^n$ (unless for instance some columns are identical).

Takeaway: if you have more variables than observations (the problem is singular and) it is always possible to exactly predict the response.

However, these predictions are unlikely to be generalizable.

Crossvalidation¶

How to tell how good your model is?

See how well it predicts "new" data.

To do $k$-fold crossvalidation:

  1. Split your dataset into $k$ chunks (these should be independent!), and
  2. for each chunk in turn, put it aside for "testing" and train your model on the remaining $k-1$ chunks.
  3. Compare "test error" to "training error".

Predictions for data used to fit the model ("training error") should not be much further off than for data held out ("test error").

This can be used either as an indication of overfitting or to compare different models to each other.

Crossvalidation set-up¶

Here's a pretty 'easy' prediction problem:

In [6]:
n = 100
x = np.array([
    rng.uniform(low=0, high=10, size=n),
    rng.normal(size=n),
]).T
y = 2.5 + 1.5 * x[:,0] - 4.2 * x[:,1] + rng.normal(size=n)

fig, (ax0, ax1) = plt.subplots(1, 2)
ax0.scatter(x[:,0], y); ax0.set_xlabel("x0"); ax0.set_ylabel("y")
ax1.scatter(x[:,1], y); ax1.set_xlabel("x1"); ax1.set_ylabel("y");
No description has been provided for this image

Let's (a) fit a linear model and (b) do crossvalidation to look for evidence of overfitting.

In [7]:
def rmse(X, y, model):
    # root mean squared error, comparing y
    # to the value predicted by `model` using `x`
    yhat = model.predict(X)
    resids = y - yhat
    return np.sqrt(np.mean(resids ** 2))

def kfold(k, X, y, model):
    # x: matrix, y: vector with same number of entries as rows of x
    # model is something with .fit() and .predict() methods
    n  = len(y)
    folds = np.repeat(np.arange(k), np.ceil(n / k))[:n]
    rng.shuffle(folds)
    test_rmse = []
    train_rmse = []
    for ik in range(k):
        test_X = X[folds == ik]
        test_y = y[folds == ik]
        train_X = X[folds != ik]
        train_y = y[folds != ik]
        model.fit(train_X, train_y)
        test_rmse.append(rmse(test_X, test_y, model))
        train_rmse.append(rmse(train_X, train_y, model))
    return pd.DataFrame({
        "test" : test_rmse,
        "train" : train_rmse,
    })

crossval = kfold(5, x, y, lm())
crossval
Out[7]:
test train
0 1.002842 1.102181
1 1.280119 1.030386
2 1.342009 1.017593
3 1.042074 1.093432
4 0.939767 1.114522

Conclusion: test error was higher than training error, but not much; there is (unsurprisingly) no evidence of (serious) overfitting.

When you've got too many variables¶

We're going to add more variables - these will be independent of everything else, so they should not give us meaningful predictive power for $y$. However, by chance each is a little correlated with $y$.

In [8]:
crossval = pd.DataFrame()
for new_vars in np.linspace(0, 80, 9):
    new_x = rng.normal(size=(n, int(new_vars)))
    X = np.column_stack([x, new_x])
    xval = kfold(5, X, y, lm())
    xval["new_vars"] = int(new_vars)
    crossval = pd.concat([crossval, xval])
In [9]:
crossval.groupby("new_vars").agg(np.mean).plot();
/tmp/ipykernel_1290415/268185597.py:1: FutureWarning: The provided callable <function mean at 0x7f5258423420> is currently using DataFrameGroupBy.mean. In a future version of pandas, the provided callable will be used directly. To keep current behavior pass the string "mean" instead.
  crossval.groupby("new_vars").agg(np.mean).plot();
No description has been provided for this image

Conclusion: as the number of variables increases, the training error decreases - even though we know there's no new information being added! The model is overfitting, which leads to increasing test (i.e., out-of-sample) error - the thing we generally care about.

Regularization¶

Recall that, somewhat mysteriously, scikit-learn's method to fit a Binomial GLM with a logistic link function has a "penalty" option. What's that?

Well, method finds $b$ to maximize the likelihood under the following model: $$ Y_i \sim \text{Binomial}(N, p(X_i \cdot b)), $$ where $p(\cdot)$ is the logistic function. The terms in the log-likelihood that depend on $b$ are $$ \sum_{i=1}^n \left\{ Y_i \log(p(X_i \cdot b)) + (N_i - Y_i) \log(1 - p(X_i \cdot b)) \right\} . $$

The problem we had above was that the variables that didn't matter had small but nonzero estimated parameters; and there were so many of them, that together they added up to something big.

Solution: "encourage" them to be small.

So, with a "penalty" the method instead maximizes the log-likelihood minus a "regularization" term that does the "encouraging". Options: $$\begin{aligned} \sum_j |b_j| \qquad & \text{"L1" or "$\ell_1$" or "lasso"}\\ \sum_j b_j^2 \qquad & \text{"L2" or "$\ell_2$" or "ridge" or "Tikhonov"} \end{aligned}$$

sklearn.linear_model.LogisticRegression(penalty="l1")

therefore finds the $b$ that maximizes $$ \sum_{i=1}^n \left\{ Y_i \log(p(X_i \cdot b)) + (N_i - Y_i) \log(1 - p(X_i \cdot b)) \right\} - \sum_j |b_j| . $$

The Lasso¶

Let's try out this scheme on our example data, using the Lasso, which fits a standard least-squares linear model but with a L1 penalty, minimizing $$ \sum_i (y_i - X_i \cdot b)^2 + \alpha \sum_j |b_j| . $$

In [10]:
from sklearn.linear_model import Lasso

ridge_crossval = pd.DataFrame()
for new_vars in np.linspace(0, 80, 9):
    new_x = rng.normal(size=(len(y), int(new_vars)))
    X = np.column_stack([x, new_x])
    xval = kfold(5, X, y, Lasso(alpha=.3))
    xval["new_vars"] = int(new_vars)
    ridge_crossval = pd.concat([ridge_crossval, xval])
In [11]:
ridge_crossval.groupby("new_vars").agg(np.mean).plot();
/tmp/ipykernel_1290415/3745485892.py:1: FutureWarning: The provided callable <function mean at 0x7f5258423420> is currently using DataFrameGroupBy.mean. In a future version of pandas, the provided callable will be used directly. To keep current behavior pass the string "mean" instead.
  ridge_crossval.groupby("new_vars").agg(np.mean).plot();
No description has been provided for this image

Test error is still consistently higher than training error - as we'd expect - but only slightly.

Exercise¶

Let's use crossvalidation to choose the strength of regularization (i.e., the $\alpha$ parameter in the lasso).

We'll apply it to the covid data above.