Code
library(reticulate)
This notebook demonstrates the use of Bayesian inference for sales forecasting using various probabilistic programming techniques. We will use the numpyro
library to define and fit our models, and plotnine
for visualization.
Bayesian inference allows us to incorporate prior knowledge and quantify uncertainty in our predictions. This notebook will guide you through the process of building a Bayesian model for sales forecasting, fitting the model using Markov Chain Monte Carlo (MCMC) and Stochastic Variational Inference (SVI), and visualizing the results.
library(reticulate)
import os
import sys
# Set XLA_FLAGS before JAX is imported
"XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" os.environ[
import polars as pl
import pandas as pd
import numpy as np
from plotnine import ggplot, aes, geom_point, geom_line, labs, theme_minimal, theme_bw, scale_x_continuous, scale_x_discrete, scale_x_datetime
import patsy
import jax
import jax.numpy as jnp
import jax.random as random
from jax.scipy.special import expit, logit
import numpyro
import numpyro.distributions as dist
#from numpyro.infer import SVI, Trace_ELBO, Predictive
from numpyro.infer import MCMC, NUTS, MCMC, NUTS #, SVI, Trace_ELBO
import arviz as az
import polars as pl
from plotnine import theme, guides, guide_legend
import pickle
#jax.config.update("jax_enable_x64", True) # Enable float64 by default
In this section, we define the Bayesian model used for sales forecasting. The model incorporates various components such as random walk for the latent state, day-of-the-week effects, day-of-the-year effects, and price elasticity. The model is implemented using the numpyro
library, which allows for efficient and scalable Bayesian inference.
These auxiliary functions are essential for data preprocessing and transformation:
periodic_rbf
: Computes a periodic Gaussian radial basis function (RBF).compute_doy_basis
: Computes 12 periodic Gaussian basis functions for seasonal effects.read_data
: Reads and preprocesses the sales data from a CSV file.init_values
: Initializes values for the model parameters.# Define a periodic Gaussian radial basis function (RBF)
def periodic_rbf(x, mu, sigma):
"""
Computes a periodic Gaussian radial basis function (RBF).
Args:
x: Scaled day-of-year values (range [0,1]).
mu: Center of the Gaussian basis function.
sigma: Controls the spread of the Gaussian.
Returns:
RBF values preserving periodicity.
"""
# compute cyclic distance to mu
= jnp.minimum(jnp.abs(x - mu), 1 - jnp.abs(x - mu))
periodic_distance # compute RBF value
return jnp.exp(- (periodic_distance ** 2) / (2 * sigma ** 2))
def compute_doy_basis(yday_fraction, sigma = 30/365.25, n_centers = 12):
"""
Computes 12 periodic Gaussian basis functions for seasonal effects.
Args:
yday_fraction: Normalized day of the year (range [0,1]).
yday_factor: Scaling factor for basis function width.
Returns:
A JAX array with 12 columns representing the 12 monthly basis functions.
"""
# Define centers of Gaussian basis functions
= jnp.linspace( 1/(2*n_centers), 1-1/(2*n_centers), n_centers)
month_centers
# Generate an array of shape (length of input, 12) with the RBF values
= jnp.stack([periodic_rbf(yday_fraction, mu, sigma) for mu in month_centers], axis=-1)
doy_basis
# Subtract each row's mean to enforce sum-to-zero constraint
= doy_basis - jnp.mean(doy_basis, axis=-1, keepdims=True)
doy_basis_centered
return doy_basis_centered
def read_data(fname, n_rows=None):
"""
Reads and preprocesses the sales data from a CSV file.
Args:
fname: The filename of the CSV file containing the sales data.
Returns:
A dictionary with the following keys:
- sales: An array of sales data.
- log_price: An array of log-transformed prices.
- wday: An array of day-of-the-week values.
- yday_fraction: An array of normalized day-of-the-year values.
"""
# Read the CSV file using polars
= pl.read_csv(fname)
df
# Keep only the first n_rows if specified
if n_rows is not None:
= df.head(n_rows)
df
# Convert the 'date' column to date type
= df.with_columns(pl.col("date").str.to_date())
df
# Extract sales, and log price data as a numpy arrays
= df["sales"].to_numpy()
sales = df["log_price"].to_numpy()
log_price
# Extract day-of-the-week values
= df["date"].dt.weekday().to_numpy() # set offset to 0
wday
# Extract day-of-the-year values
= df["date"].dt.ordinal_day().to_numpy()
yday
# Determine if the year is a leap year
= df["date"].dt.is_leap_year().to_numpy()
is_leap_year
# Normalize day-of-the-year values
= yday / (365 + is_leap_year)
yday_fraction
# Return the preprocessed data as a dictionary
return {
"date": df["date"].to_numpy(),
"sales": sales,
"log_price": log_price,
"wday": wday,
"yday_fraction": yday_fraction
}
def interpolate(x, downsampling_factor, n_out):
"""
"""
= jnp.array(range(0, n_out))/jnp.float64(downsampling_factor)
idx_n_weight = jnp.array( jnp.floor(idx_n_weight), dtype=int)
idx_1 = jnp.array( jnp.ceil(idx_n_weight), dtype=int)
idx_2 = idx_n_weight - idx_1
weight_2
= x[idx_1]
state_before = x[idx_2]
state_after
return (1-weight_2)*state_before + weight_2*state_after
def chunked_mean(x, n_chunk):
= x.shape[0]
n = -n % n_chunk # compute padding needed to make k a multiple of n
pad_size = jnp.pad(array = x, pad_width = (0, pad_size), mode = 'edge') # pad at the end
x_padded return x_padded.reshape(-1, n_chunk).mean(axis=1)
def chunked_sum(x, n_chunk):
= x.shape[0]
n = -n % n_chunk # compute padding needed to make k a multiple of n
pad_size = jnp.pad(array = x, pad_width = (0, pad_size)) # pad at the end
x_padded return x_padded.reshape(-1, n_chunk).sum(axis=1)
# Creates a simple plot using plotnine
def plot_function(x, y, title, xlab, ylab):
# Convert x to numpy array
= np.array(x)
x
# Check if y is a callable function
if callable(y):
# If y is a function, apply it to x and create a DataFrame
= pd.DataFrame({"x": x, "y": y(x)})
df else:
# If y is not a function, create a DataFrame directly
= pd.DataFrame({"x": x, "y": y})
df
# Create the plot using ggplot
= (ggplot(df, aes(x="x", y="y")) + geom_line() + labs(title=title, x=xlab, y=ylab) + theme_bw())
plot
# Return the plot
return plot
We model the sales time series as a stochastic process where the underlying rate of sales evolves over time. This evolution follows a random walk structure, but with systematic adjustments for covariates such as price, day-of-the-week effects, and day-of-the-year effects. The rate of sales \(\lambda_t\) on day \(t\) is a function of captures (i) systematic covariate effects (\(z_t\)), (ii) a global baseline (\(\mu_\tau\)), and (iii) the latent dynamic component (\(\tau_t\)).
\[ log~\lambda_t = z_t + \mu_\tau + \tau_t \]
The baseline sales level \(\tau_t\) follows a random walk. Because all contrast matrices for structured effects are centered, \(\mu_\tau + \tau_t\) can be interpreted as the average latent sales rate on \(\tau_t\).
\[ \tau_t = \tau_{t-1} + \delta_t, \quad \delta_t \sim \mathcal{N}(0, \sigma_\tau) \]
with:
\[ \mu_\tau \sim \text{Exponential}(1), \quad \sigma_\tau \sim \mathcal{N}(1) \]
We further accounted for systematic effects of (i) day of the week, (ii) day of the year, and (iii) price.
Similarly, the day-of-the-year effects are modeled using a seasonality basis matrix \(\mathbf{B}_{\text{yday}}\), which represents periodic seasonal patterns using Gaussian radial basis functions (RBFs).
\[ zw_t = \mathbf{C}_{\text{wday}} \cdot \beta_{\text{wday}}, \quad \beta_{\text{wday}} \sim \mathcal{N}(0, 1) \]
\[ zy_t = \mathbf{B}_{\text{yday}} \cdot \beta_{\text{yday}}, \quad \beta_{\text{yday}} \sim \mathcal{N}(0, 1) \]
\[ ze_t = \text{log\_price\_centered} \cdot e, \quad \log(-e) \sim \mathcal{N^{+}}(0, 1) \]
\[ z_t = zw_t + zy_t + ze_t \]
Observed sales are assumed to follow a Poisson distribution, ensuring discrete, non-negative observations:
\[ S_t \sim \text{Poisson}(\lambda_t) \]
def model_local_level_poisson(sales: jnp.array, log_price_centered: jnp.array, wday: jnp.array, yday_fraction: jnp.array,
contrasts_sdif_t: jnp.array, contrasts_wday: jnp.array, contrasts_yday: jnp.array, = 1):
downsampling_factor """
"""
= len(sales)
n_obs = contrasts_sdif_t.shape[0]
n_states
def sample_random_walk(contrasts_sdif_t, n_states):
= numpyro.sample("log_sigma", dist.Gumbel(0, 5))
log_sigma = numpyro.deterministic("sigma", jnp.exp(log_sigma))
sigma = numpyro.sample("log_state_mean", dist.Normal(0, 5)) # to-do: add an average drift term, as well as potentially an additional parameter governing drift dynamics
log_state_mean = numpyro.sample( "log_state_delta", dist.Normal(0, 1), sample_shape=(n_states-1,))
log_state_delta = numpyro.deterministic("log_state_base", jnp.dot(contrasts_sdif_t, log_state_delta) * sigma + log_state_mean )
log_state_base return log_state_base
def sample_downsampled_random_walk(contrasts_sdif_t, n_obs, n_states):
= sample_random_walk(contrasts_sdif_t, n_states)
log_state_base_downsampled = interpolate(log_state_base_downsampled, downsampling_factor, n_obs)
log_state_base return log_state_base
def sample_wday_effect(contrasts_wday, wday):
# Prior for day-of-the-week effects (6 coefficients)
= numpyro.sample("wday_coefficients", dist.Normal(0, 1), sample_shape=(6,))
wday_coefficients
# Compute wday effect per observation (sum-to-zero constraint applied via contrasts)
= jnp.dot(contrasts_wday, wday_coefficients)
wday_effects return jnp.array(wday_effects[wday-1])
def sample_yday_effect(contrasts_yday, yday_fraction):
# Prior for yearly seasonality effects (12 coefficients)
= numpyro.sample("yday_coefficients", dist.Normal(0, 1), sample_shape=(12,))
yday_coefficients return jnp.dot(contrasts_yday, yday_coefficients)
def sample_price_effect(log_price_centered):
# Prior for price elasticity
= numpyro.sample( "elasticity_pos", dist.HalfNormal(10) )
elasticity_pos = numpyro.deterministic("elasticity", -1 * elasticity_pos)
elasticity return log_price_centered * elasticity
# Sample random walk
if n_obs == n_states:
= sample_random_walk(contrasts_sdif_t, n_states)
log_state_base else:
= sample_downsampled_random_walk(contrasts_sdif_t, n_obs, n_states)
log_state_base
# Sample day-of-the-week effects
= sample_wday_effect(contrasts_wday, wday)
wday_effect
# Sample day-of-the-year effects
= sample_yday_effect(contrasts_yday, yday_fraction)
yday_effect
# Sample elasticity effect
= sample_price_effect(log_price_centered)
price_effect
# Compute state
= numpyro.deterministic("state", jnp.exp( log_state_base + yday_effect + wday_effect + price_effect )) # #
state
# Compute log-likelihood for poisson emissions
"sales", dist.Poisson(rate=state), obs=sales) # to-do: create a Poisson distribution paramaterized by log-rate, as in the Stan manual numpyro.sample(
We use the run_nuts
function to fit the model to our sales data. The function leverages the No-U-Turn Sampler (NUTS) from the numpyro
library to perform MCMC sampling. Because the model has a large number of latent parameters, initialization to sensible start values is key.
In order to fit the model, the functions below are used:
prepare_model_arguments
: Transforms the data into a format required by the model, including sales data, log-transformed prices, day-of-the-week values, and normalized day-of-the-year values.
init_values
: Finds sensible start values for the model parameters which, in this case is crucial for the convergence of the MCMC algorithm.
run_nuts
: Given a dataset, it calls the NUTS sampler to perform MCMC sampling.
def init_values(sales: jnp.array, log_price_centered: jnp.array, wday, yday_fraction: jnp.array, downsampling_factor = 1):
"""
"""
= 0.001
epsilon = jnp.log(sales + epsilon)
log_state_est = jnp.mean(log_state_est)
log_state_mean_est = jnp.diff(log_state_est)
log_state_delta_est if downsampling_factor > 1:
= chunked_sum(log_state_delta_est, downsampling_factor)
log_state_delta_est
= jnp.std(log_state_delta_est)
log_state_delta_sd_est
return {
"log_sigma": jnp.log( log_state_delta_sd_est ),
"log_state_mean": log_state_mean_est,
"log_state_delta": log_state_delta_est,
"wday_coefficients": jnp.array([0.0]*6),
"yday_coefficients": jnp.array([0.0]*12),
"log_elasticity": jnp.array([0.0])
}
def prepare_model_arguments(sales: jnp.array, log_price: jnp.array, wday: jnp.array, yday_fraction: jnp.array, downsampling_factor = 1):
"""
Prepares the arguments required for the model.
Args:
sales: Array of sales data.
log_price: Array of log-transformed prices.
wday: Array of day-of-the-week values.
yday_fraction: Array of normalized day-of-the-year values.
downsampling_factor: Factor by which to downsample the data.
Returns:
A tuple containing initialized values for the model parameters and the model arguments.
"""
= len(sales)
n_obs
# Determine the number of states based on the downsampling factor
if downsampling_factor == 1:
= n_obs
n_states else:
= int(np.ceil(n_obs / downsampling_factor) + 1)
n_states
# Define contrast matrix for random walk (T coefficients, sum-to-zero constraint)
= patsy.contrasts.Diff().code_without_intercept(range(0, n_states)).matrix
contrasts_sdif_t
# Define contrast matrix for day-of-the-week effects (6 coefficients, sum-to-zero constraint)
= patsy.contrasts.Diff().code_without_intercept(range(0, 7)).matrix # 7 days → 6 contrasts
contrasts_wday
# Compute yday effect per observation (sum-to-zero constraint applied via contrasts)
= compute_doy_basis(yday_fraction, sigma=30/365.25, n_centers=12) # to-do: do a very approximate calibration of the RBF width parameter sigma, using something like a spline for the long term trend + RBF seasonality
contrasts_yday
# Compute centered log price differences
= log_price - jnp.mean(log_price)
log_price_centered
# Set up the model parameters
= {
model_arguments 'sales': sales,
'log_price_centered': log_price_centered,
'wday': jnp.array(wday, dtype=int),
'yday_fraction': yday_fraction,
'downsampling_factor': downsampling_factor,
'contrasts_sdif_t': contrasts_sdif_t,
'contrasts_wday': contrasts_wday,
'contrasts_yday': contrasts_yday
}
# Prepare initial values for parameters
= init_values(sales, log_price_centered, wday, yday_fraction, downsampling_factor)
init_params
return init_params, model_arguments
def run_nuts(sales: jnp.array, log_price: jnp.array, wday, yday_fraction: jnp.array, downsampling_factor = 1, n_chains = 1, num_warmup=1_000, num_samples=1_000, step_size=0.01, max_tree_depth=8):
""" Runs NUTS MCMC inference on the model
"""
# Initialize random number generator key
= random.PRNGKey(0)
rng_key
# Get the number of observations
= len(sales)
n_obs
# Prepare model arguments and initial parameter values
= prepare_model_arguments(sales = sales, log_price = log_price, wday = wday, yday_fraction = yday_fraction, downsampling_factor = downsampling_factor)
init_params, model_arguments
# Split the random number generator key
= random.split(rng_key)
rng_key, rng_key_
# Set the number of chains for parallel sampling
numpyro.set_host_device_count(n_chains)
# Define the model to be used
= model_local_level_poisson
reparam_model
# Initialize the NUTS kernel with the specified step size and tree depth
= NUTS(reparam_model, step_size=step_size, max_tree_depth=max_tree_depth)
kernel
# Initialize the MCMC sampler with the NUTS kernel
= MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples, num_chains=n_chains)
mcmc
# Run the MCMC sampler
**model_arguments) # disable init values: init_params=init_params
mcmc.run(rng_key_,
# Return the fitted MCMC object
return mcmc, model_arguments
run_nuts
function. The model is fitted using the No-U-Turn Sampler (NUTS) from the numpyro
library, with 4 chains, 1,000 warmup iterations, and 1,000 sampling iterations. The step size is set to 0.01, and the maximum tree depth is 8. The fitted model is stored in the m_fit
variable.# read in the synthetic sales data
= read_data("sales_synthetic.csv") data
with open("sim_parameters.pkl", "rb") as f:
= date_range, growth, growth_plus_rw, scale_factor, wday, weekly_seasonality, yearly_seasonality = pickle.load(f) sim_parameters
# Fit the model using NumPyro NUTS MCMC
= run_nuts(data['sales'], data['log_price'], data['wday'], data['yday_fraction'],
m_fit, model_arguments =7, n_chains=4, num_warmup=1_000, num_samples=1_000,
downsampling_factor=0.01, max_tree_depth=8) step_size
/home/pavel/.cache/uv/archive-v0/mUjrPOYClnQf0pTSkamJu/lib/python3.11/site-packages/jax/_src/numpy/scalar_types.py:50: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
0%| | 0/2000 [00:00<?, ?it/s]
Compiling.. : 0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s][A
Compiling.. : 0%| | 0/2000 [00:00<?, ?it/s][A
0%| | 0/2000 [00:00<?, ?it/s][A[A
Compiling.. : 0%| | 0/2000 [00:00<?, ?it/s][A[A
0%| | 0/2000 [00:00<?, ?it/s][A[A[A
Compiling.. : 0%| | 0/2000 [00:00<?, ?it/s][A[A[A
All effective sample sizes are decent, which is a good sign. The Gelman-Rubin statistics are close to 1, indicating convergence. Inspection of trace plots are beyond the scope of this notebook.
The model successfully reconstructs all key features of the synthetic dataset.
# Let's look at the estimated random walk component of the model.
= az.summary(m_fit, var_names=["sigma", "log_state_delta"], filter_vars="like") summary
$summary py
# Create a sequence of dates starting at data["date"].min() the length of x['mean'], in steps of 7 days
= az.summary(m_fit, var_names=["log_state_base"], filter_vars="like")["mean"].to_numpy() rw_states
/home/pavel/.cache/uv/archive-v0/mUjrPOYClnQf0pTSkamJu/lib/python3.11/site-packages/jax/_src/numpy/scalar_types.py:50: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
= pd.date_range(start = data["date"].min(), periods = len(rw_states), freq='7D')
dates #p = plot_function(dates, np.exp(rw_states), "Estimated Random Walk Component", "Date", "Sales") # to-do: add uncertainty bands
= pl.DataFrame({'x': date_range, 'y': growth_plus_rw*scale_factor, 'var': 'Simulated Trend + Random Walk' })
df_1 = pl.DataFrame({'x': dates, 'y': np.exp(rw_states), 'var': 'Estimated Trend + Random Walk' })
df_2 = pl.concat([df_1, df_2])
df
= ggplot(df, aes("x", "y", color = "var")) + geom_line() + theme_bw() + theme(legend_position='top') + guides(color=guide_legend(title=""))
p = p.draw(show=True) _
= az.summary(m_fit, var_names=["wday_coefficients"], filter_vars="like") coefs_wday
$coefs_wday py
= jnp.dot(model_arguments["contrasts_wday"], jnp.array(coefs_wday["mean"]))
wday_effect = plot_function(range(0,7), wday_effect, "Effect of Day of the Week", "Date", "Sales") # to-do: add uncertainty bands
p
= pl.DataFrame({'x': range(0,7), 'y': weekly_seasonality - np.mean(weekly_seasonality), 'var': 'Simulated Weekly Seasonality' })
df_1 = pl.DataFrame({'x': range(0,7), 'y': wday_effect.tolist(), 'var': 'Estimated Weekly Seasonality' })
df_2 = pl.concat([df_1, df_2])
df
= ggplot(df, aes("x", "y", color = "var")) + geom_line() + theme_bw() + theme(legend_position='top') + guides(color=guide_legend(title=""))
p = p.draw(show=True) _
= az.summary(m_fit, var_names=["yday_coefficients"], filter_vars="like") coefs_yday
$coefs_yday py
= jnp.dot(model_arguments["contrasts_yday"], jnp.array(coefs_yday["mean"]))
yday_effect #p = plot_function(data["date"], yday_effect, "Yearly Seasonality", "Date", "Sales") # to-do: add uncertainty bands
= pl.DataFrame({'x': date_range, 'y': yearly_seasonality - np.mean(yearly_seasonality), 'var': 'Simulated Yearly Seasonality' }).with_columns(
df_1 "x").cast(pl.Date)
pl.col(
)= pl.DataFrame({'x': data["date"].tolist(), 'y': yday_effect.tolist(), 'var': 'Estimated Yearly Seasonality' })
df_2 = pl.concat([df_1, df_2])
df
= ggplot(df, aes("x", "y", color = "var")) + geom_line() + theme_bw() + theme(legend_position='top') + guides(color=guide_legend(title=""))
p
= p.draw(show=True) _
= az.summary(m_fit, var_names=["elasticity"]) summary
$summary py
This case study demonstrates how Bayesian modeling can effectively decompose sales variance into meaningful components, providing a structured way to analyze the underlying factors driving sales fluctuations. By applying this approach to a synthetic dataset, we validated the model’s ability to separate out long-term growth, seasonal effects, and price sensitivity while simultaneously quantifying uncertainty.
The key takeaways include:
Decomposing complexity: The model successfully isolates different components influencing sales, making it easier to interpret real-world dynamics.
Quantifying uncertainty: In addition to point estimates, Bayesian inference provides full posterior distributions, enabling better risk assessment.
Informed decision-making: By accounting for all sources of variance, businesses can make more confident strategic decisions that explicitly consider uncertainty. For instance, price optimization can be performed on the entire posterior distribution of the estimate of the price elasticity of demand.
These findings highlight the advantages of probabilistic modeling in sales analysis, offering a flexible and interpretable method.