Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Bass Diffusion Model #1328

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 115 additions & 0 deletions pymc_marketing/bass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright 2025 The PyMC Labs Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Bass diffusion model for product adoption.

Adapted from Wiki: https://en.wikipedia.org/wiki/Bass_diffusion_model

"""

import pymc as pm
import pytensor.tensor as pt
from pymc.model import Model

Check warning on line 22 in pymc_marketing/bass.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/bass.py#L20-L22

Added lines #L20 - L22 were not covered by tests

from pymc_marketing.prior import Censored, Prior, VariableFactory, create_dim_handler

Check warning on line 24 in pymc_marketing/bass.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/bass.py#L24

Added line #L24 was not covered by tests


def F(p, q, t):

Check warning on line 27 in pymc_marketing/bass.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/bass.py#L27

Added line #L27 was not covered by tests
"""Installed base fraction."""
return (1 - pt.exp(-(p + q) * t)) / (1 + (q / p) * pt.exp(-(p + q) * t))

Check warning on line 29 in pymc_marketing/bass.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/bass.py#L29

Added line #L29 was not covered by tests


def f(p, q, t):

Check warning on line 32 in pymc_marketing/bass.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/bass.py#L32

Added line #L32 was not covered by tests
"""Installed base fraction rate of change."""
return (p + q) * pt.exp(-(p + q) * t) / (1 + (q / p) * pt.exp(-(p + q) * t)) ** 2

Check warning on line 34 in pymc_marketing/bass.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/bass.py#L34

Added line #L34 was not covered by tests


def create_bass_model(

Check warning on line 37 in pymc_marketing/bass.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/bass.py#L37

Added line #L37 was not covered by tests
t: pt.TensorLike,
observed: pt.TensorLike | None,
priors: dict[str, Prior | Censored | VariableFactory],
coords,
) -> Model:
"""Define a Bass diffusion model."""
with pm.Model(coords=coords) as model:
combined_dims = (

Check warning on line 45 in pymc_marketing/bass.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/bass.py#L44-L45

Added lines #L44 - L45 were not covered by tests
"date",
*set(priors["p"].dims).union(priors["q"].dims).union(priors["m"].dims),
)
dim_handler = create_dim_handler(combined_dims)

Check warning on line 49 in pymc_marketing/bass.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/bass.py#L49

Added line #L49 was not covered by tests

m = dim_handler(priors["m"].create_variable("m"), priors["m"].dims)
p = dim_handler(priors["p"].create_variable("p"), priors["p"].dims)
q = dim_handler(priors["q"].create_variable("q"), priors["q"].dims)

Check warning on line 53 in pymc_marketing/bass.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/bass.py#L51-L53

Added lines #L51 - L53 were not covered by tests

time = dim_handler(t, "date")

Check warning on line 55 in pymc_marketing/bass.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/bass.py#L55

Added line #L55 was not covered by tests

adopters = pm.Deterministic("adopters", m * f(p, q, time), dims=combined_dims)

Check warning on line 57 in pymc_marketing/bass.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/bass.py#L57

Added line #L57 was not covered by tests

pm.Deterministic(

Check warning on line 59 in pymc_marketing/bass.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/bass.py#L59

Added line #L59 was not covered by tests
"innovators",
m * p * (1 - F(p, q, time)),
dims=combined_dims,
)
pm.Deterministic(

Check warning on line 64 in pymc_marketing/bass.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/bass.py#L64

Added line #L64 was not covered by tests
"imitators",
m * q * F(p, q, time) * (1 - F(p, q, time)),
dims=combined_dims,
)

pm.Deterministic(

Check warning on line 70 in pymc_marketing/bass.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/bass.py#L70

Added line #L70 was not covered by tests
"peak",
(pt.log(q) - pt.log(p)) / (p + q),
dims=combined_dims[1:],
)

priors["likelihood"].dims = combined_dims
priors["likelihood"].create_likelihood_variable( # type: ignore

Check warning on line 77 in pymc_marketing/bass.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/bass.py#L76-L77

Added lines #L76 - L77 were not covered by tests
"y",
mu=adopters,
observed=observed,
)

return model

Check warning on line 83 in pymc_marketing/bass.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/bass.py#L83

Added line #L83 was not covered by tests


if __name__ == "__main__":
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

Check warning on line 89 in pymc_marketing/bass.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/bass.py#L86-L89

Added lines #L86 - L89 were not covered by tests

from pymc_marketing.plot import plot_curve

Check warning on line 91 in pymc_marketing/bass.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/bass.py#L91

Added line #L91 was not covered by tests

n_dates = 12 * 3
dates = pd.date_range(start="2020-01-01", periods=n_dates, freq="MS")
t = np.arange(n_dates)

Check warning on line 95 in pymc_marketing/bass.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/bass.py#L93-L95

Added lines #L93 - L95 were not covered by tests

coords = {"date": dates, "product": ["A", "B", "C"]}

Check warning on line 97 in pymc_marketing/bass.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/bass.py#L97

Added line #L97 was not covered by tests

priors = {

Check warning on line 99 in pymc_marketing/bass.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/bass.py#L99

Added line #L99 was not covered by tests
"m": Prior("DiracDelta", c=5000),
"p": Prior("Beta", alpha=13.85, beta=692.43, dims="product"),
"q": Prior("Beta", alpha=36.2, beta=54.4),
# "p": Prior("DiracDelta", c=0.01),
# "q": Prior("DiracDelta", c=0.15),
"likelihood": Prior(
"Poisson",
dims="date",
),
}
model = create_bass_model(t, observed=None, priors=priors, coords=coords)
with model:
idata = pm.sample_prior_predictive()

Check warning on line 112 in pymc_marketing/bass.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/bass.py#L110-L112

Added lines #L110 - L112 were not covered by tests

idata.prior["y"].pipe(plot_curve, {"date"})
plt.show()

Check warning on line 115 in pymc_marketing/bass.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/bass.py#L114-L115

Added lines #L114 - L115 were not covered by tests
Loading