Source code for bandit.reward

"""
Classes for the environment and the reward model.
"""

from typing import Callable, List, Union

import numpy as np
import scipy.stats as ss

from abc import ABC, abstractmethod


[docs]class BaseReward(ABC): """ Base class for rewards Args: dist (Callable): a random variable distribution that has an `rvs` method that returns a reward """ def __init__(self, dist: Callable): assert hasattr(dist, "rvs"), "distribution must have rvs() method" assert hasattr(dist, "stats"), "distribution must have a stats method" self.dist = dist @abstractmethod def get_reward(self) -> Union[float, int]: return self.dist.rvs() @abstractmethod def expected_reward(self) -> Union[float, int]: return self.dist.stats("m") def moments(self, kind: str = "mv") -> List[float]: return self.dist.stats(kind)
[docs]class GaussianReward(BaseReward): """ A Gaussian random variable as a reward. Args: mean (float): mean of the Gaussian reward var (float): variance of the Gaussian reward; must be positive """ def __init__(self, mean: float = 0, var: float = 1): assert var > 0, "variance must be positive" super().__init__(ss.norm(loc=mean, scale=np.sqrt(var))) def get_reward(self) -> float: return super().get_reward() def expected_reward(self) -> float: return super().expected_reward()
[docs]class PoissonReward(BaseReward): """ Poisson random variable reward. Args: mu (float): rate parameter (mean and var) loc (float): constant shift """ def __init__(self, mu: float = 1, loc: float = 0): assert mu > 0, "poisson rate must be positive" super().__init__(ss.poisson(mu=mu, loc=loc)) def get_reward(self) -> int: return super().get_reward() def expected_reward(self) -> float: return super().expected_reward()