diff --git a/.gitignore b/.gitignore index db01cb35..425d35c5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ go.work /*.zip -tmp \ No newline at end of file +tmp +run.sh diff --git a/internal/meters/amberflo.go b/internal/meters/amberflo.go index c6c63feb..c5848574 100644 --- a/internal/meters/amberflo.go +++ b/internal/meters/amberflo.go @@ -12,7 +12,7 @@ import ( "github.com/xtgo/uuid" ) -type AmberFlo struct { +type Amberflo struct { apikey string interval time.Duration client *metering.Metering @@ -20,7 +20,7 @@ type AmberFlo struct { cfgs map[string]amberFloConfig } -func NewAmberFlo(apikey string, interval time.Duration, batchSize int) *AmberFlo { +func NewAmberflo(apikey string, interval time.Duration, batchSize int) *Amberflo { afLog := &amberfloLogger{logger: log.Logger} meteringClient := metering.NewMeteringClient( apikey, @@ -32,7 +32,7 @@ func NewAmberFlo(apikey string, interval time.Duration, batchSize int) *AmberFlo apikey, metering.WithCustomLogger(afLog), ) - return &AmberFlo{ + return &Amberflo{ apikey: apikey, interval: interval, client: meteringClient, @@ -42,13 +42,13 @@ func NewAmberFlo(apikey string, interval time.Duration, batchSize int) *AmberFlo } type amberFloConfig struct { - Name string `json:"name,omitempty"` - DefaultUser string `json:"default_user,omitempty"` - ExternalIDKey string `json:"external_id_key,omitempty"` - Dimensions map[string]string `json:"dimensions,omitempty"` + Name string `json:"name,omitempty"` + DefaultUser string `json:"default_user,omitempty"` + ExternalIDKey string `json:"external_id_key,omitempty"` + Dimensions Dimensions `json:"dimensions,omitempty"` } -func (m *AmberFlo) LoadConfig(path string) error { +func (m *Amberflo) LoadConfig(path string) error { cfgs := map[string]amberFloConfig{} data, err := ioutil.ReadFile(path) if err != nil { @@ -61,26 +61,24 @@ func (m *AmberFlo) LoadConfig(path string) error { return nil } -func (m *AmberFlo) NewMeter(user MeterUser) ApiMeter { +func (m *Amberflo) NewMeter(user MeterUser) ApiMeter { return &amberFloMeter{ user: user, mp: m, } } -func (m *AmberFlo) Close() error { +func (m *Amberflo) Close() error { return m.client.Shutdown() } -func (m *AmberFlo) Flush() error { +func (m *Amberflo) Flush() error { // metering.Flush() // in API docs but not in library time.Sleep(m.interval) return nil } -func (m *AmberFlo) getValue(user MeterUser, meterName string) (float64, bool) { - // TODO: batch and cache - // TODO: time period and aggregation is hardcoded as 1 day +func (m *Amberflo) getValue(user MeterUser, meterName string, startTime time.Time, endTime time.Time, checkDims Dimensions) (float64, bool) { cfg, ok := m.getcfg(meterName) if !ok { return 0, false @@ -92,17 +90,34 @@ func (m *AmberFlo) getValue(user MeterUser, meterName string) (float64, bool) { if cfg.Name == "" { return 0, false } - - startTimeInSeconds := (time.Now().In(time.UTC).UnixNano() / int64(time.Second)) - (24 * 60 * 60) timeRange := &metering.TimeRange{ - StartTimeInSeconds: startTimeInSeconds, + StartTimeInSeconds: startTime.In(time.UTC).Unix(), + EndTimeInSeconds: endTime.In(time.UTC).Unix(), + } + if timeRange.EndTimeInSeconds > time.Now().In(time.UTC).Unix() { + timeRange.EndTimeInSeconds = 0 } + filter := make(map[string][]string) filter["customerId"] = []string{customerId} + for _, dim := range checkDims { + filter[dim.Key] = []string{dim.Value} + } + + timeGroupingInterval := metering.Hour + switch timeSpan := endTime.Unix() - startTime.Unix(); { + case timeSpan > 24*60*60: + timeGroupingInterval = metering.Month + case timeSpan > 60*60: + timeGroupingInterval = metering.Day + default: + timeGroupingInterval = metering.Hour + } + usageResult, err := m.usageClient.GetUsage(&metering.UsagePayload{ MeterApiName: cfg.Name, Aggregation: metering.Sum, - TimeGroupingInterval: metering.Day, + TimeGroupingInterval: timeGroupingInterval, GroupBy: []string{"customerId"}, TimeRange: timeRange, Filter: filter, @@ -111,16 +126,19 @@ func (m *AmberFlo) getValue(user MeterUser, meterName string) (float64, bool) { log.Error().Err(err).Str("user", user.ID()).Msg("could not get value") return 0, false } + // jj, _ := json.Marshal(&usageResult) + // fmt.Println("usageResult:", string(jj)) + if usageResult == nil || len(usageResult.ClientMeters) == 0 || len(usageResult.ClientMeters[0].Values) == 0 { log.Error().Err(err).Str("user", user.ID()).Msg("could not get value; no client value meter") return 0, false } - cm := usageResult.ClientMeters[0].Values - cmv := cm[len(cm)-1].Value - return cmv, true + + total := usageResult.ClientMeters[0].GroupValue + return total, true } -func (m *AmberFlo) sendMeter(user MeterUser, meterName string, value float64, extraDimensions map[string]string) error { +func (m *Amberflo) sendMeter(user MeterUser, meterName string, value float64, extraDimensions Dimensions) error { cfg, ok := m.getcfg(meterName) if !ok { return nil @@ -132,12 +150,12 @@ func (m *AmberFlo) sendMeter(user MeterUser, meterName string, value float64, ex } uniqueId := uuid.NewRandom().String() utcMillis := time.Now().In(time.UTC).UnixNano() / int64(time.Millisecond) - dimensions := map[string]string{} - for k, v := range cfg.Dimensions { - dimensions[k] = v + amberFloDims := map[string]string{} + for _, v := range cfg.Dimensions { + amberFloDims[v.Key] = v.Value } - for k, v := range extraDimensions { - dimensions[k] = v + for _, v := range extraDimensions { + amberFloDims[v.Key] = v.Value } return m.client.Meter(&metering.MeterMessage{ MeterApiName: cfg.Name, @@ -145,11 +163,11 @@ func (m *AmberFlo) sendMeter(user MeterUser, meterName string, value float64, ex MeterTimeInMillis: utcMillis, CustomerId: customerId, MeterValue: value, - Dimensions: dimensions, + Dimensions: amberFloDims, }) } -func (m *AmberFlo) getCustomerID(cfg amberFloConfig, user MeterUser) (string, bool) { +func (m *Amberflo) getCustomerID(cfg amberFloConfig, user MeterUser) (string, bool) { customerId := cfg.DefaultUser if user != nil { eidKey := cfg.ExternalIDKey @@ -166,7 +184,7 @@ func (m *AmberFlo) getCustomerID(cfg amberFloConfig, user MeterUser) (string, bo return customerId, customerId != "" } -func (m *AmberFlo) getcfg(meterName string) (amberFloConfig, bool) { +func (m *Amberflo) getcfg(meterName string) (amberFloConfig, bool) { cfg, ok := m.cfgs[meterName] if !ok { cfg = amberFloConfig{ @@ -183,41 +201,34 @@ func (m *AmberFlo) getcfg(meterName string) (amberFloConfig, bool) { ////////// type amberFloMeter struct { - user MeterUser - dims []string - mp *AmberFlo + user MeterUser + addDims []eventAddDim + mp *Amberflo } -func (m *amberFloMeter) Meter(meterName string, value float64, extraDimensions map[string]string) error { - var dm2 map[string]string - if len(extraDimensions) > 0 || len(m.dims) > 0 { - dm2 = map[string]string{} - } - for k, v := range extraDimensions { - dm2[k] = v - } - for i := 0; i < len(m.dims); i += 3 { - a := m.dims[i] - k := m.dims[i+1] - v := m.dims[i+2] - if a == "" || a == meterName { - dm2[k] = v +func (m *amberFloMeter) Meter(meterName string, value float64, extraDimensions Dimensions) error { + var eventDims []Dimension + // Copy in matching dimensions set through AddDimension + for _, addDim := range m.addDims { + if addDim.MeterName == meterName { + eventDims = append(eventDims, Dimension{Key: addDim.Key, Value: addDim.Value}) } } + eventDims = append(eventDims, extraDimensions...) log.Trace(). Str("user", m.user.ID()). Str("meter", meterName). Float64("meter_value", value). Msg("meter") - return m.mp.sendMeter(m.user, meterName, value, dm2) + return m.mp.sendMeter(m.user, meterName, value, eventDims) } func (m *amberFloMeter) AddDimension(meterName string, key string, value string) { - m.dims = append(m.dims, meterName, key, value) + m.addDims = append(m.addDims, eventAddDim{MeterName: meterName, Key: key, Value: value}) } -func (m *amberFloMeter) GetValue(meterName string) (float64, bool) { - return m.mp.getValue(m.user, meterName) +func (m *amberFloMeter) GetValue(meterName string, startTime time.Time, endTime time.Time, dims Dimensions) (float64, bool) { + return m.mp.getValue(m.user, meterName, startTime, endTime, dims) } ///////// diff --git a/internal/meters/amberflo_test.go b/internal/meters/amberflo_test.go index 181ce437..6f69737d 100644 --- a/internal/meters/amberflo_test.go +++ b/internal/meters/amberflo_test.go @@ -1,6 +1,7 @@ package meters import ( + "errors" "os" "testing" "time" @@ -8,23 +9,16 @@ import ( "github.com/interline-io/transitland-server/internal/testutil" ) -type amberfloTestUser struct { - name string -} - -func (u *amberfloTestUser) ID() string { - return u.name -} - -func (u *amberfloTestUser) GetExternalData(eid string) (string, bool) { - // must match key given in config below - if eid == "amberflo" { - return u.name, true +func TestAmberfloMeter(t *testing.T) { + mp, testConfig, err := getTestAmberfloMeter() + if err != nil { + t.Skip(err.Error()) + return } - return "", false + testMeter(t, mp, testConfig) } -func TestAmberFloMeter(t *testing.T) { +func getTestAmberfloMeter() (*Amberflo, testMeterConfig, error) { checkKeys := []string{ "TL_TEST_AMBERFLO_APIKEY", "TL_TEST_AMBERFLO_METER1", @@ -36,19 +30,28 @@ func TestAmberFloMeter(t *testing.T) { for _, k := range checkKeys { _, a, ok := testutil.CheckEnv(k) if !ok { - t.Skip(a) - return + return nil, testMeterConfig{}, errors.New(a) } } + eidKey := "amberflo" testConfig := testMeterConfig{ testMeter1: os.Getenv("TL_TEST_AMBERFLO_METER1"), testMeter2: os.Getenv("TL_TEST_AMBERFLO_METER2"), - user1: &amberfloTestUser{name: os.Getenv("TL_TEST_AMBERFLO_USER1")}, - user2: &amberfloTestUser{name: os.Getenv("TL_TEST_AMBERFLO_USER2")}, - user3: &amberfloTestUser{name: os.Getenv("TL_TEST_AMBERFLO_USER3")}, + user1: &testUser{ + name: os.Getenv("TL_TEST_AMBERFLO_USER1"), + data: map[string]string{eidKey: os.Getenv("TL_TEST_AMBERFLO_USER1")}, + }, + user2: &testUser{ + name: os.Getenv("TL_TEST_AMBERFLO_USER2"), + data: map[string]string{eidKey: os.Getenv("TL_TEST_AMBERFLO_USER2")}, + }, + user3: &testUser{ + name: os.Getenv("TL_TEST_AMBERFLO_USER3"), + data: map[string]string{eidKey: os.Getenv("TL_TEST_AMBERFLO_USER3")}, + }, } - mp := NewAmberFlo(os.Getenv("TL_TEST_AMBERFLO_APIKEY"), 1*time.Second, 1) - mp.cfgs[testConfig.testMeter1] = amberFloConfig{Name: testConfig.testMeter1, ExternalIDKey: "amberflo"} - mp.cfgs[testConfig.testMeter2] = amberFloConfig{Name: testConfig.testMeter2, ExternalIDKey: "amberflo"} - testMeter(t, mp, testConfig) + mp := NewAmberflo(os.Getenv("TL_TEST_AMBERFLO_APIKEY"), 1*time.Second, 1) + mp.cfgs[testConfig.testMeter1] = amberFloConfig{Name: testConfig.testMeter1, ExternalIDKey: eidKey} + mp.cfgs[testConfig.testMeter2] = amberFloConfig{Name: testConfig.testMeter2, ExternalIDKey: eidKey} + return mp, testConfig, nil } diff --git a/internal/meters/default.go b/internal/meters/default.go index cebf2d9c..addd94db 100644 --- a/internal/meters/default.go +++ b/internal/meters/default.go @@ -2,100 +2,125 @@ package meters import ( "sync" + "time" "github.com/interline-io/transitland-lib/log" ) -type DefaultMeter struct { - values map[string]map[string]float64 +type DefaultMeterProvider struct { + values map[string]defaultMeterUserEvents lock sync.Mutex } -func NewDefaultMeter() *DefaultMeter { - return &DefaultMeter{ - values: map[string]map[string]float64{}, +func NewDefaultMeterProvider() *DefaultMeterProvider { + return &DefaultMeterProvider{ + values: map[string]defaultMeterUserEvents{}, } } -func (m *DefaultMeter) Flush() error { +func (m *DefaultMeterProvider) Flush() error { return nil } -func (m *DefaultMeter) Close() error { +func (m *DefaultMeterProvider) Close() error { return nil } -func (m *DefaultMeter) NewMeter(user MeterUser) ApiMeter { +func (m *DefaultMeterProvider) NewMeter(user MeterUser) ApiMeter { return &defaultUserMeter{ user: user, mp: m, } } -func (m *DefaultMeter) sendMeter(u MeterUser, meterName string, value float64, dims map[string]string) error { +func (m *DefaultMeterProvider) sendMeter(u MeterUser, meterName string, value float64, dims []Dimension) error { m.lock.Lock() defer m.lock.Unlock() a, ok := m.values[meterName] if !ok { - a = map[string]float64{} + a = defaultMeterUserEvents{} m.values[meterName] = a } userName := "" if u != nil { userName = u.ID() } - - a[userName] += value + event := defaultMeterEvent{ + value: value, + time: time.Now().In(time.UTC), + dims: dims, + } + a[userName] = append(a[userName], event) log.Trace(). Str("user", userName). Str("meter", meterName). Float64("meter_value", value). - Float64("total_value", a[userName]). Msg("meter") return nil } -func (m *DefaultMeter) getValue(u MeterUser, meterName string) (float64, bool) { +func (m *DefaultMeterProvider) getValue(u MeterUser, meterName string, startTime time.Time, endTime time.Time, checkDims Dimensions) (float64, bool) { m.lock.Lock() defer m.lock.Unlock() a, ok := m.values[meterName] if !ok { - a = map[string]float64{} - m.values[meterName] = a + return 0, false + } + total := 0.0 + for _, userEvent := range a[u.ID()] { + match := true + if userEvent.time.Equal(endTime) || userEvent.time.After(endTime) { + // fmt.Println("not matched on end time", userEvent.time, endTime) + match = false + } + if userEvent.time.Before(startTime) { + // fmt.Println("not matched on start time", userEvent.time, startTime) + match = false + } + if !dimsContainedIn(checkDims, userEvent.dims) { + // fmt.Println("not matched on dims") + match = false + } + if match { + // fmt.Println("matched:", userEvent.value) + total += userEvent.value + } } - v, ok := a[u.ID()] - return v, ok + return total, ok } type defaultUserMeter struct { - user MeterUser - dims []string - mp *DefaultMeter + user MeterUser + addDims []eventAddDim + mp *DefaultMeterProvider } -func (m *defaultUserMeter) Meter(meterName string, value float64, extraDimensions map[string]string) error { - var dm2 map[string]string - if len(extraDimensions) > 0 || len(m.dims) > 0 { - dm2 = map[string]string{} - } - for i := 0; i < len(m.dims); i += 3 { - a := m.dims[i] - k := m.dims[i+1] - v := m.dims[i+2] - if a == meterName { - dm2[k] = v +func (m *defaultUserMeter) Meter(meterName string, value float64, extraDimensions Dimensions) error { + // Copy in matching dimensions set through AddDimension + var eventDims []Dimension + for _, addDim := range m.addDims { + if addDim.MeterName == meterName { + eventDims = append(eventDims, Dimension{Key: addDim.Key, Value: addDim.Value}) } } - for k, v := range extraDimensions { - dm2[k] = v - } - return m.mp.sendMeter(m.user, meterName, value, dm2) + eventDims = append(eventDims, extraDimensions...) + return m.mp.sendMeter(m.user, meterName, value, eventDims) } func (m *defaultUserMeter) AddDimension(meterName string, key string, value string) { - m.dims = append(m.dims, meterName, key, value) + m.addDims = append(m.addDims, eventAddDim{MeterName: meterName, Key: key, Value: value}) } -func (m *defaultUserMeter) GetValue(meterName string) (float64, bool) { - return m.mp.getValue(m.user, meterName) +func (m *defaultUserMeter) GetValue(meterName string, startTime time.Time, endTime time.Time, dims Dimensions) (float64, bool) { + return m.mp.getValue(m.user, meterName, startTime, endTime, dims) } + +/////////// + +type defaultMeterEvent struct { + time time.Time + dims []Dimension + value float64 +} + +type defaultMeterUserEvents map[string][]defaultMeterEvent diff --git a/internal/meters/default_test.go b/internal/meters/default_test.go index 7ddee757..7c0ea275 100644 --- a/internal/meters/default_test.go +++ b/internal/meters/default_test.go @@ -5,7 +5,7 @@ import ( ) func TestDefaultMeter(t *testing.T) { - mp := NewDefaultMeter() + mp := NewDefaultMeterProvider() testConfig := testMeterConfig{ testMeter1: "test1", testMeter2: "test2", diff --git a/internal/meters/limit.go b/internal/meters/limit.go new file mode 100644 index 00000000..c3dd91c1 --- /dev/null +++ b/internal/meters/limit.go @@ -0,0 +1,129 @@ +package meters + +import ( + "errors" + "fmt" + "time" + + "github.com/interline-io/transitland-lib/log" + "github.com/tidwall/gjson" +) + +func init() { + var _ MeterProvider = &LimitMeterProvider{} +} + +type LimitMeterProvider struct { + Enabled bool + DefaultLimits []UserMeterLimit + MeterProvider +} + +func NewLimitMeterProvider(provider MeterProvider) *LimitMeterProvider { + return &LimitMeterProvider{ + MeterProvider: provider, + } +} + +func (c *LimitMeterProvider) NewMeter(u MeterUser) ApiMeter { + userData, _ := u.GetExternalData("gatekeeper") + return &LimitMeter{ + userId: u.ID(), + userData: userData, + provider: c, + ApiMeter: c.MeterProvider.NewMeter(u), + } +} + +type LimitMeter struct { + userId string + userData string + provider *LimitMeterProvider + ApiMeter +} + +func (c *LimitMeter) GetLimits(meterName string, checkDims Dimensions) []UserMeterLimit { + // The limit matches the event dimensions if all of the LIMIT dimensions are contained in event + var lims []UserMeterLimit + for _, userLimit := range parseGkUserLimits(c.userData) { + if userLimit.MeterName == meterName && dimsContainedIn(userLimit.Dims, checkDims) { + lims = append(lims, userLimit) + } + } + for _, defaultLimit := range c.provider.DefaultLimits { + if defaultLimit.MeterName == meterName && dimsContainedIn(defaultLimit.Dims, checkDims) { + lims = append(lims, defaultLimit) + } + } + return lims +} + +func (c *LimitMeter) Meter(meterName string, value float64, extraDimensions Dimensions) error { + if c.provider.Enabled { + for _, lim := range c.GetLimits(meterName, extraDimensions) { + d1, d2 := lim.Span() + currentValue, _ := c.GetValue(meterName, d1, d2, lim.Dims) + if currentValue+value > lim.Limit { + log.Info().Str("meter", meterName).Str("user", c.userId).Float64("limit", lim.Limit).Float64("current", currentValue).Float64("add", value).Str("dims", fmt.Sprintf("%v", lim.Dims)).Msg("rate limited") + return errors.New("rate check: limited") + } else { + log.Info().Str("meter", meterName).Str("user", c.userId).Float64("limit", lim.Limit).Float64("current", currentValue).Float64("add", value).Str("dims", fmt.Sprintf("%v", lim.Dims)).Msg("rate check: ok") + } + } + } + return c.ApiMeter.Meter(meterName, value, extraDimensions) +} + +type UserMeterLimit struct { + User string + MeterName string + Dims Dimensions + Period string + Limit float64 +} + +func (lim *UserMeterLimit) Span() (time.Time, time.Time) { + now := time.Now().In(time.UTC) + d1 := now + d2 := now + if lim.Period == "hourly" { + d1 = time.Date(now.Year(), now.Month(), now.Day(), now.Hour(), 0, 0, 0, time.UTC) + d2 = d1.Add(3600 * time.Second) + } else if lim.Period == "daily" { + d1 = time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.UTC) + d2 = d1.AddDate(0, 0, 1) + } else if lim.Period == "monthly" { + d1 = time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, time.UTC) + d2 = d1.AddDate(0, 1, 0) + } else if lim.Period == "yearly" { + d1 = time.Date(now.Year(), 1, 1, 0, 0, 0, 0, time.UTC) + d2 = d1.AddDate(1, 0, 0) + } else if lim.Period == "total" { + d1 = time.Unix(0, 0) + d2 = time.Unix(1<<63-1, 0) + } else { + return now, now + } + return d1, d2 +} + +func parseGkUserLimits(v string) []UserMeterLimit { + var lims []UserMeterLimit + for _, productLimit := range gjson.Get(v, "product_limits").Map() { + for _, plim := range productLimit.Array() { + lim := UserMeterLimit{ + MeterName: plim.Get("amberflo_meter").String(), + Limit: plim.Get("limit_value").Float(), + Period: plim.Get("time_period").String(), + } + if dim := plim.Get("amberflo_dimension").String(); dim != "" { + lim.Dims = append(lim.Dims, Dimension{ + Key: dim, + Value: plim.Get("amberflo_dimension_value").String(), + }) + } + lims = append(lims, lim) + } + } + return lims +} diff --git a/internal/meters/limit_test.go b/internal/meters/limit_test.go new file mode 100644 index 00000000..efa33a34 --- /dev/null +++ b/internal/meters/limit_test.go @@ -0,0 +1,162 @@ +package meters + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestLimitMeter(t *testing.T) { + meterName := "testmeter" + user := testUser{name: "testuser"} + // cmp.DefaultLimits = testLims(meterName) + // for _, lim := range cmp.DefaultLimits { + for _, lim := range testLims(meterName) { + t.Run("", func(t *testing.T) { + mp := NewDefaultMeterProvider() + cmp := NewLimitMeterProvider(mp) + cmp.Enabled = true + cmp.DefaultLimits = []UserMeterLimit{lim} + testLimitMeter(t, + cmp, + lim.MeterName, + user, + lim, + ) + }) + } +} + +func TestLimitMeter_Amberflo(t *testing.T) { + mp, testConfig, err := getTestAmberfloMeter() + if err != nil { + t.Skip(err.Error()) + return + } + user := testUser{ + name: testConfig.user1.ID(), + data: map[string]string{"amberflo": "amberflo"}, + } + for _, lim := range testLims(testConfig.testMeter1) { + t.Run("", func(t *testing.T) { + cmp := NewLimitMeterProvider(mp) + cmp.Enabled = true + cmp.DefaultLimits = []UserMeterLimit{lim} + testLimitMeter(t, + cmp, + lim.MeterName, + user, + lim, + ) + }) + } +} + +func TestLimitMeter_Gatekeeper(t *testing.T) { + // JSON blob + gkData := ` + { + "product_limits": { + "tlv2_api": [ + { + "amberflo_dimension": "fv", + "amberflo_dimension_value": true, + "amberflo_meter": "testmeter", + "limit_value": 100, + "time_period": "monthly" + }, + { + "amberflo_dimension": "fv", + "amberflo_dimension_value": false, + "amberflo_meter": "testmeter", + "limit_value": 500, + "time_period": "monthly" + } + ] + }, + }` + user := testUser{name: "testuser"} + user.data = map[string]string{"gatekeeper": gkData} + lims := parseGkUserLimits(gkData) + for _, lim := range lims { + t.Run("", func(t *testing.T) { + mp := NewDefaultMeterProvider() + cmp := NewLimitMeterProvider(mp) + cmp.Enabled = true + testLimitMeter(t, + cmp, + lim.MeterName, + user, + lim, + ) + }) + } +} + +func testLims(meterName string) []UserMeterLimit { + testKey := 1 // time.Now().In(time.UTC).Unix() + lims := []UserMeterLimit{ + // foo tests + { + MeterName: meterName, + Period: "hourly", + Limit: 50.0, + Dims: Dimensions{{Key: "ok", Value: fmt.Sprintf("foo:%d", testKey)}}, + }, + { + MeterName: meterName, + Period: "daily", + Limit: 80.0, + Dims: Dimensions{{Key: "ok", Value: fmt.Sprintf("foo:%d", testKey)}}, + }, + { + MeterName: meterName, + Period: "monthly", + Limit: 110.0, + Dims: Dimensions{{Key: "ok", Value: fmt.Sprintf("foo:%d", testKey)}}, + }, + // bar tests + { + MeterName: meterName, + Period: "hourly", + Limit: 140.0, + Dims: Dimensions{{Key: "ok", Value: fmt.Sprintf("bar:%d", testKey)}}, + }, + { + MeterName: meterName, + Period: "daily", + Limit: 170.0, + Dims: Dimensions{{Key: "ok", Value: fmt.Sprintf("bar:%d", testKey)}}, + }, + { + MeterName: meterName, + Period: "monthly", + Limit: 200.0, + Dims: Dimensions{{Key: "ok", Value: fmt.Sprintf("bar:%d", testKey)}}, + }, + } + return lims +} + +func testLimitMeter(t *testing.T, cmp *LimitMeterProvider, meterName string, user testUser, lim UserMeterLimit) { + incr := 1.0 + m := cmp.NewMeter(user) + startTime, endTime := lim.Span() + base, _ := m.GetValue(meterName, startTime, endTime, lim.Dims) + + // Probably ok + if err := m.Meter(meterName, incr, lim.Dims); err != nil { + t.Error(err) + } + cmp.MeterProvider.Flush() + + // push past limit + if err := m.Meter(meterName, incr+lim.Limit, lim.Dims); err == nil { + t.Error("expected error, got none") + } + + // Check updated value + total, _ := m.GetValue(meterName, startTime, endTime, lim.Dims) + assert.Equal(t, base+incr, total, "expected total") +} diff --git a/internal/meters/meters.go b/internal/meters/meters.go index 02acac0b..5b75589d 100644 --- a/internal/meters/meters.go +++ b/internal/meters/meters.go @@ -3,6 +3,7 @@ package meters import ( "context" "net/http" + "time" "github.com/interline-io/transitland-server/auth/authn" ) @@ -10,9 +11,9 @@ import ( var meterCtxKey = struct{ name string }{"apiMeter"} type ApiMeter interface { - Meter(string, float64, map[string]string) error + Meter(string, float64, Dimensions) error AddDimension(string, string, string) - GetValue(string) (float64, bool) + GetValue(string, time.Time, time.Time, Dimensions) (float64, bool) } type MeterProvider interface { @@ -26,7 +27,7 @@ type MeterUser interface { GetExternalData(string) (string, bool) } -func WithMeter(apiMeter MeterProvider, meterName string, meterValue float64, dims map[string]string) func(http.Handler) http.Handler { +func WithMeter(apiMeter MeterProvider, meterName string, meterValue float64, dims Dimensions) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Make ctxMeter available in context @@ -43,3 +44,31 @@ func ForContext(ctx context.Context) ApiMeter { raw, _ := ctx.Value(meterCtxKey).(ApiMeter) return raw } + +type Dimension struct { + Key string + Value string +} + +type Dimensions []Dimension + +type eventAddDim struct { + MeterName string + Key string + Value string +} + +func dimsContainedIn(checkDims Dimensions, eventDims Dimensions) bool { + for _, matchDim := range checkDims { + match := false + for _, ed := range eventDims { + if ed.Key == matchDim.Key && ed.Value == matchDim.Value { + match = true + } + } + if !match { + return false + } + } + return true +} diff --git a/internal/meters/meters_test.go b/internal/meters/meters_test.go index 11ce0142..96ae6799 100644 --- a/internal/meters/meters_test.go +++ b/internal/meters/meters_test.go @@ -8,14 +8,19 @@ import ( type testUser struct { name string + data map[string]string } func (u testUser) ID() string { return u.name } -func (u testUser) GetExternalData(string) (string, bool) { - return "test", true +func (u testUser) GetExternalData(key string) (string, bool) { + if u.data == nil { + return "", false + } + a, ok := u.data[key] + return a, ok } type testMeterConfig struct { @@ -27,58 +32,96 @@ type testMeterConfig struct { } func testMeter(t *testing.T, mp MeterProvider, cfg testMeterConfig) { + d1, d2 := (&UserMeterLimit{Period: "hourly"}).Span() t.Run("Meter", func(t *testing.T) { m := mp.NewMeter(cfg.user1) - v, _ := m.GetValue(cfg.testMeter1) + v, _ := m.GetValue(cfg.testMeter1, d1, d2, nil) m.Meter(cfg.testMeter1, 1, nil) mp.Flush() - a, _ := m.GetValue(cfg.testMeter1) + a, _ := m.GetValue(cfg.testMeter1, d1, d2, nil) assert.Equal(t, 1.0, a-v) m.Meter(cfg.testMeter1, 1, nil) mp.Flush() - b, _ := m.GetValue(cfg.testMeter1) + b, _ := m.GetValue(cfg.testMeter1, d1, d2, nil) assert.Equal(t, 2.0, b-v) }) t.Run("NewMeter", func(t *testing.T) { m1 := mp.NewMeter(cfg.user1) - v1, _ := m1.GetValue(cfg.testMeter1) - v2, _ := m1.GetValue(cfg.testMeter2) + v1, _ := m1.GetValue(cfg.testMeter1, d1, d2, nil) + v2, _ := m1.GetValue(cfg.testMeter2, d1, d2, nil) m1.Meter(cfg.testMeter1, 1, nil) m1.Meter(cfg.testMeter2, 2, nil) mp.Flush() - va1, _ := m1.GetValue(cfg.testMeter1) + va1, _ := m1.GetValue(cfg.testMeter1, d1, d2, nil) assert.Equal(t, 1.0, va1-v1) - va2, _ := m1.GetValue(cfg.testMeter2) + va2, _ := m1.GetValue(cfg.testMeter2, d1, d2, nil) assert.Equal(t, 2.0, va2-v2) }) t.Run("GetValue", func(t *testing.T) { m1 := mp.NewMeter(cfg.user1) m2 := mp.NewMeter(cfg.user2) m3 := mp.NewMeter(cfg.user3) - v1, _ := m1.GetValue(cfg.testMeter1) - v2, _ := m2.GetValue(cfg.testMeter1) - v3, _ := m3.GetValue(cfg.testMeter1) + v1, _ := m1.GetValue(cfg.testMeter1, d1, d2, nil) + v2, _ := m2.GetValue(cfg.testMeter1, d1, d2, nil) + v3, _ := m3.GetValue(cfg.testMeter1, d1, d2, nil) m1.Meter(cfg.testMeter1, 1, nil) m2.Meter(cfg.testMeter1, 2.0, nil) mp.Flush() - a, ok := m1.GetValue(cfg.testMeter1) + a, ok := m1.GetValue(cfg.testMeter1, d1, d2, nil) assert.Equal(t, 1.0, a-v1) assert.Equal(t, true, ok) - a, ok = m2.GetValue(cfg.testMeter1) + a, ok = m2.GetValue(cfg.testMeter1, d1, d2, nil) assert.Equal(t, 2.0, a-v2) assert.Equal(t, true, ok) - a, _ = m3.GetValue(cfg.testMeter1) + a, _ = m3.GetValue(cfg.testMeter1, d1, d2, nil) + assert.Equal(t, 0.0, a-v3) + }) + + t.Run("GetValue match dims", func(t *testing.T) { + addDims1 := []Dimension{{Key: "test", Value: "a"}, {Key: "other", Value: "boo"}} + addDims2 := []Dimension{{Key: "test", Value: "b"}} + checkDims1 := []Dimension{{Key: "test", Value: "a"}} + checkDims2 := []Dimension{{Key: "test", Value: "b"}} + + m1 := mp.NewMeter(cfg.user1) + m2 := mp.NewMeter(cfg.user2) + m3 := mp.NewMeter(cfg.user3) + + // Initial values + v1, _ := m1.GetValue(cfg.testMeter1, d1, d2, checkDims1) + v2, _ := m2.GetValue(cfg.testMeter1, d1, d2, checkDims2) + v3, _ := m3.GetValue(cfg.testMeter1, d1, d2, checkDims1) + + // m1 meter + m1.Meter(cfg.testMeter1, 1, addDims1) + // m2 uses different dimension + m2.Meter(cfg.testMeter1, 2.0, addDims2) + mp.Flush() + + a, ok := m1.GetValue(cfg.testMeter1, d1, d2, checkDims1) + assert.Equal(t, 1.0, a-v1) + assert.Equal(t, true, ok) + + a, ok = m2.GetValue(cfg.testMeter1, d1, d2, checkDims1) + assert.Equal(t, 0.0, a) + assert.Equal(t, true, ok) + + a, ok = m2.GetValue(cfg.testMeter1, d1, d2, checkDims2) + assert.Equal(t, 2.0, a-v2) + assert.Equal(t, true, ok) + + a, _ = m3.GetValue(cfg.testMeter1, d1, d2, checkDims1) assert.Equal(t, 0.0, a-v3) }) } diff --git a/server/rest/feed_version_download.go b/server/rest/feed_version_download.go index d18ed4e6..bf5fdc91 100644 --- a/server/rest/feed_version_download.go +++ b/server/rest/feed_version_download.go @@ -76,10 +76,10 @@ func feedVersionDownloadLatestHandler(cfg restConfig, w http.ResponseWriter, r * // Send request to metering if apiMeter := meters.ForContext(r.Context()); apiMeter != nil { - dims := map[string]string{ - "fv_sha1": fvsha1, - "feed_onestop_id": fid, - "is_latest_feed_version": "true", + dims := []meters.Dimension{ + {Key: "fv_sha1", Value: fvsha1}, + {Key: "feed_onestop_id", Value: fid}, + {Key: "is_latest_feed_version", Value: "true"}, } apiMeter.Meter("feed-version-downloads", 1.0, dims) } @@ -151,10 +151,10 @@ func feedVersionDownloadHandler(cfg restConfig, w http.ResponseWriter, r *http.R // Send request to metering if apiMeter := meters.ForContext(r.Context()); apiMeter != nil { - dims := map[string]string{ - "fv_sha1": fvsha1, - "feed_onestop_id": fid, - "is_latest_feed_version": "false", + dims := []meters.Dimension{ + {Key: "fv_sha1", Value: fvsha1}, + {Key: "feed_onestop_id", Value: fid}, + {Key: "is_latest_feed_version", Value: "false"}, } apiMeter.Meter("feed-version-downloads", 1.0, dims) } diff --git a/server/server_cmd.go b/server/server_cmd.go index 596fd3ba..fa80054c 100644 --- a/server/server_cmd.go +++ b/server/server_cmd.go @@ -51,6 +51,7 @@ type Command struct { EnableJobsApi bool EnableWorkers bool EnableProfiler bool + EnableRateLimits bool LoadAdmins bool QueuePrefix string SecretsFile string @@ -129,7 +130,8 @@ func (cmd *Command) Parse(args []string) error { // Metering // fl.BoolVar(&cmd.EnableMetering, "enable-metering", false, "Enable metering") fl.StringVar(&cmd.metersConfig.MeteringProvider, "metering-provider", "", "Use metering provider") - fl.StringVar(&cmd.metersConfig.MeteringAmberfloConfig, "metering-amberflo-config", "", "Use provided config for AmberFlo metering") + fl.StringVar(&cmd.metersConfig.MeteringAmberfloConfig, "metering-amberflo-config", "", "Use provided config for Amberflo metering") + fl.BoolVar(&cmd.EnableRateLimits, "enable-rate-limits", false, "Enable rate limits") // Jobs fl.BoolVar(&cmd.EnableJobsApi, "enable-jobs-api", false, "Enable job api") @@ -233,10 +235,10 @@ func (cmd *Command) Run() error { // Setup metering var meterProvider meters.MeterProvider - meterProvider = meters.NewDefaultMeter() + meterProvider = meters.NewDefaultMeterProvider() if cmd.metersConfig.EnableMetering { if cmd.metersConfig.MeteringProvider == "amberflo" { - a := meters.NewAmberFlo(os.Getenv("AMBERFLO_APIKEY"), 30*time.Second, 100) + a := meters.NewAmberflo(os.Getenv("AMBERFLO_APIKEY"), 30*time.Second, 100) if cmd.metersConfig.MeteringAmberfloConfig != "" { if err := a.LoadConfig(cmd.metersConfig.MeteringAmberfloConfig); err != nil { return err @@ -244,6 +246,12 @@ func (cmd *Command) Run() error { } meterProvider = a } + if cmd.EnableRateLimits { + mp := meters.NewLimitMeterProvider(meterProvider) + mp.Enabled = true + // mp.DefaultLimits = append(mp.DefaultLimits, meters.UserMeterLimit{Limit: 10, Period: "monthly", MeterName: "rest"}) + meterProvider = mp + } defer meterProvider.Close() }