import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams['figure.figsize'] = (15, 8)
import numpy as np
import pandas as pd
from dsci345 import pretty
rng = np.random.default_rng(123)
$$\renewcommand{\P}{\mathbb{P}} \newcommand{\E}{\mathbb{E}} \newcommand{\var}{\text{var}} \newcommand{\sd}{\text{sd}} \newcommand{\cov}{\text{cov}} \newcommand{\cor}{\text{cor}}$$
This is here so we can use \P
and \E
and \var
and \cov
and \cor
and \sd
in LaTeX below.
A cautionary tale¶
import datetime
full_covid = pd.read_csv("data/United_States_COVID-19_Cases_and_Deaths_by_State_over_Time.csv")[["state", "tot_cases", "submission_date"]]
covid = full_covid.rename(columns={'submission_date': 'date', "tot_cases": "cases"})
covid.date = pd.to_datetime(covid.date)
covid = covid[covid.date < np.datetime64("2021-01-01")]
covid = covid[covid.date.dt.dayofweek == 0]
states = ['AK', 'AL', 'AR', 'AZ', 'CA', 'CO', 'CT', 'DC', 'DE', 'FL',
'GA', 'GU', 'HI', 'IA', 'ID', 'IL', 'IN', 'KS', 'KY', 'LA',
'MA', 'MD', 'ME', 'MI', 'MN', 'MO', 'MS', 'MT', 'NC', 'ND',
'NE', 'NH', 'NJ', 'NM', 'NV', 'NY', 'OH', 'OK', 'OR', 'PA',
'PR', 'RI', 'SC', 'SD', 'TN', 'TX', 'UT', 'VA',
'VT', 'WA', 'WI', 'WV', 'WY']
covid = covid[np.isin(covid.state, states)]
covid = covid.pivot(index="date", columns="state")
covid = covid.diff()[1:]
covid.columns = [x[1] for x in covid.columns]
Here are weekly COVID case counts across the US states plus DC, PR, and GU for 2020, a 48 x 53 matrix:
covid
AK | AL | AR | AZ | CA | CO | CT | DC | DE | FL | ... | SD | TN | TX | UT | VA | VT | WA | WI | WV | WY | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
date | |||||||||||||||||||||
2020-02-03 | 0.0 | 0.0 | 0.0 | 0.0 | 6.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
2020-02-10 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 2.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
2020-02-17 | 0.0 | 0.0 | 0.0 | 0.0 | 2.0 | 0.0 | 0.0 | 0.0 | 2.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
2020-02-24 | 0.0 | 0.0 | 0.0 | 0.0 | 17.0 | 0.0 | 0.0 | 0.0 | 2.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
2020-03-02 | 0.0 | 0.0 | 0.0 | 0.0 | 14.0 | 1.0 | 0.0 | 0.0 | 0.0 | 2.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
2020-03-09 | 0.0 | 0.0 | 0.0 | 5.0 | 94.0 | 11.0 | 1.0 | 2.0 | 0.0 | 16.0 | ... | 0.0 | 0.0 | 12.0 | 1.0 | 2.0 | 1.0 | 153.0 | 1.0 | 0.0 | 0.0 |
2020-03-16 | 1.0 | 36.0 | 17.0 | 12.0 | 259.0 | 135.0 | 40.0 | 20.0 | 7.0 | 138.0 | ... | 10.0 | 10.0 | 44.0 | 28.0 | 49.0 | 7.0 | 656.0 | 52.0 | 0.0 | 3.0 |
2020-03-23 | 37.0 | 188.0 | 180.0 | 235.0 | 1341.0 | 582.0 | 374.0 | 115.0 | 58.0 | 1065.0 | ... | 18.0 | 495.0 | 231.0 | 251.0 | 203.0 | 60.0 | 1212.0 | 363.0 | 20.0 | 23.0 |
2020-03-30 | 80.0 | 770.0 | 307.0 | 904.0 | 4030.0 | 1890.0 | 2156.0 | 358.0 | 245.0 | 4408.0 | ... | 73.0 | 1037.0 | 2587.0 | 526.0 | 766.0 | 168.0 | 2553.0 | 805.0 | 125.0 | 69.0 |
2020-04-06 | 73.0 | 1076.0 | 411.0 | 1299.0 | 8573.0 | 2572.0 | 4335.0 | 602.0 | 475.0 | 7585.0 | ... | 187.0 | 2117.0 | 4399.0 | 910.0 | 1858.0 | 274.0 | 3222.0 | 1219.0 | 200.0 | 117.0 |
2020-04-13 | 86.0 | 1789.0 | 495.0 | 1246.0 | 8012.0 | 2719.0 | 6475.0 | 858.0 | 921.0 | 6979.0 | ... | 580.0 | 1730.0 | 6633.0 | 676.0 | 2869.0 | 191.0 | 2303.0 | 988.0 | 281.0 | 161.0 |
2020-04-20 | 44.0 | 1288.0 | 554.0 | 1362.0 | 8630.0 | 2390.0 | 6434.0 | 972.0 | 1097.0 | 5966.0 | ... | 817.0 | 1721.0 | 5552.0 | 867.0 | 3243.0 | 55.0 | 1899.0 | 1071.0 | 276.0 | 55.0 |
2020-04-27 | 24.0 | 1457.0 | 1091.0 | 1652.0 | 12486.0 | 3779.0 | 6182.0 | 965.0 | 1448.0 | 5545.0 | ... | 560.0 | 2556.0 | 5839.0 | 1048.0 | 4545.0 | 31.0 | 1706.0 | 1582.0 | 175.0 | 92.0 |
2020-05-04 | 25.0 | 1581.0 | 414.0 | 2203.0 | 11473.0 | 3364.0 | 3976.0 | 1278.0 | 1310.0 | 4489.0 | ... | 423.0 | 3516.0 | 7035.0 | 1099.0 | 5957.0 | 42.0 | 1855.0 | 2155.0 | 147.0 | 76.0 |
2020-05-11 | 11.0 | 2053.0 | 574.0 | 2461.0 | 13002.0 | 2809.0 | 3792.0 | 1219.0 | 1498.0 | 3991.0 | ... | 946.0 | 1836.0 | 7537.0 | 989.0 | 5578.0 | 23.0 | 1681.0 | 2182.0 | 145.0 | 73.0 |
2020-05-18 | 20.0 | 2175.0 | 770.0 | 2790.0 | 12491.0 | 2394.0 | 4351.0 | 881.0 | 1104.0 | 5238.0 | ... | 413.0 | 2753.0 | 8824.0 | 1064.0 | 6070.0 | 12.0 | 1435.0 | 2269.0 | 133.0 | 97.0 |
2020-05-25 | 10.0 | 2870.0 | 1216.0 | 2391.0 | 14128.0 | 2207.0 | 2757.0 | 955.0 | 820.0 | 5146.0 | ... | 559.0 | 2451.0 | 7278.0 | 1125.0 | 6587.0 | 22.0 | 1455.0 | 2897.0 | 280.0 | 77.0 |
2020-06-01 | 59.0 | 3213.0 | 1414.0 | 3562.0 | 18448.0 | 2087.0 | 1867.0 | 632.0 | 638.0 | 5239.0 | ... | 448.0 | 3051.0 | 8909.0 | 1517.0 | 7671.0 | 21.0 | 1863.0 | 2959.0 | 246.0 | 67.0 |
2020-06-08 | 97.0 | 2739.0 | 2297.0 | 7555.0 | 18313.0 | 1578.0 | 1352.0 | 532.0 | 396.0 | 8321.0 | ... | 437.0 | 3228.0 | 10736.0 | 2328.0 | 5853.0 | 84.0 | 1997.0 | 2495.0 | 133.0 | 50.0 |
2020-06-15 | 101.0 | 5209.0 | 3177.0 | 9033.0 | 20133.0 | 1163.0 | 1143.0 | 410.0 | 388.0 | 13687.0 | ... | 457.0 | 3793.0 | 13492.0 | 2397.0 | 3635.0 | 50.0 | 2096.0 | 4442.0 | 161.0 | 119.0 |
2020-06-22 | 99.0 | 4266.0 | 3166.0 | 17888.0 | 26602.0 | 1524.0 | 547.0 | 259.0 | 397.0 | 23612.0 | ... | 398.0 | 4547.0 | 25773.0 | 3309.0 | 3579.0 | 36.0 | 2725.0 | 2358.0 | 249.0 | 151.0 |
2020-06-29 | 141.0 | 6797.0 | 4174.0 | 19954.0 | 38496.0 | 1788.0 | 580.0 | 234.0 | 666.0 | 48535.0 | ... | 390.0 | 5434.0 | 38130.0 | 4128.0 | 3724.0 | 31.0 | 3204.0 | 3195.0 | 299.0 | 220.0 |
2020-07-06 | 260.0 | 7295.0 | 3996.0 | 26915.0 | 55134.0 | 1982.0 | 614.0 | 223.0 | 851.0 | 59914.0 | ... | 389.0 | 10877.0 | 47546.0 | 3741.0 | 3913.0 | 41.0 | 4659.0 | 4197.0 | 572.0 | 225.0 |
2020-07-13 | 369.0 | 11142.0 | 4686.0 | 22390.0 | 57478.0 | 2988.0 | 534.0 | 391.0 | 891.0 | 76760.0 | ... | 419.0 | 11915.0 | 63756.0 | 4718.0 | 5540.0 | 45.0 | 4881.0 | 5152.0 | 871.0 | 228.0 |
2020-07-20 | 417.0 | 13038.0 | 4988.0 | 21366.0 | 62376.0 | 3350.0 | 545.0 | 433.0 | 692.0 | 76510.0 | ... | 419.0 | 14570.0 | 68121.0 | 4362.0 | 6733.0 | 54.0 | 5958.0 | 6293.0 | 829.0 | 284.0 |
2020-07-27 | 672.0 | 11976.0 | 5520.0 | 18651.0 | 69012.0 | 4095.0 | 928.0 | 519.0 | 767.0 | 70503.0 | ... | 501.0 | 16021.0 | 53489.0 | 3831.0 | 7697.0 | 40.0 | 5513.0 | 6606.0 | 912.0 | 333.0 |
2020-08-03 | 707.0 | 11206.0 | 5150.0 | 15677.0 | 54351.0 | 3680.0 | 1079.0 | 455.0 | 668.0 | 53915.0 | ... | 576.0 | 14060.0 | 56091.0 | 3238.0 | 7034.0 | 21.0 | 5670.0 | 6120.0 | 919.0 | 328.0 |
2020-08-10 | 455.0 | 9842.0 | 5431.0 | 8033.0 | 47010.0 | 2986.0 | 505.0 | 494.0 | 474.0 | 44277.0 | ... | 643.0 | 13014.0 | 48803.0 | 2853.0 | 7643.0 | 31.0 | 4872.0 | 5955.0 | 781.0 | 194.0 |
2020-08-17 | 527.0 | 6662.0 | 3049.0 | 6489.0 | 66120.0 | 2209.0 | 700.0 | 466.0 | 511.0 | 36154.0 | ... | 697.0 | 11112.0 | 52133.0 | 2481.0 | 6672.0 | 58.0 | 4193.0 | 5359.0 | 878.0 | 289.0 |
2020-08-24 | 517.0 | 6911.0 | 3817.0 | 4416.0 | 40584.0 | 2138.0 | 744.0 | 366.0 | 471.0 | 24946.0 | ... | 1065.0 | 10106.0 | 37434.0 | 2566.0 | 6209.0 | 36.0 | 3913.0 | 4904.0 | 680.0 | 272.0 |
2020-08-31 | 458.0 | 10366.0 | 4330.0 | 3428.0 | 35470.0 | 2091.0 | 868.0 | 353.0 | 651.0 | 24465.0 | ... | 2084.0 | 10075.0 | 32585.0 | 2739.0 | 6964.0 | 50.0 | 3344.0 | 4949.0 | 938.0 | 239.0 |
2020-09-07 | 520.0 | 5695.0 | 5056.0 | 4136.0 | 31150.0 | 2159.0 | 486.0 | 323.0 | 697.0 | 18770.0 | ... | 1791.0 | 10229.0 | 27401.0 | 2951.0 | 6977.0 | 31.0 | 3163.0 | 6360.0 | 1325.0 | 190.0 |
2020-09-14 | 541.0 | 6336.0 | 4347.0 | 2768.0 | 22543.0 | 2404.0 | 1530.0 | 307.0 | 692.0 | 18268.0 | ... | 1501.0 | 8427.0 | 23034.0 | 3540.0 | 7000.0 | 44.0 | 2717.0 | 8569.0 | 1245.0 | 360.0 |
2020-09-21 | 580.0 | 6609.0 | 5737.0 | 5533.0 | 23916.0 | 3760.0 | 1129.0 | 356.0 | 681.0 | 18630.0 | ... | 2068.0 | 10994.0 | 34983.0 | 5111.0 | 6567.0 | 24.0 | 2907.0 | 13091.0 | 1351.0 | 552.0 |
2020-09-28 | 733.0 | 6905.0 | 5685.0 | 3266.0 | 23569.0 | 4354.0 | 1123.0 | 286.0 | 732.0 | 16536.0 | ... | 2869.0 | 9406.0 | 40835.0 | 7048.0 | 5455.0 | 21.0 | 3947.0 | 15629.0 | 1341.0 | 810.0 |
2020-10-05 | 982.0 | 7010.0 | 5381.0 | 3567.0 | 21521.0 | 4319.0 | 1973.0 | 283.0 | 958.0 | 15795.0 | ... | 2860.0 | 8763.0 | 30081.0 | 7281.0 | 5964.0 | 64.0 | 3829.0 | 17438.0 | 1230.0 | 875.0 |
2020-10-12 | 1258.0 | 6911.0 | 6057.0 | 4987.0 | 23244.0 | 5803.0 | 2257.0 | 475.0 | 948.0 | 18209.0 | ... | 4327.0 | 14216.0 | 25823.0 | 8109.0 | 7013.0 | 54.0 | 4271.0 | 18676.0 | 1539.0 | 1173.0 |
2020-10-19 | 1322.0 | 7452.0 | 6110.0 | 5854.0 | 20763.0 | 8035.0 | 2644.0 | 373.0 | 909.0 | 21336.0 | ... | 4911.0 | 13443.0 | 33401.0 | 8730.0 | 7258.0 | 69.0 | 4557.0 | 22811.0 | 2012.0 | 1509.0 |
2020-10-26 | 2166.0 | 11592.0 | 7130.0 | 7074.0 | 30219.0 | 11308.0 | 4078.0 | 417.0 | 944.0 | 25682.0 | ... | 5905.0 | 18277.0 | 38548.0 | 10521.0 | 7447.0 | 139.0 | 5089.0 | 28786.0 | 1930.0 | 2166.0 |
2020-11-02 | 2716.0 | 9513.0 | 6914.0 | 9211.0 | 29618.0 | 15715.0 | 5759.0 | 626.0 | 1127.0 | 29906.0 | ... | 8109.0 | 13939.0 | 110710.0 | 11623.0 | 9143.0 | 134.0 | 6119.0 | 33000.0 | 3012.0 | 2690.0 |
2020-11-09 | 3239.0 | 10609.0 | 9170.0 | 11581.0 | 41223.0 | 25028.0 | 7605.0 | 649.0 | 1552.0 | 34832.0 | ... | 8461.0 | 21972.0 | 61179.0 | 17162.0 | 10059.0 | 216.0 | 9421.0 | 41452.0 | 3570.0 | 3843.0 |
2020-11-16 | 4071.0 | 14028.0 | 11537.0 | 17234.0 | 57384.0 | 34597.0 | 11821.0 | 977.0 | 2273.0 | 44342.0 | ... | 9967.0 | 28019.0 | 77266.0 | 20911.0 | 11160.0 | 620.0 | 13727.0 | 48182.0 | 5655.0 | 5183.0 |
2020-11-23 | 3932.0 | 15812.0 | 11842.0 | 25433.0 | 81135.0 | 37795.0 | 13456.0 | 1226.0 | 2999.0 | 54578.0 | ... | 7570.0 | 28054.0 | 85027.0 | 23641.0 | 16401.0 | 677.0 | 15029.0 | 45131.0 | 6654.0 | 6238.0 |
2020-11-30 | 4443.0 | 15860.0 | 11169.0 | 24514.0 | 102598.0 | 32762.0 | 10555.0 | 1262.0 | 3426.0 | 53444.0 | ... | 6616.0 | 27143.0 | 77175.0 | 16286.0 | 16797.0 | 460.0 | 18941.0 | 32037.0 | 6728.0 | 3874.0 |
2020-12-07 | 4903.0 | 23381.0 | 14683.0 | 39047.0 | 153467.0 | 34561.0 | 18549.0 | 1767.0 | 5055.0 | 64630.0 | ... | 6036.0 | 32250.0 | 106372.0 | 21932.0 | 21035.0 | 752.0 | 17031.0 | 31659.0 | 8286.0 | 3680.0 |
2020-12-14 | 3908.0 | 24904.0 | 15015.0 | 54426.0 | 218609.0 | 26619.0 | 18148.0 | 1719.0 | 5818.0 | 69233.0 | ... | 4854.0 | 55834.0 | 97337.0 | 18234.0 | 26279.0 | 713.0 | 23549.0 | 27896.0 | 8266.0 | 2790.0 |
2020-12-21 | 2383.0 | 27509.0 | 16050.0 | 41118.0 | 307304.0 | 18576.0 | 13385.0 | 1702.0 | 4735.0 | 77384.0 | ... | 3720.0 | 64380.0 | 109807.0 | 16911.0 | 25741.0 | 648.0 | 15917.0 | 23068.0 | 8943.0 | 2343.0 |
2020-12-28 | 1839.0 | 21699.0 | 13421.0 | 43099.0 | 263628.0 | 14576.0 | 13823.0 | 1602.0 | 4504.0 | 67592.0 | ... | 2583.0 | 40264.0 | 91478.0 | 13807.0 | 25285.0 | 544.0 | 12829.0 | 15784.0 | 8099.0 | 1586.0 |
48 rows × 53 columns
Question: Can we predict, say, Oregon's case counts using the other states?
from sklearn.linear_model import LinearRegression as lm
other_states = covid.loc[:,covid.columns != "OR"]
obs_OR = covid.loc[:,"OR"]
OR_fit = lm().fit(other_states, obs_OR)
est_OR = OR_fit.predict(other_states)
plt.plot(covid.index, obs_OR, label="observed")
plt.plot(covid.index, est_OR, label="estimated", linestyle="--")
plt.legend(); plt.xlabel("date"); plt.ylabel("OR case counts");
Gee, we can predict the case counts perfectly? Does that seem very likely? What's going on?
Let's think about what we're trying to do here. We're trying to find coefficients $b_1, \ldots, b_k$ so that the the linear combination of the columns of $X$ $$ \hat y = b_1 X_{\cdot 1} + \cdots + b_k X_{\cdot k} $$ is as close to $y$ as possible.
Well, when is it possible to find $b$ so that $\hat y = y$? It is possible if $y$ is in the column space of $X$.
Recall that $X$ is an $n \times k$ matrix. If $n \le k$ then the columns of $X$ span then entire space $\mathbb{R}^n$ (unless for instance some columns are identical).
Takeaway: if you have more variables than observations (the problem is singular and) it is always possible to exactly predict the response.
However, these predictions are unlikely to be generalizable.
Crossvalidation¶
How to tell how good your model is?
See how well it predicts "new" data.
To do $k$-fold crossvalidation:
- Split your dataset into $k$ chunks (these should be independent!), and
- for each chunk in turn, put it aside for "testing" and train your model on the remaining $k-1$ chunks.
- Compare "test error" to "training error".
Predictions for data used to fit the model ("training error") should not be much further off than for data held out ("test error").
This can be used either as an indication of overfitting or to compare different models to each other.
Crossvalidation set-up¶
Here's a pretty 'easy' prediction problem:
n = 100
x = np.array([
rng.uniform(low=0, high=10, size=n),
rng.normal(size=n),
]).T
y = 2.5 + 1.5 * x[:,0] - 4.2 * x[:,1] + rng.normal(size=n)
fig, (ax0, ax1) = plt.subplots(1, 2)
ax0.scatter(x[:,0], y); ax0.set_xlabel("x0"); ax0.set_ylabel("y")
ax1.scatter(x[:,1], y); ax1.set_xlabel("x1"); ax1.set_ylabel("y");
Let's (a) fit a linear model and (b) do crossvalidation to look for evidence of overfitting.
def rmse(X, y, model):
# root mean squared error, comparing y
# to the value predicted by `model` using `x`
yhat = model.predict(X)
resids = y - yhat
return np.sqrt(np.mean(resids ** 2))
def kfold(k, X, y, model):
# x: matrix, y: vector with same number of entries as rows of x
# model is something with .fit() and .predict() methods
n = len(y)
folds = np.repeat(np.arange(k), np.ceil(n / k))[:n]
rng.shuffle(folds)
test_rmse = []
train_rmse = []
for ik in range(k):
test_X = X[folds == ik]
test_y = y[folds == ik]
train_X = X[folds != ik]
train_y = y[folds != ik]
model.fit(train_X, train_y)
test_rmse.append(rmse(test_X, test_y, model))
train_rmse.append(rmse(train_X, train_y, model))
return pd.DataFrame({
"test" : test_rmse,
"train" : train_rmse,
})
crossval = kfold(5, x, y, lm())
crossval
test | train | |
---|---|---|
0 | 1.002842 | 1.102181 |
1 | 1.280119 | 1.030386 |
2 | 1.342009 | 1.017593 |
3 | 1.042074 | 1.093432 |
4 | 0.939767 | 1.114522 |
Conclusion: test error was higher than training error, but not much; there is (unsurprisingly) no evidence of (serious) overfitting.
When you've got too many variables¶
We're going to add more variables - these will be independent of everything else, so they should not give us meaningful predictive power for $y$. However, by chance each is a little correlated with $y$.
crossval = pd.DataFrame()
for new_vars in np.linspace(0, 80, 9):
new_x = rng.normal(size=(n, int(new_vars)))
X = np.column_stack([x, new_x])
xval = kfold(5, X, y, lm())
xval["new_vars"] = int(new_vars)
crossval = pd.concat([crossval, xval])
crossval.groupby("new_vars").agg(np.mean).plot();
/tmp/ipykernel_1290415/268185597.py:1: FutureWarning: The provided callable <function mean at 0x7f5258423420> is currently using DataFrameGroupBy.mean. In a future version of pandas, the provided callable will be used directly. To keep current behavior pass the string "mean" instead. crossval.groupby("new_vars").agg(np.mean).plot();
Conclusion: as the number of variables increases, the training error decreases - even though we know there's no new information being added! The model is overfitting, which leads to increasing test (i.e., out-of-sample) error - the thing we generally care about.
Regularization¶
Recall that, somewhat mysteriously, scikit-learn's method to fit a Binomial GLM with a logistic link function has a "penalty" option. What's that?
Well, method finds $b$ to maximize the likelihood under the following model: $$ Y_i \sim \text{Binomial}(N, p(X_i \cdot b)), $$ where $p(\cdot)$ is the logistic function. The terms in the log-likelihood that depend on $b$ are $$ \sum_{i=1}^n \left\{ Y_i \log(p(X_i \cdot b)) + (N_i - Y_i) \log(1 - p(X_i \cdot b)) \right\} . $$
The problem we had above was that the variables that didn't matter had small but nonzero estimated parameters; and there were so many of them, that together they added up to something big.
Solution: "encourage" them to be small.
So, with a "penalty" the method instead maximizes the log-likelihood minus a "regularization" term that does the "encouraging". Options: $$\begin{aligned} \sum_j |b_j| \qquad & \text{"L1" or "$\ell_1$" or "lasso"}\\ \sum_j b_j^2 \qquad & \text{"L2" or "$\ell_2$" or "ridge" or "Tikhonov"} \end{aligned}$$
sklearn.linear_model.LogisticRegression(penalty="l1")
therefore finds the $b$ that maximizes $$ \sum_{i=1}^n \left\{ Y_i \log(p(X_i \cdot b)) + (N_i - Y_i) \log(1 - p(X_i \cdot b)) \right\} - \sum_j |b_j| . $$
The Lasso¶
Let's try out this scheme on our example data, using the Lasso, which fits a standard least-squares linear model but with a L1 penalty, minimizing $$ \sum_i (y_i - X_i \cdot b)^2 + \alpha \sum_j |b_j| . $$
from sklearn.linear_model import Lasso
ridge_crossval = pd.DataFrame()
for new_vars in np.linspace(0, 80, 9):
new_x = rng.normal(size=(len(y), int(new_vars)))
X = np.column_stack([x, new_x])
xval = kfold(5, X, y, Lasso(alpha=.3))
xval["new_vars"] = int(new_vars)
ridge_crossval = pd.concat([ridge_crossval, xval])
ridge_crossval.groupby("new_vars").agg(np.mean).plot();
/tmp/ipykernel_1290415/3745485892.py:1: FutureWarning: The provided callable <function mean at 0x7f5258423420> is currently using DataFrameGroupBy.mean. In a future version of pandas, the provided callable will be used directly. To keep current behavior pass the string "mean" instead. ridge_crossval.groupby("new_vars").agg(np.mean).plot();
Test error is still consistently higher than training error - as we'd expect - but only slightly.
Exercise¶
Let's use crossvalidation to choose the strength of regularization (i.e., the $\alpha$ parameter in the lasso).
We'll apply it to the covid
data above.