diff --git a/lib/rack/attack.rb b/lib/rack/attack.rb index 0f375677..d7242dc2 100644 --- a/lib/rack/attack.rb +++ b/lib/rack/attack.rb @@ -13,6 +13,7 @@ class Rack::Attack autoload :DalliProxy, 'rack/attack/store_proxy/dalli_proxy' autoload :MemCacheProxy, 'rack/attack/store_proxy/mem_cache_proxy' autoload :RedisStoreProxy, 'rack/attack/store_proxy/redis_store_proxy' + autoload :RedisDistributedStoreProxy, 'rack/attack/store_proxy/redis_distributed_store_proxy' autoload :Fail2Ban, 'rack/attack/fail2ban' autoload :Allow2Ban, 'rack/attack/allow2ban' autoload :Request, 'rack/attack/request' diff --git a/lib/rack/attack/store_proxy.rb b/lib/rack/attack/store_proxy.rb index 4d698538..165e85da 100644 --- a/lib/rack/attack/store_proxy.rb +++ b/lib/rack/attack/store_proxy.rb @@ -1,10 +1,10 @@ module Rack class Attack module StoreProxy - PROXIES = [DalliProxy, MemCacheProxy, RedisStoreProxy] + PROXIES = [DalliProxy, MemCacheProxy, RedisStoreProxy, RedisDistributedStoreProxy] ACTIVE_SUPPORT_WRAPPER_CLASSES = Set.new(['ActiveSupport::Cache::MemCacheStore', 'ActiveSupport::Cache::RedisStore']).freeze - ACTIVE_SUPPORT_CLIENTS = Set.new(['Redis::Store', 'Dalli::Client', 'MemCache']).freeze + ACTIVE_SUPPORT_CLIENTS = Set.new(['Redis::Store', 'Redis::DistributedStore', 'Dalli::Client', 'MemCache']).freeze def self.build(store) client = unwrap_active_support_stores(store) diff --git a/lib/rack/attack/store_proxy/redis_distributed_store_proxy.rb b/lib/rack/attack/store_proxy/redis_distributed_store_proxy.rb new file mode 100644 index 00000000..acca3f56 --- /dev/null +++ b/lib/rack/attack/store_proxy/redis_distributed_store_proxy.rb @@ -0,0 +1,40 @@ +require 'delegate' +require 'rack/attack/store_proxy/redis_store_proxy' + +module Rack + class Attack + module StoreProxy + class RedisDistributedStoreProxy < RedisStoreProxy + def self.handle?(store) + defined?(::Redis::DistributedStore) && store.is_a?(::Redis::DistributedStore) + end + + # overrride #increment to use a Lua script as Redis::Distributed + # does not support pipelining (even when all keys got to the same node) + def increment(key, amount, options={}) + evalsha(script_sha, keys: [key], argv:[amount, options[:expires_in]]) + rescue Redis::BaseError + end + + private + + def script_sha + @script_sha ||= begin + shas = script 'load', %{ + -- KEYS[1]: key to increment + -- ARGV[1]: amount to increment by + -- ARGV[2]: updated TTL if any + local value = redis.call('INCRBY', KEYS[1], tonumber(ARGV[1])) + local ttl = tonumber(ARGV[2]) + if ttl then + redis.call('EXPIRE', KEYS[1], ttl) + end + return value + } + shas.kind_of?(Array) ? shas.first : shas + end + end + end + end + end +end diff --git a/spec/integration/rack_attack_cache_spec.rb b/spec/integration/rack_attack_cache_spec.rb index 6eb27eff..f50dc2b0 100644 --- a/spec/integration/rack_attack_cache_spec.rb +++ b/spec/integration/rack_attack_cache_spec.rb @@ -24,6 +24,7 @@ def sleep_until_expired ActiveSupport::Cache::MemoryStore.new, ActiveSupport::Cache::DalliStore.new("127.0.0.1"), ActiveSupport::Cache::RedisStore.new("127.0.0.1"), + ActiveSupport::Cache::RedisStore.new(%w[127.0.0.1/1 127.0.0.1/2]), ActiveSupport::Cache::MemCacheStore.new("127.0.0.1"), Dalli::Client.new, ConnectionPool.new { Dalli::Client.new },