diff --git a/lib/rack/attack/cache.rb b/lib/rack/attack/cache.rb index 0e1f606e..75bdaf6a 100644 --- a/lib/rack/attack/cache.rb +++ b/lib/rack/attack/cache.rb @@ -22,9 +22,9 @@ def store=(store) end end - def count(unprefixed_key, period) + def count(unprefixed_key, period, weight = 1) key, expires_in = key_and_expiry(unprefixed_key, period) - do_count(key, expires_in) + do_count(key, expires_in, weight) end def read(unprefixed_key) @@ -67,19 +67,19 @@ def key_and_expiry(unprefixed_key, period) ["#{prefix}:#{(@last_epoch_time / period).to_i}:#{unprefixed_key}", expires_in] end - def do_count(key, expires_in) + def do_count(key, expires_in, weight) enforce_store_presence! enforce_store_method_presence!(:increment) - result = store.increment(key, 1, expires_in: expires_in) + result = store.increment(key, weight, expires_in: expires_in) # NB: Some stores return nil when incrementing uninitialized values if result.nil? enforce_store_method_presence!(:write) - store.write(key, 1, expires_in: expires_in) + store.write(key, weight, expires_in: expires_in) end - result || 1 + result || weight end def enforce_store_presence! diff --git a/lib/rack/attack/throttle.rb b/lib/rack/attack/throttle.rb index 1cc50f40..badf5832 100644 --- a/lib/rack/attack/throttle.rb +++ b/lib/rack/attack/throttle.rb @@ -5,7 +5,7 @@ class Attack class Throttle MANDATORY_OPTIONS = [:limit, :period].freeze - attr_reader :name, :limit, :period, :block, :type + attr_reader :name, :limit, :period, :weight, :block, :type def initialize(name, options, &block) @name = name @@ -15,6 +15,7 @@ def initialize(name, options, &block) end @limit = options[:limit] @period = options[:period].respond_to?(:call) ? options[:period] : options[:period].to_i + @weight = options[:weight].respond_to?(:call) ? options[:weight] : (options[:weight] || 1).to_i @type = options.fetch(:type, :throttle) end @@ -28,7 +29,8 @@ def matched_by?(request) current_period = period_for(request) current_limit = limit_for(request) - count = cache.count("#{name}:#{discriminator}", current_period) + current_weight = weight_for(request) + count = cache.count("#{name}:#{discriminator}", current_period, current_weight) data = { discriminator: discriminator, @@ -65,6 +67,10 @@ def limit_for(request) limit.respond_to?(:call) ? limit.call(request) : limit end + def weight_for(request) + weight.respond_to?(:call) ? weight.call(request) : weight + end + def annotate_request_with_throttle_data(request, data) (request.env['rack.attack.throttle_data'] ||= {})[name] = data end diff --git a/spec/acceptance/throttling_spec.rb b/spec/acceptance/throttling_spec.rb index 0db89dd6..09dd39d1 100644 --- a/spec/acceptance/throttling_spec.rb +++ b/spec/acceptance/throttling_spec.rb @@ -34,6 +34,79 @@ end end + it "supports a non-1 constant weight" do + Rack::Attack.throttle("by ip", limit: 4, period: 60, weight: 2) do |request| + request.ip + end + + get "/", {}, "REMOTE_ADDR" => "1.2.3.4" + + assert_equal 200, last_response.status + + get "/", {}, "REMOTE_ADDR" => "1.2.3.4" + + assert_equal 200, last_response.status + + get "/", {}, "REMOTE_ADDR" => "1.2.3.4" + + assert_equal 429, last_response.status + assert_nil last_response.headers["Retry-After"] + assert_equal "Retry later\n", last_response.body + + get "/", {}, "REMOTE_ADDR" => "5.6.7.8" + + assert_equal 200, last_response.status + + Timecop.travel(60) do + get "/", {}, "REMOTE_ADDR" => "1.2.3.4" + + assert_equal 200, last_response.status + end + end + + it "supports a dynamic weight" do + weight_proc = lambda do |request| + if request.env["X-APIKey"] == "private-secret" + 3 + else + 2 + end + end + Rack::Attack.throttle("by ip", limit: 4, period: 60, weight: weight_proc) do |request| + request.ip + end + + get "/", {}, "REMOTE_ADDR" => "1.2.3.4" + + assert_equal 200, last_response.status + + get "/", {}, "REMOTE_ADDR" => "1.2.3.4" + + assert_equal 200, last_response.status + + get "/", {}, "REMOTE_ADDR" => "1.2.3.4" + + assert_equal 429, last_response.status + assert_nil last_response.headers["Retry-After"] + assert_equal "Retry later\n", last_response.body + + get "/", {}, "REMOTE_ADDR" => "5.6.7.8", "X-APIKey" => "private-secret" + + assert_equal 200, last_response.status + + get "/", {}, "REMOTE_ADDR" => "5.6.7.8", "X-APIKey" => "private-secret" + + assert_equal 429, last_response.status + assert_nil last_response.headers["Retry-After"] + assert_equal "Retry later\n", last_response.body + + Timecop.travel(60) do + get "/", {}, "REMOTE_ADDR" => "1.2.3.4" + + assert_equal 200, last_response.status + end + end + it "returns correct Retry-After header if enabled" do Rack::Attack.throttled_response_retry_after_header = true