Exercise: Estimating Elephants¶

$$\renewcommand{\P}{\mathbb{P}} \newcommand{\E}{\mathbb{E}} \newcommand{\var}{\text{var}} \newcommand{\sd}{\text{sd}} \newcommand{\cov}{\text{cov}} \newcommand{\cor}{\text{cor}}$$ This is here so we can use \P and \E and \var and \cov and \cor and \sd in LaTeX below.

Elephants, at birth, are about 1m long measured along their backs, and grow about 10cm/year for the first 20 years, although elephants of the same age differ by 10-20% or so (see Trimble et al). Their rate of growth is also affected by health (e.g., food availability and parasite load). How well can we estimate the age of juvenile elephants (between 10-20 years old) based on their lengths in aerial photographs? Does it help much to take into account food availability?

To see how well we expect this to work, let's simulate some data.

In [1]:
import numpy as np
import matplotlib.pyplot as plt
rng = np.random.default_rng(seed=123)
In [2]:
n = 100
age = rng.uniform(low=10, high=20, size=n) # in years

What about food availability? Let's measure food availability as a percentage of 'optimal', and suppose that for each 10% that food drops from this point, average elephant size goes down by .15 * .25 m, on average (15% of a standard deviation).

In [3]:
food = rng.gamma(shape=10, scale=8, size=n) # in percent
mean_length = 1 + .1 * age  - .15 *.25 * (100 - food)/10 # in m
length = rng.normal(loc=mean_length, scale=0.25, size=n)
In [4]:
plt.hist(food);
No description has been provided for this image
In [5]:
plt.scatter(age, mean_length)
plt.scatter(age, length, c=food)
plt.xlabel("age (years)"); plt.ylabel("length (m)")
plt.axline((10, 2), slope=.1) # mean size at 100% food
plt.colorbar(label='food avail');
No description has been provided for this image

The inference problem¶

We'd like to infer age based on length. The line that minimizes mean squared error has slope $\sd(Y)/\sd(X) \times \cor[X, Y]$, and has the right mean.

Make the length-vs-age plot and add this line.

In [6]:
a = np.corrcoef(age, length)[0,1] * np.std(age, ddof=1) / np.std(length, ddof=1)
b = np.mean(age) - a * np.mean(length)
print(a, b)
4.704069567707924 3.464646893346048
In [7]:
age_hat = b + a * length
fig, ax = plt.subplots()
ax.scatter(length, age, label='observed')
ax.scatter(length, age_hat, label='predicted')
ax.legend();
ax.set_xlabel("length"); ax.set_ylabel("age");
No description has been provided for this image

How can we answer the question "how well can we estimate age, based on length"? One answer to this is the root mean squared error. Compute this.

In [8]:
np.sqrt(
    np.mean(
        (age_hat - age)**2
    )
)
Out[8]:
np.float64(2.1087153735840225)

Multivariate inference¶

Now let's use food availability also!

Recall that $a$ solves $$ (x^T x) a = x^T y .$$ (and don't forget the intercept!)

In [9]:
X = np.column_stack([
    np.ones(n),
    length,
    food
])
X[:5,:]
Out[9]:
array([[  1.        ,   2.41047531,  51.65953572],
       [  1.        ,   2.20074333, 100.47012885],
       [  1.        ,   2.06274541, 126.34161723],
       [  1.        ,   1.66858873,  47.77556973],
       [  1.        ,   2.10190759,  86.63821377]])
In [10]:
intercept, a_length, a_food = np.linalg.solve(
    X.T.dot(X),
    X.T.dot(age)
)
intercept, a_length, a_food
Out[10]:
(np.float64(4.550814802710699),
 np.float64(5.124045892831306),
 np.float64(-0.025940018774039154))
In [11]:
new_age_hat = intercept + a_length * length + a_food * food
# equivalently
new_age_hat = X.dot(a)
In [12]:
fig, ax = plt.subplots()
ax.scatter(age, new_age_hat),
ax.set_xlabel("age"); ax.set_ylabel("estimated age");
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[12], line 2
      1 fig, ax = plt.subplots()
----> 2 ax.scatter(age, new_age_hat),
      3 ax.set_xlabel("age"); ax.set_ylabel("estimated age");

File ~/micromamba/envs/dsci/lib/python3.12/site-packages/matplotlib/_api/deprecation.py:453, in make_keyword_only.<locals>.wrapper(*args, **kwargs)
    447 if len(args) > name_idx:
    448     warn_deprecated(
    449         since, message="Passing the %(name)s %(obj_type)s "
    450         "positionally is deprecated since Matplotlib %(since)s; the "
    451         "parameter will become keyword-only in %(removal)s.",
    452         name=name, obj_type=f"parameter of {func.__name__}()")
--> 453 return func(*args, **kwargs)

File ~/micromamba/envs/dsci/lib/python3.12/site-packages/matplotlib/__init__.py:1521, in _preprocess_data.<locals>.inner(ax, data, *args, **kwargs)
   1518 @functools.wraps(func)
   1519 def inner(ax, *args, data=None, **kwargs):
   1520     if data is None:
-> 1521         return func(
   1522             ax,
   1523             *map(cbook.sanitize_sequence, args),
   1524             **{k: cbook.sanitize_sequence(v) for k, v in kwargs.items()})
   1526     bound = new_sig.bind(ax, *args, **kwargs)
   1527     auto_label = (bound.arguments.get(label_namer)
   1528                   or bound.kwargs.get(label_namer))

File ~/micromamba/envs/dsci/lib/python3.12/site-packages/matplotlib/axes/_axes.py:4930, in Axes.scatter(self, x, y, s, c, marker, cmap, norm, vmin, vmax, alpha, linewidths, edgecolors, colorizer, plotnonfinite, **kwargs)
   4928 y = np.ma.ravel(y)
   4929 if x.size != y.size:
-> 4930     raise ValueError("x and y must be the same size")
   4932 if s is None:
   4933     s = (20 if mpl.rcParams['_internal.classic_mode'] else
   4934          mpl.rcParams['lines.markersize'] ** 2.0)

ValueError: x and y must be the same size
No description has been provided for this image
In [13]:
np.sqrt(
    np.mean(
        (new_age_hat - age)**2
    )
)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[13], line 3
      1 np.sqrt(
      2     np.mean(
----> 3         (new_age_hat - age)**2
      4     )
      5 )

ValueError: operands could not be broadcast together with shapes (100,3) (100,)