Skip to content

Replace np.sum(a * b) with a @ b for better performance and accuracy #542

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Aug 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lectures/career.md
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,8 @@ class CareerWorkerProblem:

self.F_probs = BetaBinomial(grid_size - 1, F_a, F_b).pdf()
self.G_probs = BetaBinomial(grid_size - 1, G_a, G_b).pdf()
self.F_mean = np.sum(self.θ * self.F_probs)
self.G_mean = np.sum(self.ϵ * self.G_probs)
self.F_mean = self.θ @ self.F_probs
self.G_mean = self.ϵ @ self.G_probs

# Store these parameters for str and repr methods
self._F_a, self._F_b = F_a, F_b
Expand Down
6 changes: 4 additions & 2 deletions lectures/kalman.md
Original file line number Diff line number Diff line change
Expand Up @@ -783,8 +783,10 @@ e2 = np.empty(T-1)

for t in range(1, T):
kn.update(y[:,t])
e1[t-1] = np.sum((x[:, t] - kn.x_hat.flatten())**2)
e2[t-1] = np.sum((x[:, t] - A @ x[:, t-1])**2)
diff1 = x[:, t] - kn.x_hat.flatten()
diff2 = x[:, t] - A @ x[:, t-1]
e1[t-1] = diff1 @ diff1
e2[t-1] = diff2 @ diff2

fig, ax = plt.subplots(figsize=(9,6))
ax.plot(range(1, T), e1, 'k-', lw=2, alpha=0.6,
Expand Down
6 changes: 3 additions & 3 deletions lectures/lake_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ def _update_bellman(α, β, γ, c, σ, w_vec, p_vec, V, V_new, U):
V_new[w_idx] = u(w, σ) + β * ((1 - α) * V[w_idx] + α * U)

U_new = u(c, σ) + β * (1 - γ) * U + \
β * γ * np.sum(np.maximum(U, V) * p_vec)
β * γ * (np.maximum(U, V) @ p_vec)

return U_new

Expand Down Expand Up @@ -836,8 +836,8 @@ def compute_steady_state_quantities(c, τ):
u, e = x

# Compute steady state welfare
w = np.sum(V * p_vec * (w_vec - τ > w_bar)) / np.sum(p_vec * (w_vec -
τ > w_bar))
mask = (w_vec - τ > w_bar)
w = ((V * p_vec * mask) @ np.ones_like(p_vec)) / ((p_vec * mask) @ np.ones_like(p_vec))
welfare = e * w + u * U

return e, u, welfare
Expand Down
14 changes: 12 additions & 2 deletions lectures/linear_algebra.md
Original file line number Diff line number Diff line change
Expand Up @@ -246,15 +246,25 @@ Continuing on from the previous example, the inner product and norm can be compu
follows

```{code-cell} python3
np.sum(x * y) # Inner product of x and y
np.sum(x * y) # Inner product of x and y, method 1
```

```{code-cell} python3
x @ y # Inner product of x and y, method 2 (preferred)
```

The `@` operator is preferred because it uses optimized BLAS libraries that implement fused multiply-add operations, providing better performance and numerical accuracy compared to the separate multiply and sum operations.

```{code-cell} python3
np.sqrt(np.sum(x**2)) # Norm of x, take one
```

```{code-cell} python3
np.linalg.norm(x) # Norm of x, take two
np.sqrt(x @ x) # Norm of x, take two (preferred)
```

```{code-cell} python3
np.linalg.norm(x) # Norm of x, take three
```

### Span
Expand Down
2 changes: 1 addition & 1 deletion lectures/mccall_model_with_separation.md
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def update(model, v, d):
" One update on the Bellman equations. "
α, β, c, w, q = model.α, model.β, model.c, model.w, model.q
v_new = u(w) + β * ((1 - α) * v + α * d)
d_new = jnp.sum(jnp.maximum(v, u(c) + β * d) * q)
d_new = jnp.maximum(v, u(c) + β * d) @ q
return v_new, d_new

@jax.jit
Expand Down
2 changes: 1 addition & 1 deletion lectures/mix_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -820,7 +820,7 @@ def learn_x_bayesian(observations, α0, β0, grid_size=2000):
post = np.exp(log_post)
post /= post.sum()

μ_path[t + 1] = np.sum(x_grid * post)
μ_path[t + 1] = x_grid @ post

return μ_path

Expand Down
Loading