Here are weekly COVID case counts across the US states plus DC, PR, and GU for 2020, a 48 x 53 matrix:
covid
AK | AL | AR | AZ | CA | CO | CT | DC | DE | FL | ... | SD | TN | TX | UT | VA | VT | WA | WI | WV | WY | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
date | |||||||||||||||||||||
2020-02-03 | 0.0 | 0.0 | 0.0 | 0.0 | 6.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
2020-02-10 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 2.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
2020-02-17 | 0.0 | 0.0 | 0.0 | 0.0 | 2.0 | 0.0 | 0.0 | 0.0 | 2.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
2020-02-24 | 0.0 | 0.0 | 0.0 | 0.0 | 17.0 | 0.0 | 0.0 | 0.0 | 2.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
2020-03-02 | 0.0 | 0.0 | 0.0 | 0.0 | 14.0 | 1.0 | 0.0 | 0.0 | 0.0 | 2.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
2020-03-09 | 0.0 | 0.0 | 0.0 | 5.0 | 94.0 | 11.0 | 1.0 | 2.0 | 0.0 | 16.0 | ... | 0.0 | 0.0 | 12.0 | 1.0 | 2.0 | 1.0 | 153.0 | 1.0 | 0.0 | 0.0 |
2020-03-16 | 1.0 | 36.0 | 17.0 | 12.0 | 259.0 | 135.0 | 40.0 | 20.0 | 7.0 | 138.0 | ... | 10.0 | 10.0 | 44.0 | 28.0 | 49.0 | 7.0 | 656.0 | 52.0 | 0.0 | 3.0 |
2020-03-23 | 37.0 | 188.0 | 180.0 | 235.0 | 1341.0 | 582.0 | 374.0 | 115.0 | 58.0 | 1065.0 | ... | 18.0 | 495.0 | 231.0 | 251.0 | 203.0 | 60.0 | 1212.0 | 363.0 | 20.0 | 23.0 |
2020-03-30 | 80.0 | 770.0 | 307.0 | 904.0 | 4030.0 | 1890.0 | 2156.0 | 358.0 | 245.0 | 4408.0 | ... | 73.0 | 1037.0 | 2587.0 | 526.0 | 766.0 | 168.0 | 2553.0 | 805.0 | 125.0 | 69.0 |
2020-04-06 | 73.0 | 1076.0 | 411.0 | 1299.0 | 8573.0 | 2572.0 | 4335.0 | 602.0 | 475.0 | 7585.0 | ... | 187.0 | 2117.0 | 4399.0 | 910.0 | 1858.0 | 274.0 | 3222.0 | 1219.0 | 200.0 | 117.0 |
2020-04-13 | 86.0 | 1789.0 | 495.0 | 1246.0 | 8012.0 | 2719.0 | 6475.0 | 858.0 | 921.0 | 6979.0 | ... | 580.0 | 1730.0 | 6633.0 | 676.0 | 2869.0 | 191.0 | 2303.0 | 988.0 | 281.0 | 161.0 |
2020-04-20 | 44.0 | 1288.0 | 554.0 | 1362.0 | 8630.0 | 2390.0 | 6434.0 | 972.0 | 1097.0 | 5966.0 | ... | 817.0 | 1721.0 | 5552.0 | 867.0 | 3243.0 | 55.0 | 1899.0 | 1071.0 | 276.0 | 55.0 |
2020-04-27 | 24.0 | 1457.0 | 1091.0 | 1652.0 | 12486.0 | 3779.0 | 6182.0 | 965.0 | 1448.0 | 5545.0 | ... | 560.0 | 2556.0 | 5839.0 | 1048.0 | 4545.0 | 31.0 | 1706.0 | 1582.0 | 175.0 | 92.0 |
2020-05-04 | 25.0 | 1581.0 | 414.0 | 2203.0 | 11473.0 | 3364.0 | 3976.0 | 1278.0 | 1310.0 | 4489.0 | ... | 423.0 | 3516.0 | 7035.0 | 1099.0 | 5957.0 | 42.0 | 1855.0 | 2155.0 | 147.0 | 76.0 |
2020-05-11 | 11.0 | 2053.0 | 574.0 | 2461.0 | 13002.0 | 2809.0 | 3792.0 | 1219.0 | 1498.0 | 3991.0 | ... | 946.0 | 1836.0 | 7537.0 | 989.0 | 5578.0 | 23.0 | 1681.0 | 2182.0 | 145.0 | 73.0 |
2020-05-18 | 20.0 | 2175.0 | 770.0 | 2790.0 | 12491.0 | 2394.0 | 4351.0 | 881.0 | 1104.0 | 5238.0 | ... | 413.0 | 2753.0 | 8824.0 | 1064.0 | 6070.0 | 12.0 | 1435.0 | 2269.0 | 133.0 | 97.0 |
2020-05-25 | 10.0 | 2870.0 | 1216.0 | 2391.0 | 14128.0 | 2207.0 | 2757.0 | 955.0 | 820.0 | 5146.0 | ... | 559.0 | 2451.0 | 7278.0 | 1125.0 | 6587.0 | 22.0 | 1455.0 | 2897.0 | 280.0 | 77.0 |
2020-06-01 | 59.0 | 3213.0 | 1414.0 | 3562.0 | 18448.0 | 2087.0 | 1867.0 | 632.0 | 638.0 | 5239.0 | ... | 448.0 | 3051.0 | 8909.0 | 1517.0 | 7671.0 | 21.0 | 1863.0 | 2959.0 | 246.0 | 67.0 |
2020-06-08 | 97.0 | 2739.0 | 2297.0 | 7555.0 | 18313.0 | 1578.0 | 1352.0 | 532.0 | 396.0 | 8321.0 | ... | 437.0 | 3228.0 | 10736.0 | 2328.0 | 5853.0 | 84.0 | 1997.0 | 2495.0 | 133.0 | 50.0 |
2020-06-15 | 101.0 | 5209.0 | 3177.0 | 9033.0 | 20133.0 | 1163.0 | 1143.0 | 410.0 | 388.0 | 13687.0 | ... | 457.0 | 3793.0 | 13492.0 | 2397.0 | 3635.0 | 50.0 | 2096.0 | 4442.0 | 161.0 | 119.0 |
2020-06-22 | 99.0 | 4266.0 | 3166.0 | 17888.0 | 26602.0 | 1524.0 | 547.0 | 259.0 | 397.0 | 23612.0 | ... | 398.0 | 4547.0 | 25773.0 | 3309.0 | 3579.0 | 36.0 | 2725.0 | 2358.0 | 249.0 | 151.0 |
2020-06-29 | 141.0 | 6797.0 | 4174.0 | 19954.0 | 38496.0 | 1788.0 | 580.0 | 234.0 | 666.0 | 48535.0 | ... | 390.0 | 5434.0 | 38130.0 | 4128.0 | 3724.0 | 31.0 | 3204.0 | 3195.0 | 299.0 | 220.0 |
2020-07-06 | 260.0 | 7295.0 | 3996.0 | 26915.0 | 55134.0 | 1982.0 | 614.0 | 223.0 | 851.0 | 59914.0 | ... | 389.0 | 10877.0 | 47546.0 | 3741.0 | 3913.0 | 41.0 | 4659.0 | 4197.0 | 572.0 | 225.0 |
2020-07-13 | 369.0 | 11142.0 | 4686.0 | 22390.0 | 57478.0 | 2988.0 | 534.0 | 391.0 | 891.0 | 76760.0 | ... | 419.0 | 11915.0 | 63756.0 | 4718.0 | 5540.0 | 45.0 | 4881.0 | 5152.0 | 871.0 | 228.0 |
2020-07-20 | 417.0 | 13038.0 | 4988.0 | 21366.0 | 62376.0 | 3350.0 | 545.0 | 433.0 | 692.0 | 76510.0 | ... | 419.0 | 14570.0 | 68121.0 | 4362.0 | 6733.0 | 54.0 | 5958.0 | 6293.0 | 829.0 | 284.0 |
2020-07-27 | 672.0 | 11976.0 | 5520.0 | 18651.0 | 69012.0 | 4095.0 | 928.0 | 519.0 | 767.0 | 70503.0 | ... | 501.0 | 16021.0 | 53489.0 | 3831.0 | 7697.0 | 40.0 | 5513.0 | 6606.0 | 912.0 | 333.0 |
2020-08-03 | 707.0 | 11206.0 | 5150.0 | 15677.0 | 54351.0 | 3680.0 | 1079.0 | 455.0 | 668.0 | 53915.0 | ... | 576.0 | 14060.0 | 56091.0 | 3238.0 | 7034.0 | 21.0 | 5670.0 | 6120.0 | 919.0 | 328.0 |
2020-08-10 | 455.0 | 9842.0 | 5431.0 | 8033.0 | 47010.0 | 2986.0 | 505.0 | 494.0 | 474.0 | 44277.0 | ... | 643.0 | 13014.0 | 48803.0 | 2853.0 | 7643.0 | 31.0 | 4872.0 | 5955.0 | 781.0 | 194.0 |
2020-08-17 | 527.0 | 6662.0 | 3049.0 | 6489.0 | 66120.0 | 2209.0 | 700.0 | 466.0 | 511.0 | 36154.0 | ... | 697.0 | 11112.0 | 52133.0 | 2481.0 | 6672.0 | 58.0 | 4193.0 | 5359.0 | 878.0 | 289.0 |
2020-08-24 | 517.0 | 6911.0 | 3817.0 | 4416.0 | 40584.0 | 2138.0 | 744.0 | 366.0 | 471.0 | 24946.0 | ... | 1065.0 | 10106.0 | 37434.0 | 2566.0 | 6209.0 | 36.0 | 3913.0 | 4904.0 | 680.0 | 272.0 |
2020-08-31 | 458.0 | 10366.0 | 4330.0 | 3428.0 | 35470.0 | 2091.0 | 868.0 | 353.0 | 651.0 | 24465.0 | ... | 2084.0 | 10075.0 | 32585.0 | 2739.0 | 6964.0 | 50.0 | 3344.0 | 4949.0 | 938.0 | 239.0 |
2020-09-07 | 520.0 | 5695.0 | 5056.0 | 4136.0 | 31150.0 | 2159.0 | 486.0 | 323.0 | 697.0 | 18770.0 | ... | 1791.0 | 10229.0 | 27401.0 | 2951.0 | 6977.0 | 31.0 | 3163.0 | 6360.0 | 1325.0 | 190.0 |
2020-09-14 | 541.0 | 6336.0 | 4347.0 | 2768.0 | 22543.0 | 2404.0 | 1530.0 | 307.0 | 692.0 | 18268.0 | ... | 1501.0 | 8427.0 | 23034.0 | 3540.0 | 7000.0 | 44.0 | 2717.0 | 8569.0 | 1245.0 | 360.0 |
2020-09-21 | 580.0 | 6609.0 | 5737.0 | 5533.0 | 23916.0 | 3760.0 | 1129.0 | 356.0 | 681.0 | 18630.0 | ... | 2068.0 | 10994.0 | 34983.0 | 5111.0 | 6567.0 | 24.0 | 2907.0 | 13091.0 | 1351.0 | 552.0 |
2020-09-28 | 733.0 | 6905.0 | 5685.0 | 3266.0 | 23569.0 | 4354.0 | 1123.0 | 286.0 | 732.0 | 16536.0 | ... | 2869.0 | 9406.0 | 40835.0 | 7048.0 | 5455.0 | 21.0 | 3947.0 | 15629.0 | 1341.0 | 810.0 |
2020-10-05 | 982.0 | 7010.0 | 5381.0 | 3567.0 | 21521.0 | 4319.0 | 1973.0 | 283.0 | 958.0 | 15795.0 | ... | 2860.0 | 8763.0 | 30081.0 | 7281.0 | 5964.0 | 64.0 | 3829.0 | 17438.0 | 1230.0 | 875.0 |
2020-10-12 | 1258.0 | 6911.0 | 6057.0 | 4987.0 | 23244.0 | 5803.0 | 2257.0 | 475.0 | 948.0 | 18209.0 | ... | 4327.0 | 14216.0 | 25823.0 | 8109.0 | 7013.0 | 54.0 | 4271.0 | 18676.0 | 1539.0 | 1173.0 |
2020-10-19 | 1322.0 | 7452.0 | 6110.0 | 5854.0 | 20763.0 | 8035.0 | 2644.0 | 373.0 | 909.0 | 21336.0 | ... | 4911.0 | 13443.0 | 33401.0 | 8730.0 | 7258.0 | 69.0 | 4557.0 | 22811.0 | 2012.0 | 1509.0 |
2020-10-26 | 2166.0 | 11592.0 | 7130.0 | 7074.0 | 30219.0 | 11308.0 | 4078.0 | 417.0 | 944.0 | 25682.0 | ... | 5905.0 | 18277.0 | 38548.0 | 10521.0 | 7447.0 | 139.0 | 5089.0 | 28786.0 | 1930.0 | 2166.0 |
2020-11-02 | 2716.0 | 9513.0 | 6914.0 | 9211.0 | 29618.0 | 15715.0 | 5759.0 | 626.0 | 1127.0 | 29906.0 | ... | 8109.0 | 13939.0 | 110710.0 | 11623.0 | 9143.0 | 134.0 | 6119.0 | 33000.0 | 3012.0 | 2690.0 |
2020-11-09 | 3239.0 | 10609.0 | 9170.0 | 11581.0 | 41223.0 | 25028.0 | 7605.0 | 649.0 | 1552.0 | 34832.0 | ... | 8461.0 | 21972.0 | 61179.0 | 17162.0 | 10059.0 | 216.0 | 9421.0 | 41452.0 | 3570.0 | 3843.0 |
2020-11-16 | 4071.0 | 14028.0 | 11537.0 | 17234.0 | 57384.0 | 34597.0 | 11821.0 | 977.0 | 2273.0 | 44342.0 | ... | 9967.0 | 28019.0 | 77266.0 | 20911.0 | 11160.0 | 620.0 | 13727.0 | 48182.0 | 5655.0 | 5183.0 |
2020-11-23 | 3932.0 | 15812.0 | 11842.0 | 25433.0 | 81135.0 | 37795.0 | 13456.0 | 1226.0 | 2999.0 | 54578.0 | ... | 7570.0 | 28054.0 | 85027.0 | 23641.0 | 16401.0 | 677.0 | 15029.0 | 45131.0 | 6654.0 | 6238.0 |
2020-11-30 | 4443.0 | 15860.0 | 11169.0 | 24514.0 | 102598.0 | 32762.0 | 10555.0 | 1262.0 | 3426.0 | 53444.0 | ... | 6616.0 | 27143.0 | 77175.0 | 16286.0 | 16797.0 | 460.0 | 18941.0 | 32037.0 | 6728.0 | 3874.0 |
2020-12-07 | 4903.0 | 23381.0 | 14683.0 | 39047.0 | 153467.0 | 34561.0 | 18549.0 | 1767.0 | 5055.0 | 64630.0 | ... | 6036.0 | 32250.0 | 106372.0 | 21932.0 | 21035.0 | 752.0 | 17031.0 | 31659.0 | 8286.0 | 3680.0 |
2020-12-14 | 3908.0 | 24904.0 | 15015.0 | 54426.0 | 218609.0 | 26619.0 | 18148.0 | 1719.0 | 5818.0 | 69233.0 | ... | 4854.0 | 55834.0 | 97337.0 | 18234.0 | 26279.0 | 713.0 | 23549.0 | 27896.0 | 8266.0 | 2790.0 |
2020-12-21 | 2383.0 | 27509.0 | 16050.0 | 41118.0 | 307304.0 | 18576.0 | 13385.0 | 1702.0 | 4735.0 | 77384.0 | ... | 3720.0 | 64380.0 | 109807.0 | 16911.0 | 25741.0 | 648.0 | 15917.0 | 23068.0 | 8943.0 | 2343.0 |
2020-12-28 | 1839.0 | 21699.0 | 13421.0 | 43099.0 | 263628.0 | 14576.0 | 13823.0 | 1602.0 | 4504.0 | 67592.0 | ... | 2583.0 | 40264.0 | 91478.0 | 13807.0 | 25285.0 | 544.0 | 12829.0 | 15784.0 | 8099.0 | 1586.0 |
48 rows × 53 columns
Question: Can we predict, say, Oregon's case counts using the other states?
from sklearn.linear_model import LinearRegression as lm
other_states = covid.loc[:,covid.columns != "OR"]
obs_OR = covid.loc[:,"OR"]
OR_fit = lm().fit(other_states, obs_OR)
est_OR = OR_fit.predict(other_states)
plt.plot(covid.index, obs_OR, label="observed")
plt.plot(covid.index, est_OR, label="estimated", linestyle="--")
plt.legend(); plt.xlabel("date"); plt.ylabel("OR case counts");
Gee, we can predict the case counts perfectly? Does that seem very likely? What's going on?
Let's think about what we're trying to do here. We're trying to find coefficients $b_1, \ldots, b_k$ so that the the linear combination of the columns of $X$ $$ \hat y = b_1 X_{\cdot 1} + \cdots + b_k X_{\cdot k} $$ is as close to $y$ as possible.
Well, when is it possible to find $b$ so that $\hat y = y$? It is possible if $y$ is in the column space of $X$.
Recall that $X$ is an $n \times k$ matrix. If $n \le k$ then the columns of $X$ span then entire space $\mathbb{R}^n$ (unless for instance some columns are identical).
Takeaway: if you have more variables than observations (the problem is singular and) it is always possible to exactly predict the response.
However, these predictions are unlikely to be generalizable.
How to tell how good your model is?
See how well it predicts "new" data.
To do $k$-fold crossvalidation:
Predictions for data used to fit the model ("training error") should not be much further off than for data held out ("test error").
This can be used either as an indication of overfitting or to compare different models to each other.
Here's a pretty 'easy' prediction problem:
n = 100
x = np.array([
rng.uniform(low=0, high=10, size=n),
rng.normal(size=n),
]).T
y = 2.5 + 1.5 * x[:,0] - 4.2 * x[:,1] + rng.normal(size=n)
fig, (ax0, ax1) = plt.subplots(1, 2)
ax0.scatter(x[:,0], y); ax0.set_xlabel("x0"); ax0.set_ylabel("y")
ax1.scatter(x[:,1], y); ax1.set_xlabel("x1"); ax1.set_ylabel("y");
Let's (a) fit a linear model and (b) do crossvalidation to look for evidence of overfitting.
def rmse(X, y, model):
yhat = model.predict(X)
resids = y - yhat
return np.sqrt(np.mean(resids ** 2))
def kfold(k, X, y, model):
n = len(y)
folds = np.repeat(np.arange(k), np.ceil(n / k))[:n]
rng.shuffle(folds)
test_rmse = []
train_rmse = []
for ik in range(k):
test_X = X[folds == ik]
test_y = y[folds == ik]
train_X = X[folds != ik]
train_y = y[folds != ik]
model.fit(train_X, train_y)
test_rmse.append(rmse(test_X, test_y, model))
train_rmse.append(rmse(train_X, train_y, model))
return pd.DataFrame({
"test" : test_rmse,
"train" : train_rmse,
})
crossval = kfold(5, x, y, lm())
crossval.mean()
test 1.121362 train 1.071623 dtype: float64
crossval
test | train | |
---|---|---|
0 | 1.002842 | 1.102181 |
1 | 1.280119 | 1.030386 |
2 | 1.342009 | 1.017593 |
3 | 1.042074 | 1.093432 |
4 | 0.939767 | 1.114522 |
We're going to add more variables - these will be independent of everything else, so they should not give us meaningful predictive power for $y$. However, by chance each is a little correlated with $y$.
crossval = pd.DataFrame()
for new_vars in np.linspace(0, 80, 9):
new_x = rng.normal(size=(n, int(new_vars)))
X = np.column_stack([x, new_x])
xval = kfold(5, X, y, lm())
xval["new_vars"] = int(new_vars)
crossval = crossval.append(xval)
crossval.groupby("new_vars").agg(np.mean).plot();
Recall that, somewhat mysteriously, scikit-learn's method to fit a Binomial GLM with a logistic link function has a "penalty" option. What's that?
Well, it's finding $b$ to maximize the likelihood under the following model: $$ Y_i \sim \text{Binomial}(N, p(X_i \cdot b)), $$ where $p(\cdot)$ is the logistic function. The terms in the log-likelihood that depend on $b$ are $$ \sum_{i=1}^n \left\{ Y_i \log(p(X_i \cdot b)) + (N_i - Y_i) \log(1 - p(X_i \cdot b)) \right\} . $$
The problem we had above was that the variables that didn't matter had small but nonzero estimated parameters; and there were so many of them, that together they added up to something big.
Solution: "encourage" them to be small.
So, from our log-likelihood $$ \sum_{i=1}^n \left\{ Y_i \log(p(X_i \cdot b)) + (N_i - Y_i) \log(1 - p(X_i \cdot b)) \right\} $$ we subtract a "regularization" term that does the "encouraging". Options: $$\begin{aligned} \sum_j |b_j| \qquad & \text{"L1" or "$\ell_1$" or "lasso"}\\ \sum_j b_j^2 \qquad & \text{"L2" or "$\ell_2$" or "ridge" or "Tikhonov"} \end{aligned}$$
Let's do it, with the Lasso, which (since this is a standard least-squares linear model) minimizes: $$ \sum_i (y_i - X_i \cdot b)^2 + \alpha \sum_j |b_j| . $$
from sklearn.linear_model import Lasso
ridge_crossval = pd.DataFrame()
for new_vars in np.linspace(0, 80, 9):
new_x = rng.normal(size=(len(y), int(new_vars)))
X = np.column_stack([x, new_x])
xval = kfold(5, X, y, Lasso(alpha=.3))
xval["new_vars"] = int(new_vars)
ridge_crossval = ridge_crossval.append(xval)
ridge_crossval.groupby("new_vars").agg(np.mean).plot();
Let's see
(a) how well coefficients are estimated with the lasso, and
(b) how well predictions are made with the lasso.
To do this, we'll
(IN CLASS)
true_params = {"slope": -.15, "intercept": 1.2}
def mean_cups(sleep):
return np.exp(true_params['intercept'] + true_params['slope'] * sleep)
def sim_coffee(n):
sleep = rng.uniform(low=1, high=10, size=n) # not really uniform
coffee = rng.poisson(lam=mean_cups(sleep))
return sleep, coffee
sleep, coffee = sim_coffee(100)
plt.scatter(sleep, coffee)
plt.plot(np.linspace(1, 10, 11), mean_cups(np.linspace(1, 10, 11)), label='expected number')
plt.xlabel("hours of sleep"); plt.ylabel("cups of coffee"); plt.legend();
from sklearn.linear_model import PoissonRegressor as poisson_glm
def fit_model(sleep, coffee, alpha):
model = poisson_glm(alpha=alpha)
model.fit(np.array([sleep]).T, coffee)
return model
model = fit_model(sleep, coffee, alpha=1)
model.coef_[0], model.intercept_
(-0.11482714727544707, 1.0905232611565594)
plt.scatter(sleep, coffee)
plt.xlabel("hours of sleep"); plt.ylabel("cups of coffee");
coffee_hat = np.exp(model.intercept_ + model.coef_[0] * sleep)
plt.scatter(sleep, coffee_hat, label='predicted')
mean_coffee = mean_cups(sleep)
plt.scatter(sleep, mean_coffee, label='true mean')
plt.legend();
def experiment(n, alpha):
# n times, generate data, fit a model, and return an array of the estimated coefficients
out = np.zeros((n, 2))
for j in range(n):
sleep, coffee = sim_coffee(100)
model = fit_model(sleep, coffee, alpha)
out[j, :] = model.coef_[0], model.intercept_
return out
experiment(10, 1)
array([[-0.10180883, 0.89003322], [-0.1223202 , 1.10556814], [-0.12684918, 1.10732741], [-0.13122881, 1.04476602], [-0.12383794, 1.17677132], [-0.16069652, 1.16739785], [-0.16488536, 1.33573878], [-0.14585712, 1.26087712], [-0.129308 , 1.04148244], [-0.11134622, 0.92374059]])
exp_results = experiment(300, alpha=1)
fig, (ax0, ax1) = plt.subplots(1, 2)
ax0.hist(exp_results[:, 0])
ax0.set_title("estimated slopes")
ax0.axvline(true_params['slope'], color='red')
ax1.hist(exp_results[:, 1])
ax1.axvline(true_params['intercept'], color='red')
ax1.set_title("estimated intercepts")
Text(0.5, 1.0, 'estimated intercepts')
Now let's look at the effect of changing alpha:
alpha_vals = np.linspace(0, 30, 11)
many_experiments = np.array([experiment(100, alpha=a) for a in alpha_vals])
min_estimates = np.min(many_experiments[:,:,0], axis=1)
max_estimates = np.max(many_experiments[:,:,0], axis=1)
mean_estimates = np.mean(many_experiments[:,:,0], axis=1)
for a, x, y, z in zip(alpha_vals, min_estimates, mean_estimates, max_estimates):
plt.plot([a, a], [x, z])
plt.scatter([a], [y])
plt.axhline(true_params['slope'], linestyle=":", label='true value')
plt.xlabel("alpha"); plt.ylabel("slope"); plt.legend();