Skip to content

Commit

Permalink
Adds HTTPSource and a lot of tests (#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
btamadio authored Feb 25, 2021
1 parent 254fa80 commit 541aa5d
Show file tree
Hide file tree
Showing 24 changed files with 1,492 additions and 499 deletions.
58 changes: 38 additions & 20 deletions bandit.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,33 @@ import (
"context"
)

// Bandit gets reward values from a RewardSource, computes selection probabilities using a Strategy, and selects
// A Bandit gets reward values from a RewardSource, computes selection probabilities using a Strategy, and selects
// an arm using a Sampler.
type Bandit struct {
RewardSource
Strategy
Sampler
}

type Result struct {
Rewards []Dist
Probs []float64
Arm int
}

// SelectArm gets the current reward estimates, computes the arm selection probability, and selects and arm index.
func (b *Bandit) SelectArm(ctx context.Context, unit string) (Result, error) {
// SelectArm gets the current reward estimates, computes the arm selection probabilities, and selects and arm index.
// Returns a partial result and an error message if an error is encountered at any point.
// For example, if the reward estimates were retrieved, but an error was encountered during the probability computation,
// the result will contain the reward estimates, but no probabilities or arm index.
// There is an unfortunate name collision between a multi-armed bandit context and Go's context.Context type.
// The context.Context argument should only be used for passing request-scoped data to an external reward service, such
// as timeouts and cancellation propagation.
// The banditContext argument is used to pass bandit context features to the reward source for contextual bandits.
// The unit argument is a string that will be hashed to select an arm with the pseudo-random sampler.
// SelectArm is deterministic for a fixed unit and set of reward estimates from the RewardSource.
func (b *Bandit) SelectArm(ctx context.Context, unit string, banditContext interface{}) (Result, error) {

res := Result{
Rewards: make([]Dist, 0),
Probs: make([]float64, 0),
Arm: -1,
}

rewards, err := b.GetRewards(ctx)
rewards, err := b.GetRewards(ctx, banditContext)
if err != nil {
return res, err
}
Expand All @@ -51,14 +54,19 @@ func (b *Bandit) SelectArm(ctx context.Context, unit string) (Result, error) {
return res, nil
}

// RewardSource provides the current reward estimates, as a Dist for each arm.
// Features can be passed to the RewardSource using the Context argument, which is useful for contextual bandits.
// The RewardSource should provide the reward estimates conditioned on those context features.
type RewardSource interface {
GetRewards(context.Context) ([]Dist, error)
// Result is the return type for a call to Bandit.SelectArm.
// It will contain the reward estimates provided by the RewardSource, the computed arm selection probabilities,
// and the index of the selected arm.
type Result struct {
Rewards []Dist
Probs []float64
Arm int
}

// Dist represents a one-dimensional probability distribution.
// A Dist represents a one-dimensional probability distribution.
// Reward estimates are represented as a Dist for each arm.
// Strategies compute arm-selection probabilities using the Dist interface.
// This allows for combining different distributions with different strategies.
type Dist interface {
// CDF returns the cumulative distribution function evaluated at x.
CDF(x float64) float64
Expand All @@ -76,14 +84,24 @@ type Dist interface {
Support() (float64, float64)
}

// Strategy computes arm selection probabilities from a slice of Distributions.
// The output probabilities slice should be the same length as the input Dist slice.
// A RewardSource provides the current reward estimates, in the form of a Dist for each arm.
// There is an unfortunate name collision between a multi-armed bandit context and Go's Context type.
// The first argument is a context.Context and should only be used for passing request-scoped data to an external reward service.
// If the RewardSource does not require an external request, this first argument should always be context.Background()
// The second argument is used to pass context values to the reward source for contextual bandits.
// A RewardSource implementation should provide the reward estimates conditioned on the value of banditContext.
// For non-contextual bandits, banditContext can be nil.
type RewardSource interface {
GetRewards(ctx context.Context, banditContext interface{}) ([]Dist, error)
}

// A Strategy computes arm selection probabilities from a slice of Distributions.
type Strategy interface {
ComputeProbs([]Dist) ([]float64, error)
}

// Sampler returns a pseudo-random arm index given a set of probabilities and a unit.
// Samplers should always return the same arm index for the same set of probabilities and unit.
// A Sampler returns a pseudo-random arm index given a set of probabilities and a string to hash.
// Samplers should always return the same arm index for the same set of probabilities and unit value.
type Sampler interface {
Sample(probs []float64, unit string) (int, error)
}
2 changes: 1 addition & 1 deletion bandit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func ExampleBandit_SelectArm() {
Sampler: NewSha1Sampler(),
}

result, err := b.SelectArm(context.Background(), "12345")
result, err := b.SelectArm(context.Background(), "12345", nil)
if err != nil {
panic(err)
}
Expand Down
31 changes: 22 additions & 9 deletions dists.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"gonum.org/v1/gonum/stat/distuv"
)

// Normal is a normal distribution for use with any bandit strategy.
// For the purposes of Thompson sampling, it is truncated at mean +/- 4*sigma
func Normal(mu, sigma float64) NormalDist {
return NormalDist{distuv.Normal{Mu: mu, Sigma: sigma}}
}
Expand All @@ -24,6 +26,7 @@ func (n NormalDist) String() string {
return fmt.Sprintf("Normal(%f,%f)", n.Mu, n.Sigma)
}

// Beta is a beta distribution for use with any bandit strategy.
func Beta(alpha, beta float64) BetaDist {
return BetaDist{distuv.Beta{Alpha: alpha, Beta: beta}}
}
Expand All @@ -40,40 +43,50 @@ func (b BetaDist) String() string {
return fmt.Sprintf("Beta(%f,%f)", b.Beta.Alpha, b.Beta.Beta)
}

func Point(x float64) PointDist {
return PointDist{x}
// Point is used for reward models that just provide point estimates. Don't use with Thompson sampling.
func Point(mu float64) PointDist {
return PointDist{mu}
}

type PointDist struct {
X float64
Mu float64
}

func (p PointDist) Mean() float64 {
return p.X
return p.Mu
}

func (p PointDist) CDF(x float64) float64 {
if x >= p.X {
if x >= p.Mu {
return 1
}
return 0
}

func (p PointDist) Prob(x float64) float64 {
if x == p.X {
if x == p.Mu {
return math.NaN()
}
return 0
}

func (p PointDist) Rand() float64 {
return p.X
return p.Mu
}

func (p PointDist) Support() (float64, float64) {
return p.X, p.X
return p.Mu, p.Mu
}

func (p PointDist) String() string {
return fmt.Sprintf("Point(%f)", p.X)
if math.IsInf(p.Mu, -1) {
return "Null()"
}
return fmt.Sprintf("Point(%f)", p.Mu)
}

// Null returns a PointDist with mean equal to negative infinity. This is a special value that indicates
// to a Strategy that this arm should get selection probability zero.
func Null() PointDist {
return PointDist{math.Inf(-1)}
}
42 changes: 39 additions & 3 deletions epsilon_greedy.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,34 @@ import (
"math"
)

func NewEpsilonGreedy(e float64) *EpsilonGreedy {
return &EpsilonGreedy{
Epsilon: e,
}
}

// EpsilonGreedy implements the epsilon-greedy bandit strategy.
// The Epsilon parameter must be greater than zero.
// If any arm has a Null distribution, it will have zero selection probability, and the other
// arms' probabilities will be computed as if the Null arms are not present.
// Ties are accounted for, so if multiple arms have the maximum mean reward estimate, they will have equal probabilities.
type EpsilonGreedy struct {
Epsilon float64
meanRewards []float64
}

// ComputeProbs computes the arm selection probabilities from the set of reward estimates, accounting for Nulls and ties.
// Returns an error if epsilon is less than zero.
func (e *EpsilonGreedy) ComputeProbs(rewards []Dist) ([]float64, error) {

if err := e.validateEpsilon(); err != nil {
return nil, err
}

if len(rewards) == 0 {
return []float64{}, nil
}

e.meanRewards = make([]float64, len(rewards))
for i, dist := range rewards {
e.meanRewards[i] = dist.Mean()
Expand All @@ -26,23 +43,42 @@ func (e *EpsilonGreedy) ComputeProbs(rewards []Dist) ([]float64, error) {
}

func (e EpsilonGreedy) computeProbs() []float64 {

probs := make([]float64, len(e.meanRewards))

nonNullArms := e.numNonNullArms()
if nonNullArms == 0 {
return probs
}

maxRewardArmIndices := argsMax(e.meanRewards)
numMaxima := len(maxRewardArmIndices)
numArms := len(e.meanRewards)

for i := range e.meanRewards {
if isIn(maxRewardArmIndices, i) {
probs[i] = (1-e.Epsilon)/float64(numMaxima) + e.Epsilon/float64(numArms)
probs[i] = (1-e.Epsilon)/float64(numMaxima) + e.Epsilon/float64(nonNullArms)
} else {
probs[i] = e.Epsilon / float64(len(e.meanRewards))
if math.IsInf(e.meanRewards[i], -1) {
probs[i] = 0
} else {
probs[i] = e.Epsilon / float64(nonNullArms)
}
}
}

return probs
}

func (e EpsilonGreedy) numNonNullArms() int {
count := 0
for _, val := range e.meanRewards {
if val > math.Inf(-1) {
count += 1
}
}
return count
}

func (e EpsilonGreedy) validateEpsilon() error {
if e.Epsilon < 0 || e.Epsilon > 1 {
return fmt.Errorf("invalid Epsilon value: %v. Must be between 0 and 1", e.Epsilon)
Expand Down
9 changes: 4 additions & 5 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@ module github.com/stitchfix/mab
go 1.14

require (
cloud.google.com/go/datastore v1.4.0 // indirect
github.com/gomodule/redigo v1.8.3 // indirect
golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad // indirect
golang.org/x/net v0.0.0-20210119194325-5f4716e94777 // indirect
golang.org/x/tools v0.1.0 // indirect
github.com/kr/pretty v0.1.0 // indirect
github.com/stretchr/testify v1.5.1
golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6 // indirect
gonum.org/v1/gonum v0.8.2
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
)
Loading

0 comments on commit 541aa5d

Please sign in to comment.