diff --git a/river/bandit/__init__.py b/river/bandit/__init__.py index 70b0a98a72..f00db0cbfb 100644 --- a/river/bandit/__init__.py +++ b/river/bandit/__init__.py @@ -13,6 +13,7 @@ from .epsilon_greedy import EpsilonGreedy from .evaluate import evaluate, evaluate_offline from .exp3 import Exp3 +from .kl_ucb import KLUCB from .lin_ucb import LinUCBDisjoint from .random import RandomPolicy from .thompson import ThompsonSampling @@ -31,4 +32,5 @@ "ThompsonSampling", "UCB", "RandomPolicy", + "KLUCB", ] diff --git a/river/bandit/kl_ucb.py b/river/bandit/kl_ucb.py new file mode 100644 index 0000000000..4fff36bfff --- /dev/null +++ b/river/bandit/kl_ucb.py @@ -0,0 +1,196 @@ +from __future__ import annotations + +import math +import random + + +class KLUCB: + """ + + KL-UCB is an algorithm for solving the multi-armed bandit problem. It uses Kullback-Leibler (KL) + divergence to calculate upper confidence bounds (UCBs) for each arm. The algorithm aims to balance + exploration (trying different arms) and exploitation (selecting the best-performing arm) in a principled way. + + Parameters + ---------- + n_arms (int): + The total number of arms available for selection. + horizon (int): + The total number of time steps or trials during which the algorithm will run. + c (float, default=0): + A scaling parameter for the confidence bound. Larger values promote exploration, + while smaller values favor exploitation. + + Attributes + ---------- + arm_count (list[int]): + A list where each element tracks the number of times an arm has been selected. + rewards (list[float]): + A list where each element accumulates the total rewards received from pulling each arm. + t (int): + The current time step in the algorithm. + + Methods + ------- + update(arm, reward): + Updates the statistics for the selected arm based on the observed reward. + + kl_divergence(p, q): + Computes the Kullback-Leibler (KL) divergence between probabilities `p` and `q`. + This measures how one probability distribution differs from another. + + kl_index(arm): + Calculates the KL-UCB index for a specific arm using binary search to determine the upper bound. + + pull_arm(arm): + Simulates pulling an arm by generating a reward based on the empirical mean reward for that arm. + + + Examples: + ---------- + + >>> from river.bandit import KLUCB + >>> n_arms = 3 + >>> horizon = 100 + >>> c = 1 + >>> klucb = KLUCB(n_arms=n_arms, horizon=horizon, c=c) + + >>> random.seed(42) + + >>> def calculate_reward(arm): + ... #Example: Bernoulli reward based on the true probability (for testing) + ... true_probabilities = [0.3, 0.5, 0.7] # Example probabilities for each arm + ... return 1 if random.random() < true_probabilities[arm] else 0 + >>> # Initialize tracking variables + >>> selected_arms = [] + >>> total_reward = 0 + >>> cumulative_rewards = [] + >>> for t in range(1, horizon + 1): + ... klucb.t = t + ... indices = [klucb.kl_index(arm) for arm in range(n_arms)] + ... chosen_arm = indices.index(max(indices)) + ... reward = calculate_reward(chosen_arm) + ... klucb.update(chosen_arm, reward) + ... selected_arms.append(chosen_arm) + ... total_reward += reward + ... cumulative_rewards.append(total_reward) + + + >>> print("Selected arms:", selected_arms) + Selected arms: [0, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2] + + + + >>> print("Cumulative rewards:", cumulative_rewards) + Cumulative rewards: [0, 1, 2, 3, 3, 3, 3, 4, 5, 6, 7, 7, 8, 9, 9, 9, 10, 10, 10, 11, 11, 11, 11, 12, 12, 13, 14, 15, 15, 16, 16, 16, 17, 17, 18, 19, 19, 20, 20, 20, 20, 21, 22, 23, 24, 25, 26, 27, 27, 28, 29, 30, 31, 31, 31, 31, 32, 32, 33, 34, 34, 34, 34, 35, 35, 35, 36, 37, 38, 39, 40, 40, 40, 41, 41, 42, 42, 42, 43, 44, 44, 45, 45, 45, 46, 47, 47, 48, 49, 50, 51, 52, 52, 53, 54, 55, 55, 56, 56, 56] + + + + >>> print(f"Total Reward: {total_reward}") + Total Reward: 56 + + """ + + def __init__(self, n_arms, horizon, c=0): + self.n_arms = n_arms + self.horizon = horizon + self.c = c + self.arm_count = [1 for _ in range(n_arms)] + self.rewards = [0.0 for _ in range(n_arms)] + self.t = 0 + + def update(self, arm, reward): + """ + Updates the number of times the arm has been pulled and the cumulative reward + for the given arm. Also increments the current time step. + + Parameters + ---------- + arm (int): The index of the arm that was pulled. + reward (float): The reward obtained from pulling the arm. + """ + self.arm_count[arm] += 1 + self.rewards[arm] += reward + self.t += 1 + + def kl_divergence(self, p, q): + """ + Computes the Kullback-Leibler (KL) divergence between two probabilities `p` and `q`. + + Parameters + ---------- + p (float): The first probability (true distribution). + q (float): The second probability (approximated distribution). + + Returns + ------- + float: The KL divergence value. Returns infinity if `q` is not a valid probability. + """ + + if p == 0: + return float("inf") if q >= 1 else -math.log(1 - q) + elif p == 1: + return float("inf") if q <= 0 else -math.log(q) + elif q <= 0 or q >= 1: + return float("inf") + return p * math.log(p / q) + (1 - p) * math.log((1 - p) / (1 - q)) + + def kl_index(self, arm): + """ + Computes the KL-UCB index for a given arm using binary search. + This determines the upper confidence bound for the arm. + + Parameters + ---------- + arm (int): The index of the arm to compute the index for. + + Returns + ------- + float: The KL-UCB index for the arm. + """ + + n_t = self.arm_count[arm] + if n_t == 0: + return float("inf") # Unseen arm + empirical_mean = self.rewards[arm] / n_t + log_t_over_n = math.log(self.t + 1) / n_t + c_factor = self.c * log_t_over_n + + # Binary search to find the q that satisfies the KL-UCB condition + low = empirical_mean + high = 1.0 + for _ in range(100): # Fixed number of iterations for binary search + mid = (low + high) / 2 + kl = self.kl_divergence(empirical_mean, mid) + if kl > c_factor: + high = mid + else: + low = mid + return low + + def pull_arm(self, arm): + """ + Simulates pulling an arm by generating a reward based on its empirical mean. + + Parameters + ---------- + arm (int): The index of the arm to pull. + + Returns + ------- + int: 1 if the arm yields a reward, 0 otherwise. + """ + prob = self.rewards[arm] / self.arm_count[arm] + return 1 if random.random() < prob else 0 + + @staticmethod + def _unit_test_params(): + """ + Returns a list of dictionaries with parameters to initialize the KLUCB class + for unit testing. + """ + return [ + {"n_arms": 2, "horizon": 100, "c": 0.5}, + {"n_arms": 5, "horizon": 1000, "c": 1.0}, + {"n_arms": 10, "horizon": 500, "c": 0.1}, + ] diff --git a/river/linear_model/__init__.py b/river/linear_model/__init__.py index 756720490a..9eaedb1dda 100644 --- a/river/linear_model/__init__.py +++ b/river/linear_model/__init__.py @@ -3,6 +3,7 @@ from __future__ import annotations from . import base +from .adpredictor import AdPredictor from .alma import ALMAClassifier from .bayesian_lin_reg import BayesianLinearRegression from .lin_reg import LinearRegression @@ -21,4 +22,5 @@ "PARegressor", "Perceptron", "SoftmaxRegression", + "AdPredictor", ] diff --git a/river/linear_model/adpredictor.py b/river/linear_model/adpredictor.py new file mode 100644 index 0000000000..67811a9529 --- /dev/null +++ b/river/linear_model/adpredictor.py @@ -0,0 +1,156 @@ +from __future__ import annotations + +import collections +import math + +from river.base.classifier import Classifier + + +def default_mean(): + return 0.0 + + +def default_variance(): + return 1.0 + + +class AdPredictor(Classifier): + """ + AdPredictor is a machine learning algorithm designed to predict the probability of user + clicks on online advertisements. This algorithm plays a crucial role in computational advertising, where predicting + click-through rates (CTR) is essential for optimizing ad placements and maximizing revenue. + Parameters + ---------- + beta (float, default=0.1): + A smoothing parameter that regulates the weight updates. Smaller values allow for finer updates, + while larger values can accelerate convergence but may risk instability. + prior_probability (float, default=0.5): + The initial estimate rate. This value sets the bias weight, influencing the model's predictions + before observing any data. + + epsilon (float, default=0.1): + A variance dynamics parameter that controls how the model balances prior knowledge and learned information. + Larger values prioritize prior knowledge, while smaller values favor data-driven updates. + + num_features (int, default=10): + The maximum number of features the model can handle. This parameter affects scalability and efficiency, + especially for high-dimensional data. + + Attributes + ---------- + weights (defaultdict): + A dictionary where each feature key maps to a dictionary containing: + + mean (float): The current estimate of the feature's weight. + variance (float): The uncertainty associated with the weight estimate. + + bias_weight (float): + The weight corresponding to the model bias, initialized using the prior_probability. + This attribute allows the model to make predictions even when no features are active. + + Examples: + ---------- + + >>> from river.linear_model import AdPredictor + >>> adpredictor = AdPredictor(beta=0.1, prior_probability=0.5, epsilon=0.1, num_features=5) + >>> data = [({"feature1": 1, "feature2": 1}, 1),({"feature1": 1, "feature3": 1}, 0),({"feature2": 1, "feature4": 1}, 1),({"feature1": 1, "feature2": 1, "feature3": 1}, 0),({"feature4": 1, "feature5": 1}, 1),] + >>> def train_and_test(model, data): + ... for x, y in data: + ... pred_before = model.predict_one(x) + ... model.learn_one(x, y) + ... pred_after = model.predict_one(x) + ... print(f"Features: {x} | True label: {y} | Prediction before training: {pred_before:.4f} | Prediction after training: {pred_after:.4f}") + + >>> train_and_test(adpredictor, data) + Features: {'feature1': 1, 'feature2': 1} | True label: 1 | Prediction before training: 0.5000 | Prediction after training: 0.7230 + Features: {'feature1': 1, 'feature3': 1} | True label: 0 | Prediction before training: 0.6065 | Prediction after training: 0.3650 + Features: {'feature2': 1, 'feature4': 1} | True label: 1 | Prediction before training: 0.6065 | Prediction after training: 0.7761 + Features: {'feature1': 1, 'feature2': 1, 'feature3': 1} | True label: 0 | Prediction before training: 0.5455 | Prediction after training: 0.3197 + Features: {'feature4': 1, 'feature5': 1} | True label: 1 | Prediction before training: 0.5888 | Prediction after training: 0.7699 + + """ + + def __init__(self, beta=0.1, prior_probability=0.5, epsilon=0.1, num_features=10): + # Initialization of model parameters + self.beta = beta + self.prior_probability = prior_probability + self.epsilon = epsilon + self.num_features = num_features + # Initialize weights as a defaultdict for each feature, with mean and variance attributes + + self.means = collections.defaultdict(default_mean) + self.variances = collections.defaultdict(default_variance) + + # Initialize bias weight based on prior probability + self.bias_weight = self.prior_bias_weight() + + def prior_bias_weight(self): + # Calculate initial bias weight using prior probability + + return math.log(self.prior_probability / (1 - self.prior_probability)) / self.beta + + def _active_mean_variance(self, features): + """_active_mean_variance(features) (method): + Computes the cumulative mean and variance for all active features in a sample, + including the bias. This is crucial for making predictions.""" + # Calculate total mean and variance for all active features + + total_mean = sum(self.means[f] for f in features) + self.bias_weight + total_variance = sum(self.variances[f] for f in features) + self.beta**2 + return total_mean, total_variance + + def predict_one(self, x): + # Generate a probability prediction for one sample + features = x.keys() + total_mean, total_variance = self._active_mean_variance(features) + # Sigmoid function for probability prediction based on Gaussian distribution + return 1 / (1 + math.exp(-total_mean / math.sqrt(total_variance))) + + def learn_one(self, x, y): + # Online learning step to update the model with one sample + features = x.keys() + y = 1 if y else -1 + total_mean, total_variance = self._active_mean_variance(features) + v, w = self.gaussian_corrections(y * total_mean / math.sqrt(total_variance)) + + # Update mean and variance for each feature in the sample + for feature in features: + mean = self.means[feature] + variance = self.variances[feature] + + mean_delta = y * variance / math.sqrt(total_variance) * v # Update mean + variance_multiplier = 1.0 - variance / total_variance * w # Update variance + + # Update weight + self.means[feature] = mean + mean_delta + self.variances[feature] = variance * variance_multiplier + + def gaussian_corrections(self, score): + """gaussian_corrections(score) (method): + Implements Bayesian update corrections using the Gaussian probability density function (PDF) + and cumulative density function (CDF).""" + # CDF calculation for Gaussian correction + cdf = 1 / (1 + math.exp(-score)) + pdf = math.exp(-0.5 * score**2) / math.sqrt(2 * math.pi) # PDF calculation + v = pdf / cdf # Correction factor for mean update + w = v * (v + score) # Correction factor for variance update + return v, w + + def _apply_dynamics(self, weight): + """_apply_dynamics(weight) (method): + Regularizes the variance of a feature weight using a combination of prior variance and learned variance. + This helps maintain a balance between prior beliefs and observed data.""" + # Apply variance dynamics for regularization + prior_variance = 1.0 + # Adjust variance to manage prior knowledge and current learning balance + adjusted_variance = ( + weight["variance"] + * prior_variance + / ((1.0 - self.epsilon) * prior_variance + self.epsilon * weight["variance"]) + ) + # Adjust mean based on the dynamics, balancing previous and current knowledge + adjusted_mean = adjusted_variance * ( + (1.0 - self.epsilon) * weight["mean"] / weight["variance"] + + self.epsilon * 0 / prior_variance + ) + return {"mean": adjusted_mean, "variance": adjusted_variance}