Exercise: Estimating Elephants¶
$$\renewcommand{\P}{\mathbb{P}} \newcommand{\E}{\mathbb{E}} \newcommand{\var}{\text{var}} \newcommand{\sd}{\text{sd}} \newcommand{\cov}{\text{cov}} \newcommand{\cor}{\text{cor}}$$
This is here so we can use \P and \E and \var and \cov and \cor and \sd in LaTeX below.
Elephants, at birth, are about 1m long measured along their backs, and grow about 10cm/year for the first 20 years, although elephants of the same age differ by 10-20% or so (see Trimble et al). Their rate of growth is also affected by health (e.g., food availability and parasite load). How well can we estimate the age of juvenile elephants (between 10-20 years old) based on their lengths in aerial photographs? Does it help much to take into account food availability?
To see how well we expect this to work, let's simulate some data.
import numpy as np
import matplotlib.pyplot as plt
rng = np.random.default_rng(seed=123)
n = 100
age = rng.uniform(low=10, high=20, size=n) # in years
What about food availability? Let's measure food availability as a percentage of 'optimal', and suppose that for each 10% that food drops from this point, average elephant size goes down by .15 * .25 m, on average (15% of a standard deviation).
food = rng.gamma(shape=10, scale=8, size=n) # in percent
mean_length = 1 + .1 * age - .15 *.25 * (100 - food)/10 # in m
length = rng.normal(loc=mean_length, scale=0.25, size=n)
plt.hist(food);
plt.scatter(age, mean_length)
plt.scatter(age, length, c=food)
plt.xlabel("age (years)"); plt.ylabel("length (m)")
plt.axline((10, 2), slope=.1) # mean size at 100% food
plt.colorbar(label='food avail');
The inference problem¶
We'd like to infer age based on length. The line that minimizes mean squared error has slope $\sd(Y)/\sd(X) \times \cor[X, Y]$, and has the right mean.
Make the length-vs-age plot and add this line.
a = np.corrcoef(age, length)[0,1] * np.std(age, ddof=1) / np.std(length, ddof=1)
b = np.mean(age) - a * np.mean(length)
print(a, b)
4.704069567707924 3.464646893346048
age_hat = b + a * length
fig, ax = plt.subplots()
ax.scatter(length, age, label='observed')
ax.scatter(length, age_hat, label='predicted')
ax.legend();
ax.set_xlabel("length"); ax.set_ylabel("age");
How can we answer the question "how well can we estimate age, based on length"? One answer to this is the root mean squared error. Compute this.
np.sqrt(
np.mean(
(age_hat - age)**2
)
)
np.float64(2.1087153735840225)
Multivariate inference¶
Now let's use food availability also!
Recall that $a$ solves $$ (x^T x) a = x^T y .$$ (and don't forget the intercept!)
X = np.column_stack([
np.ones(n),
length,
food
])
X[:5,:]
array([[ 1. , 2.41047531, 51.65953572],
[ 1. , 2.20074333, 100.47012885],
[ 1. , 2.06274541, 126.34161723],
[ 1. , 1.66858873, 47.77556973],
[ 1. , 2.10190759, 86.63821377]])
intercept, a_length, a_food = np.linalg.solve(
X.T.dot(X),
X.T.dot(age)
)
intercept, a_length, a_food
(np.float64(4.550814802710699), np.float64(5.124045892831306), np.float64(-0.025940018774039154))
new_age_hat = intercept + a_length * length + a_food * food
# equivalently
new_age_hat = X.dot(a)
fig, ax = plt.subplots()
ax.scatter(age, new_age_hat),
ax.set_xlabel("age"); ax.set_ylabel("estimated age");
--------------------------------------------------------------------------- ValueError Traceback (most recent call last) Cell In[12], line 2 1 fig, ax = plt.subplots() ----> 2 ax.scatter(age, new_age_hat), 3 ax.set_xlabel("age"); ax.set_ylabel("estimated age"); File ~/micromamba/envs/dsci/lib/python3.12/site-packages/matplotlib/_api/deprecation.py:453, in make_keyword_only.<locals>.wrapper(*args, **kwargs) 447 if len(args) > name_idx: 448 warn_deprecated( 449 since, message="Passing the %(name)s %(obj_type)s " 450 "positionally is deprecated since Matplotlib %(since)s; the " 451 "parameter will become keyword-only in %(removal)s.", 452 name=name, obj_type=f"parameter of {func.__name__}()") --> 453 return func(*args, **kwargs) File ~/micromamba/envs/dsci/lib/python3.12/site-packages/matplotlib/__init__.py:1521, in _preprocess_data.<locals>.inner(ax, data, *args, **kwargs) 1518 @functools.wraps(func) 1519 def inner(ax, *args, data=None, **kwargs): 1520 if data is None: -> 1521 return func( 1522 ax, 1523 *map(cbook.sanitize_sequence, args), 1524 **{k: cbook.sanitize_sequence(v) for k, v in kwargs.items()}) 1526 bound = new_sig.bind(ax, *args, **kwargs) 1527 auto_label = (bound.arguments.get(label_namer) 1528 or bound.kwargs.get(label_namer)) File ~/micromamba/envs/dsci/lib/python3.12/site-packages/matplotlib/axes/_axes.py:4930, in Axes.scatter(self, x, y, s, c, marker, cmap, norm, vmin, vmax, alpha, linewidths, edgecolors, colorizer, plotnonfinite, **kwargs) 4928 y = np.ma.ravel(y) 4929 if x.size != y.size: -> 4930 raise ValueError("x and y must be the same size") 4932 if s is None: 4933 s = (20 if mpl.rcParams['_internal.classic_mode'] else 4934 mpl.rcParams['lines.markersize'] ** 2.0) ValueError: x and y must be the same size
np.sqrt(
np.mean(
(new_age_hat - age)**2
)
)
--------------------------------------------------------------------------- ValueError Traceback (most recent call last) Cell In[13], line 3 1 np.sqrt( 2 np.mean( ----> 3 (new_age_hat - age)**2 4 ) 5 ) ValueError: operands could not be broadcast together with shapes (100,3) (100,)