From fcbcc5db3a2a375ebb0d4a2f02c7dcf5c91e261e Mon Sep 17 00:00:00 2001 From: Luis Borjas Reyes Date: Thu, 21 Jan 2021 21:28:01 -0500 Subject: [PATCH 1/7] wip: adds time and cache effects --- geocode-city-api.cabal | 7 ++++- package.yaml | 1 + src/Effects.hs | 6 ++++ src/Effects/Cache.hs | 62 ++++++++++++++++++++++++++++++++++++++++++ src/Effects/Time.hs | 32 ++++++++++++++++++++++ src/Server/Auth.hs | 24 ++++++++++++++++ src/Server/Handlers.hs | 10 +++++++ src/Server/Run.hs | 40 +++++++++++++++++---------- src/Server/Types.hs | 6 ++-- 9 files changed, 171 insertions(+), 17 deletions(-) create mode 100644 src/Effects/Cache.hs create mode 100644 src/Effects/Time.hs diff --git a/geocode-city-api.cabal b/geocode-city-api.cabal index b885cc2..d5951d9 100644 --- a/geocode-city-api.cabal +++ b/geocode-city-api.cabal @@ -4,7 +4,7 @@ cabal-version: 1.12 -- -- see: https://github.com/sol/hpack -- --- hash: 525eaceaf2e2d5cc83f923332a87f464440a973b44bfca60b8e8cce1368f11e3 +-- hash: 5d0596bb832674fa5ce623837a77f59bb8ffde9e470b037e2efa79949410c210 name: geocode-city-api version: 0.1.0.0 @@ -32,8 +32,10 @@ library Database.Pool Database.Queries Effects + Effects.Cache Effects.Database Effects.Log + Effects.Time Import Server.Auth Server.Handlers @@ -52,6 +54,7 @@ library , containers , envy , fused-effects >=1.1.1.0 && <1.2 + , hedis , http-api-data , http-types , lens @@ -86,6 +89,7 @@ executable geocode-city-api-exe , envy , fused-effects >=1.1.1.0 && <1.2 , geocode-city-api + , hedis , http-api-data , http-types , lens @@ -125,6 +129,7 @@ test-suite geocode-city-api-test , envy , fused-effects >=1.1.1.0 && <1.2 , geocode-city-api + , hedis , hspec , hspec-wai , hspec-wai-json diff --git a/package.yaml b/package.yaml index 9d4bf75..59825b1 100644 --- a/package.yaml +++ b/package.yaml @@ -42,6 +42,7 @@ dependencies: - resource-pool ^>= 0.2.3.2 - containers - lens +- hedis ghc-options: - -Wall diff --git a/src/Effects.hs b/src/Effects.hs index 4d135f1..8a31aee 100644 --- a/src/Effects.hs +++ b/src/Effects.hs @@ -1,9 +1,15 @@ module Effects ( module Effects.Log , module Effects.Database + , module Effects.Time + , module Effects.Cache ) where import Effects.Log import Effects.Database +-- +import Effects.Cache +-- +import Effects.Time diff --git a/src/Effects/Cache.hs b/src/Effects/Cache.hs new file mode 100644 index 0000000..02488ef --- /dev/null +++ b/src/Effects/Cache.hs @@ -0,0 +1,62 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} + +module Effects.Cache where + +import Control.Algebra +import Control.Carrier.Error.Either (ErrorC, Throw, runError) +import Control.Carrier.Reader +import Control.Carrier.Throw.Either (throwError) +import qualified Database.Redis as R +import Import + +--- + +data Cache (m :: Type -> Type) k where + HllAdd :: ByteString -> [ByteString] -> Cache m Integer + HllCount :: [ByteString] -> Cache m Integer + +hllAdd :: (Has Cache sig m) => ByteString -> [ByteString] -> m Integer +hllAdd key values = send $ HllAdd key values + +hllCount :: (Has Cache sig m) => [ByteString] -> m Integer +hllCount = send . HllCount + +newtype CacheIOC m a = CacheIOC {runCacheIO :: ReaderC R.Connection m a} + deriving (Applicative, Functor, Monad, MonadIO) + +newtype CacheError = CacheError R.Reply + deriving (Eq, Show) + +runCacheWithConnection :: R.Connection -> CacheIOC m hs -> m hs +runCacheWithConnection conn = runReader conn . runCacheIO + +instance + (Has (Throw CacheError) sig m, MonadIO m, Algebra sig m) => + Algebra (Cache :+: sig) (CacheIOC m) + where + alg hdl sig ctx = CacheIOC $ case sig of + L (HllAdd key values) -> do + conn <- ask + added <- liftIO $ + R.runRedis conn $ do + R.pfadd key values + (<$ ctx) <$> either (throwError . CacheError) pure added + L (HllCount keys) -> do + conn <- ask + count <- liftIO $ + R.runRedis conn $ do + R.pfcount keys + (<$ ctx) <$> either (throwError . CacheError) pure count + R other -> alg (runCacheIO . hdl) (R other) ctx + +runCacheEither :: ErrorC CacheError m a -> m (Either CacheError a) +runCacheEither = runError @CacheError diff --git a/src/Effects/Time.hs b/src/Effects/Time.hs new file mode 100644 index 0000000..4d15e6d --- /dev/null +++ b/src/Effects/Time.hs @@ -0,0 +1,32 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} + +module Effects.Time where + +import Control.Algebra +import Data.Time (UTCTime, getCurrentTime) +import Import + +data Time (m :: Type -> Type) k where + Now :: Time m UTCTime + +now :: (Has Time sig m) => m UTCTime +now = send Now + +newtype TimeIOC m a = TimeIOC {runTimeIO :: m a} + deriving (Applicative, Functor, Monad, MonadIO) + +instance + (MonadIO m, Algebra sig m) => + Algebra (Time :+: sig) (TimeIOC m) + where + alg hdl sig ctx = case sig of + L Now -> (<$ ctx) <$> liftIO getCurrentTime + R other -> TimeIOC $ alg (runTimeIO . hdl) other ctx diff --git a/src/Server/Auth.hs b/src/Server/Auth.hs index b35ad7a..2003d29 100644 --- a/src/Server/Auth.hs +++ b/src/Server/Auth.hs @@ -33,15 +33,27 @@ import Data.Swagger securityDefinitions, ) import Servant.Swagger +import qualified Data.ByteString as BS newtype ApiKey = ApiKey Text deriving (Eq, Show) +newtype RequestID = RequestID ByteString + deriving (Eq, Show) + +data RequestKey + = ByIP String RequestID + | ByApiKey ApiKey RequestID + deriving (Eq, Show) type ApiKeyAuth = AuthHandler Request ApiKey mkApiKey :: ByteString -> ApiKey mkApiKey = ApiKey . decodeUtf8 +getLastIP :: ByteString -> Maybe ByteString +getLastIP bs = + BS.split 44 bs + & lastMaybe authHandler :: ApiKeyAuth authHandler = mkAuthHandler handler @@ -57,6 +69,18 @@ authHandler = & queryString & L.lookup "api-key" & fromMaybe Nothing + extractRequestId req = + req + & requestHeaders + & L.lookup "x-request-id" + <&> RequestID + + extractRequestIP req = + req + & requestHeaders + & L.lookup "x-forwarded-for" + <&> getLastIP + & fromMaybe Nothing handler req = either throw401 pure $ do extractApiKeyHeader req <|> extractApiKeyParam req <&> mkApiKey diff --git a/src/Server/Handlers.hs b/src/Server/Handlers.hs index ed435f4..c2b26e2 100644 --- a/src/Server/Handlers.hs +++ b/src/Server/Handlers.hs @@ -10,6 +10,8 @@ import Server.Auth import Control.Lens ((.~), (?~)) import Data.Swagger import Servant.Swagger +import Data.Time +import Effects.Cache service :: AppM sig m => ServerT Service m service = @@ -27,6 +29,8 @@ stats apiKey = validateApiKey apiKey >> do validateApiKey :: (AppM sig m) => ApiKey -> m () validateApiKey (ApiKey apiKey) = do + -- example use, not really anything to write home about yet + count <- hllAdd "a" ["b"] isValidKey <- Q.isKeyEnabled apiKey if isValidKey then pure () @@ -97,3 +101,9 @@ serializeCityResult Q.CityQ {..} = elevation = cElevation, population = cPopulation } + +hourString :: UTCTime -> String +hourString = formatTime defaultTimeLocale "%Y-%m-%dT%H" + +monthString :: UTCTime -> String +monthString = formatTime defaultTimeLocale "%Y-%m-%d" diff --git a/src/Server/Run.hs b/src/Server/Run.hs index c757c79..c05bc96 100644 --- a/src/Server/Run.hs +++ b/src/Server/Run.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE TypeApplications #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} @@ -10,27 +11,22 @@ import qualified Data.Pool as P import qualified Database.Pool as DB import qualified Database.PostgreSQL.Simple as PG import Effects - ( LogStdoutC (runLogStdout), + (CacheError, runCacheWithConnection, LogStdoutC (runLogStdout), reinterpretLog, runDatabaseWithConnection, + runTimeIO ) import Import - ( Applicative (pure), - IO, - Proxy (..), - either, - ($), - (&), - ) import qualified Network.Wai.Handler.Warp as Warp -import Servant (Application, HasServer (hoistServerWithContext), serveWithContext, throwError) +import Servant (Application, HasServer (hoistServerWithContext), ServerError (errBody), err500, serveWithContext, throwError) import Server.Handlers (service) import Server.Types (proxyService) import Server.Auth (ApiKeyAuth, authContext) +import qualified Database.Redis as R -- | Build a wai app with a connection pool -application :: P.Pool PG.Connection -> Application -application pool = +application :: R.Connection -> P.Pool PG.Connection -> Application +application cacheConn pool = serveWithContext proxyService authContext $ hoistServerWithContext proxyService @@ -40,17 +36,33 @@ application pool = where transform handler = do res <- P.withResource pool runEffects - either Servant.throwError pure res + either Servant.throwError pure (handleError res) where runEffects conn = handler + & runTimeIO + & runCacheWithConnection cacheConn & runDatabaseWithConnection conn & reinterpretLog renderLogMessage & runLogStdout - & runError + & runError @ServerError + & runError @CacheError & runM +-- TODO: we can probably have a runCacheSafe interpreter somewhere +-- that allows us to handle this deeper in, in the handlers; +-- though maybe it's preferrable to handle it here? +-- At the moment, it is, since we don't really care about specific +-- cache errors in each call site. +handleError :: Either CacheError (Either ServerError a) -> Either ServerError a +handleError err = + case err of + Left _cacheError -> Servant.throwError err500 {errBody = "Cache error"} + Right s -> s + +-- | Given config, start the app start :: AppConfig -> IO () start cfg = do pool <- DB.initPool (appDatabaseUrl cfg) - Warp.run (appPort cfg) (application pool) + redis <- R.checkedConnect R.defaultConnectInfo + Warp.run (appPort cfg) (application redis pool) diff --git a/src/Server/Types.hs b/src/Server/Types.hs index 995de60..77a25de 100644 --- a/src/Server/Types.hs +++ b/src/Server/Types.hs @@ -13,7 +13,7 @@ import Import import Config (LogMessage) import Control.Carrier.Error.Either (Throw) import Data.Time (Day) -import Effects (Database, Log) +import Effects (Cache, Database, Log, Time) import Data.Aeson ( Options (fieldLabelModifier), ToJSON (toJSON), @@ -49,7 +49,9 @@ type Service = SwaggerAPI :<|> ApiRoutes type AppM sig m = ( Has (Log LogMessage) sig m, Has (Throw ServerError) sig m, - Has Database sig m + Has Database sig m, + Has Time sig m, + Has Cache sig m ) proxyService :: Proxy Service From b9560ed23d0ea7678c672e8392fd9cdcd2817eba Mon Sep 17 00:00:00 2001 From: Luis Borjas Reyes Date: Fri, 22 Jan 2021 11:33:45 -0500 Subject: [PATCH 2/7] More incomplete strides towards rate limiting: DONE: * Add new type to identify a request by either IP or Api Key * Handle IPs or Api Keys differently when it comes to usage limiting TODO: * Store quota in DB * Actually initialize redis * Don't look for IPs in dev/test? i.e. how to treat "anon" access, or lack of request ids? Right now, a dev/test request will never default to IP auth, and it will get a fake request id assigned. Maybe it should default to the remote addr method? * Set informative headers --- src/Server/Auth.hs | 115 +++++++++++++++++++++++++++++------------ src/Server/Handlers.hs | 68 ++++++++++++++++++------ src/Server/Run.hs | 1 + src/Server/Types.hs | 2 +- 4 files changed, 136 insertions(+), 50 deletions(-) diff --git a/src/Server/Auth.hs b/src/Server/Auth.hs index 2003d29..a0956d6 100644 --- a/src/Server/Auth.hs +++ b/src/Server/Auth.hs @@ -33,7 +33,7 @@ import Data.Swagger securityDefinitions, ) import Servant.Swagger -import qualified Data.ByteString as BS +import qualified Data.ByteString.Char8 as B8 newtype ApiKey = ApiKey Text deriving (Eq, Show) @@ -41,54 +41,103 @@ newtype ApiKey = ApiKey Text newtype RequestID = RequestID ByteString deriving (Eq, Show) +newtype IPAddress = IPAddress ByteString + deriving (Eq, Show) + +-- | How a request is identified for rate-limiting purposes. data RequestKey - = ByIP String RequestID + = ByIP IPAddress RequestID | ByApiKey ApiKey RequestID deriving (Eq, Show) -type ApiKeyAuth = AuthHandler Request ApiKey -mkApiKey :: ByteString -> ApiKey -mkApiKey = ApiKey . decodeUtf8 +-- | Authentication (identification) by either Api Key or IP +type ApiKeyAuth = AuthHandler Request RequestKey -getLastIP :: ByteString -> Maybe ByteString -getLastIP bs = - BS.split 44 bs - & lastMaybe +-- | Given a handler, try to extract an identifier: either an API Key or +-- IP Address. A little bit of Maybe blindness here: we don't really +-- care to publicize that "anonymous" requests can be made, since the only +-- way an IP address can't be found is outside of a production env. authHandler :: ApiKeyAuth authHandler = mkAuthHandler handler where - maybeToEither e = maybe (Left e) Right throw401 msg = throwError $ err401 {errBody = msg} - extractApiKeyHeader req = - req - & requestHeaders - & L.lookup "x-geocode-city-api-key" - extractApiKeyParam req = - req - & queryString - & L.lookup "api-key" - & fromMaybe Nothing - extractRequestId req = - req - & requestHeaders - & L.lookup "x-request-id" - <&> RequestID - - extractRequestIP req = - req - & requestHeaders - & L.lookup "x-forwarded-for" - <&> getLastIP - & fromMaybe Nothing handler req = either throw401 pure $ do - extractApiKeyHeader req <|> extractApiKeyParam req - <&> mkApiKey + authWithApiKey req <|> authWithIP req & maybeToEither "Missing API key header (X-Geocode-City-Api-Key) or query param (api-key)" authContext :: Context (ApiKeyAuth ': '[]) authContext = authHandler :. EmptyContext + +--- +--- HELPERS +--- +maybeToEither :: a -> Maybe b -> Either a b +maybeToEither e = maybe (Left e) Right + +-- | Make an ApiKey +mkApiKey :: ByteString -> ApiKey +mkApiKey = ApiKey . decodeUtf8 + +-- TODO: maybe generate a random UUID? This is only really +-- useful for dev/test, when it's annoying to set `x-request-id` +-- by hand +mkRequestId :: Maybe ByteString -> Maybe RequestID +mkRequestId Nothing = Just . RequestID $ "fake-id" +mkRequestId rid = RequestID <$> rid + +getLastIP :: ByteString -> Maybe IPAddress +getLastIP bs = + B8.split ',' bs + & lastMaybe + <&> IPAddress + +-- | Find API Key in @X-Geocode-City-Api-Key@ header +extractApiKeyHeader :: Request -> Maybe ByteString +extractApiKeyHeader req = + req + & requestHeaders + & L.lookup "x-geocode-city-api-key" + +-- | Find API Key in @api-key@ querystring parameter +extractApiKeyParam :: Request -> Maybe ByteString +extractApiKeyParam req = + req + & queryString + & L.lookup "api-key" + & fromMaybe Nothing + +-- Both x-request-id and x-forwarded-for are Heroku-isms; +-- they're not reliable (or set!) in other environments +-- but we only use them to identify requests (not authenticate/authorize.) +-- | Extract request ID from `x-request-id` header. Default to +extractRequestId :: Request -> Maybe RequestID +extractRequestId req = + req + & requestHeaders + & L.lookup "x-request-id" + & mkRequestId + +extractRequestIP :: Request -> Maybe IPAddress +extractRequestIP req = + req + & requestHeaders + & L.lookup "x-forwarded-for" + <&> getLastIP + & fromMaybe Nothing +authWithApiKey :: Request -> Maybe RequestKey +authWithApiKey req = do + apiKey <- mkApiKey <$> (extractApiKeyHeader req <|> extractApiKeyParam req) + requestID <- extractRequestId req + pure $ ByApiKey apiKey requestID + +authWithIP :: Request -> Maybe RequestKey +authWithIP req = do + ip <- extractRequestIP req + requestID <- extractRequestId req + pure $ ByIP ip requestID + --- --- Swagger instances --- diff --git a/src/Server/Handlers.hs b/src/Server/Handlers.hs index c2b26e2..7e89b77 100644 --- a/src/Server/Handlers.hs +++ b/src/Server/Handlers.hs @@ -12,6 +12,7 @@ import Data.Swagger import Servant.Swagger import Data.Time import Effects.Cache +import Effects.Time (now) service :: AppM sig m => ServerT Service m service = @@ -21,40 +22,75 @@ service = :<|> search :<|> reverseGeocode -stats :: (AppM sig m) => ApiKey -> m Stats +stats :: (AppM sig m) => RequestKey -> m Stats stats apiKey = validateApiKey apiKey >> do count <- Q.cityCount update <- Q.latestUpdate return $ Stats update count -validateApiKey :: (AppM sig m) => ApiKey -> m () -validateApiKey (ApiKey apiKey) = do - -- example use, not really anything to write home about yet - count <- hllAdd "a" ["b"] - isValidKey <- Q.isKeyEnabled apiKey - if isValidKey - then pure () - else throwError err403 {errBody = "Invalid API Key."} - -autoComplete :: (AppM sig m) => ApiKey -> Text -> Maybe Int -> m [CityAutocomplete] +autoComplete :: (AppM sig m) => RequestKey -> Text -> Maybe Int -> m [CityAutocomplete] autoComplete apiKey q limit = validateApiKey apiKey >> do results <- Q.cityAutoComplete q limit return $ map serializeAutocompleteResult results -- | Search city by name -search :: (AppM sig m) => ApiKey -> Text -> Maybe Int -> m [City] +search :: (AppM sig m) => RequestKey -> Text -> Maybe Int -> m [City] search apiKey q limit = validateApiKey apiKey >> do results <- Q.citySearch q limit return $ map serializeCityResult results -- | Search city by (lat, lng) -reverseGeocode :: (AppM sig m) => ApiKey -> Latitude -> Longitude -> Maybe Int -> m [City] +reverseGeocode :: (AppM sig m) => RequestKey -> Latitude -> Longitude -> Maybe Int -> m [City] reverseGeocode apiKey lat lng limit = validateApiKey apiKey >> do results <- Q.reverseSearch (lng, lat) limit return $ map serializeCityResult results + +--- +--- Rate Limiting +--- + +-- TODO: return requests left in quota, rate limit reset timestamp + +-- | Rate limit based on api key: must be valid and under allocated quota in the current (UTC) month +validateApiKey :: (AppM sig m) => RequestKey -> m () +validateApiKey (ByApiKey (ApiKey apiKey) (RequestID requestId)) = do + currentTime <- now + let cacheKey = "requests:" <> encodeUtf8 apiKey <> encodeUtf8 (monthString currentTime) + _added <- hllAdd cacheKey [requestId] + count <- hllCount [cacheKey] + -- TODO: retrieve quota from DB + isValidKey <- Q.isKeyEnabled apiKey + if isValidKey + then + if count < 100000 then + pure () + else + throwError err429 {errBody = "Monthly request limit exceeded for API Key."} + else throwError err403 {errBody = "Invalid API Key."} + +-- | Rate-limit based on (real) IP: must be under 1000 requests in the current (UTC) day +validateApiKey (ByIP (IPAddress ipAddress) (RequestID requestId)) = do + currentTime <- now + let cacheKey = "requests:" <> ipAddress <> encodeUtf8 (dayString currentTime) + _added <- hllAdd cacheKey [requestId] + count <- hllCount [cacheKey] + if count < 1000 then + pure () + else + throwError err429 {errBody = "Daily request limit exceeded for IP Address (try using an API Key)."} + + +err429 :: ServerError +err429 = + ServerError { errHTTPCode = 429 + , errReasonPhrase = "Too Many Requests" + , errBody = "" + , errHeaders = [] + } + --- --- Swagger --- @@ -102,8 +138,8 @@ serializeCityResult Q.CityQ {..} = population = cPopulation } -hourString :: UTCTime -> String -hourString = formatTime defaultTimeLocale "%Y-%m-%dT%H" +dayString :: UTCTime -> String +dayString = formatTime defaultTimeLocale "%Y-%m-%d" monthString :: UTCTime -> String -monthString = formatTime defaultTimeLocale "%Y-%m-%d" +monthString = formatTime defaultTimeLocale "%Y-%m" diff --git a/src/Server/Run.hs b/src/Server/Run.hs index c05bc96..fe5df40 100644 --- a/src/Server/Run.hs +++ b/src/Server/Run.hs @@ -64,5 +64,6 @@ handleError err = start :: AppConfig -> IO () start cfg = do pool <- DB.initPool (appDatabaseUrl cfg) + -- TODO: get REDIS_URL from config, parse (and fail if parsing fails) redis <- R.checkedConnect R.defaultConnectInfo Warp.run (appPort cfg) (application redis pool) diff --git a/src/Server/Types.hs b/src/Server/Types.hs index 77a25de..69dfe74 100644 --- a/src/Server/Types.hs +++ b/src/Server/Types.hs @@ -62,7 +62,7 @@ proxyApi = Proxy -- | Auth type for api keys -- from: https://docs.servant.dev/en/stable/tutorial/Authentication.html#generalized-authentication -type instance AuthServerData (AuthProtect "api-key") = ApiKey +type instance AuthServerData (AuthProtect "api-key") = RequestKey --- --- REQUEST TYPES From 07b04c04737c6a598c79a6d099cdd569f504704a Mon Sep 17 00:00:00 2001 From: Luis Borjas Reyes Date: Sat, 23 Jan 2021 15:35:16 -0500 Subject: [PATCH 3/7] Solidify request ID and IP extraction * No fake request IDs: either you provide it, or you don't * Consider cases when the x-f-f header only includes one IP, and strip leading spaces (we could also split with a regex!) * Pave the way for "Dev Mode" auth: if no credentials were obtained, see if we flat out deny, or allow. --- src/Server/Auth.hs | 60 +++++++++++++++++++++++++++++++--------------- 1 file changed, 41 insertions(+), 19 deletions(-) diff --git a/src/Server/Auth.hs b/src/Server/Auth.hs index a0956d6..3df256e 100644 --- a/src/Server/Auth.hs +++ b/src/Server/Auth.hs @@ -34,6 +34,7 @@ import Data.Swagger ) import Servant.Swagger import qualified Data.ByteString.Char8 as B8 +import Data.Char (isSpace) newtype ApiKey = ApiKey Text deriving (Eq, Show) @@ -80,18 +81,30 @@ maybeToEither e = maybe (Left e) Right mkApiKey :: ByteString -> ApiKey mkApiKey = ApiKey . decodeUtf8 --- TODO: maybe generate a random UUID? This is only really --- useful for dev/test, when it's annoying to set `x-request-id` --- by hand -mkRequestId :: Maybe ByteString -> Maybe RequestID -mkRequestId Nothing = Just . RequestID $ "fake-id" -mkRequestId rid = RequestID <$> rid - -getLastIP :: ByteString -> Maybe IPAddress -getLastIP bs = - B8.split ',' bs - & lastMaybe - <&> IPAddress +-- | Given a bytestring containing potentially many ip addresses +-- separated by commas, get the last (or only) one. +-- +-- HEROKU SANS PROXY SPECIFIC: +-- Note that getting the last IP, in the case of multiple addresses +-- being present (due to spoofing or proxies) isn't _fully_ reliable: +-- Heroku appends the IP _it_ saw as connecting to their origin to +-- the `x-forwarded-for` header, which means the last one is likely +-- the real one even in the presence of spoofing; but if a proxy sits +-- in front of Heroku (e.g. Fastly,) then that one will end up in the +-- last position, as is canonical for the header: +-- https://stackoverflow.com/a/37061471 +-- https://en.wikipedia.org/wiki/X-Forwarded-For +-- in our case though, when most traffic will fall within the only-one-IP +-- or many-IPs-but-likely-spoofing scenario, this is good enough. +mkIpAddress :: ByteString -> Maybe IPAddress +mkIpAddress bs = do + let addresses = B8.split ',' bs + case addresses of + [ip] -> pure . IPAddress $ ip + l@(_:_ips) -> do + lastIP <- lastMaybe l + pure . IPAddress $ B8.dropWhile isSpace lastIP + _ -> Nothing -- | Find API Key in @X-Geocode-City-Api-Key@ header extractApiKeyHeader :: Request -> Maybe ByteString @@ -108,30 +121,39 @@ extractApiKeyParam req = & L.lookup "api-key" & fromMaybe Nothing --- Both x-request-id and x-forwarded-for are Heroku-isms; --- they're not reliable (or set!) in other environments --- but we only use them to identify requests (not authenticate/authorize.) --- | Extract request ID from `x-request-id` header. Default to + +-- | Extract request ID from `x-request-id` header: +-- https://devcenter.heroku.com/articles/http-request-id extractRequestId :: Request -> Maybe RequestID extractRequestId req = req & requestHeaders & L.lookup "x-request-id" - & mkRequestId + <&> RequestID +-- | The request IP is the "real" IP as populated by Heroku: +-- https://devcenter.heroku.com/articles/http-routing#heroku-headers +-- there's also [remoteHost](https://hackage.haskell.org/package/wai-3.2.3/docs/Network-Wai.html#v:remoteHost) +-- but that's unusable in production. extractRequestIP :: Request -> Maybe IPAddress extractRequestIP req = req & requestHeaders & L.lookup "x-forwarded-for" - <&> getLastIP + <&> mkIpAddress & fromMaybe Nothing + +-- | Attempt to extract the api key from the request: either from our custom header, or the querystring. authWithApiKey :: Request -> Maybe RequestKey authWithApiKey req = do apiKey <- mkApiKey <$> (extractApiKeyHeader req <|> extractApiKeyParam req) requestID <- extractRequestId req pure $ ByApiKey apiKey requestID +-- | Identify a request by IP. This is a very Heroku-centric approach: we rely +-- on the x-request-id and x-forwarded-for headers; we _could_ use +-- wai's `remoteHost` as a fallback, but it's unrealistic in the current +-- deployment environment: it's either populated upstream, or not. authWithIP :: Request -> Maybe RequestKey authWithIP req = do ip <- extractRequestIP req @@ -158,4 +180,4 @@ instance HasSwagger api => HasSwagger (AuthProtect "api-key" :> api) where mkSec = id securityScheme = SecurityScheme type_ (Just desc) type_ = SecuritySchemeApiKey (ApiKeyParams "api-key" ApiKeyQuery) - desc = "JSON Web Token-based API key (can also be provided in the X-Geocode-City-Api-Key header)" + desc = "JSON Web Token-based API key (can also be provided in the X-Geocode-City-Api-Key header.) If omitted, the client IP will be used for rate limiting." From 136da898a2664495ee016c8ce952c3cf2b449914 Mon Sep 17 00:00:00 2001 From: Luis Borjas Reyes Date: Sat, 23 Jan 2021 15:59:00 -0500 Subject: [PATCH 4/7] Allow "unlimited" anon access, if configured By default, this isn't allowed in production. --- src/Config.hs | 15 +++++++++++++++ src/Server/Auth.hs | 19 ++++++++++++++----- src/Server/Handlers.hs | 10 +++++++--- src/Server/Run.hs | 22 +++++++++++++--------- 4 files changed, 49 insertions(+), 17 deletions(-) diff --git a/src/Config.hs b/src/Config.hs index d482ac6..8e58e19 100644 --- a/src/Config.hs +++ b/src/Config.hs @@ -13,6 +13,9 @@ import System.Envy defOption, gFromEnvCustom, ) +import qualified Data.Pool as P +import qualified Database.PostgreSQL.Simple as PG +import qualified Database.Redis as R data Environment = Development @@ -20,6 +23,10 @@ data Environment | Production deriving stock (Eq, Show, Enum, Read) +data AnonAccess + = AlwaysDenyAnon + | AlwaysAllowAnon + deriving stock (Eq, Show) instance Var Environment where toVar = show fromVar = readMaybe @@ -40,6 +47,14 @@ instance FromEnv AppConfig where -- drop the `app*` prefix that e.g. Heroku will add: fromEnv = gFromEnvCustom defOption {dropPrefixCount = 3} +-- opaque "env" to carry/specify runtime dependencies. +data AppContext = AppContext + { ctxRedisConnection :: !R.Connection, + ctxDatabasePool :: P.Pool PG.Connection, + ctxAnonAccess :: !AnonAccess + } + +-- | Default app config. Override with environment variables. defaultConfig :: AppConfig defaultConfig = AppConfig diff --git a/src/Server/Auth.hs b/src/Server/Auth.hs index 3df256e..91cb0d8 100644 --- a/src/Server/Auth.hs +++ b/src/Server/Auth.hs @@ -35,6 +35,7 @@ import Data.Swagger import Servant.Swagger import qualified Data.ByteString.Char8 as B8 import Data.Char (isSpace) +import Config (AnonAccess (..)) newtype ApiKey = ApiKey Text deriving (Eq, Show) @@ -49,6 +50,7 @@ newtype IPAddress = IPAddress ByteString data RequestKey = ByIP IPAddress RequestID | ByApiKey ApiKey RequestID + | WithUnlimitedAccess deriving (Eq, Show) -- | Authentication (identification) by either Api Key or IP @@ -58,17 +60,17 @@ type ApiKeyAuth = AuthHandler Request RequestKey -- IP Address. A little bit of Maybe blindness here: we don't really -- care to publicize that "anonymous" requests can be made, since the only -- way an IP address can't be found is outside of a production env. -authHandler :: ApiKeyAuth -authHandler = +authHandler :: AnonAccess -> ApiKeyAuth +authHandler anonCriterion = mkAuthHandler handler where throw401 msg = throwError $ err401 {errBody = msg} handler req = either throw401 pure $ do - authWithApiKey req <|> authWithIP req + authWithApiKey req <|> authWithIP req <|> authAnon anonCriterion & maybeToEither "Missing API key header (X-Geocode-City-Api-Key) or query param (api-key)" -authContext :: Context (ApiKeyAuth ': '[]) -authContext = authHandler :. EmptyContext +authContext :: AnonAccess -> Context (ApiKeyAuth ': '[]) +authContext aa = authHandler aa :. EmptyContext --- @@ -160,6 +162,13 @@ authWithIP req = do requestID <- extractRequestId req pure $ ByIP ip requestID +-- | Given an @AnonAccess@ criterion, give them "unkeyed" access, +-- or deny. Useful for dev/test, or if you want to deploy +-- without api key/rate limiting. +authAnon :: AnonAccess -> Maybe RequestKey +authAnon AlwaysAllowAnon = Just WithUnlimitedAccess +authAnon AlwaysDenyAnon = Nothing + --- --- Swagger instances --- diff --git a/src/Server/Handlers.hs b/src/Server/Handlers.hs index 7e89b77..7bafc0c 100644 --- a/src/Server/Handlers.hs +++ b/src/Server/Handlers.hs @@ -8,11 +8,11 @@ import Database.Queries as Q import Control.Carrier.Error.Either (throwError) import Server.Auth import Control.Lens ((.~), (?~)) -import Data.Swagger +import Data.Swagger hiding (Info) import Servant.Swagger import Data.Time -import Effects.Cache -import Effects.Time (now) +import Config (LogMessage (..)) +import Effects (hllAdd, hllCount, log, now) service :: AppM sig m => ServerT Service m service = @@ -82,6 +82,10 @@ validateApiKey (ByIP (IPAddress ipAddress) (RequestID requestId)) = do else throwError err429 {errBody = "Daily request limit exceeded for IP Address (try using an API Key)."} +-- | If given "unlimited access", don't do any rate limiting. +validateApiKey WithUnlimitedAccess = do + log $ Info "Request without rate limiting" + pure () err429 :: ServerError err429 = diff --git a/src/Server/Run.hs b/src/Server/Run.hs index fe5df40..d811fa4 100644 --- a/src/Server/Run.hs +++ b/src/Server/Run.hs @@ -4,12 +4,11 @@ module Server.Run where -import Config (AppConfig (..), renderLogMessage) +import Config (AnonAccess (..), AppConfig (..), AppContext (..), Environment (Production), renderLogMessage) import Control.Carrier.Error.Either (runError) import Control.Carrier.Lift (runM) import qualified Data.Pool as P import qualified Database.Pool as DB -import qualified Database.PostgreSQL.Simple as PG import Effects (CacheError, runCacheWithConnection, LogStdoutC (runLogStdout), reinterpretLog, @@ -25,9 +24,9 @@ import Server.Auth (ApiKeyAuth, authContext) import qualified Database.Redis as R -- | Build a wai app with a connection pool -application :: R.Connection -> P.Pool PG.Connection -> Application -application cacheConn pool = - serveWithContext proxyService authContext $ +application :: AppContext -> Application +application appCtx = + serveWithContext proxyService (authContext (ctxAnonAccess appCtx)) $ hoistServerWithContext proxyService (Proxy :: Proxy '[ApiKeyAuth]) @@ -35,13 +34,13 @@ application cacheConn pool = service where transform handler = do - res <- P.withResource pool runEffects + res <- P.withResource (ctxDatabasePool appCtx) runEffects either Servant.throwError pure (handleError res) where runEffects conn = handler & runTimeIO - & runCacheWithConnection cacheConn + & runCacheWithConnection (ctxRedisConnection appCtx) & runDatabaseWithConnection conn & reinterpretLog renderLogMessage & runLogStdout @@ -65,5 +64,10 @@ start :: AppConfig -> IO () start cfg = do pool <- DB.initPool (appDatabaseUrl cfg) -- TODO: get REDIS_URL from config, parse (and fail if parsing fails) - redis <- R.checkedConnect R.defaultConnectInfo - Warp.run (appPort cfg) (application redis pool) + redis <- R.checkedConnect R.defaultConnectInfo + let env = AppContext { + ctxRedisConnection = redis + , ctxDatabasePool = pool + , ctxAnonAccess = if Production == appDeployEnv cfg then AlwaysDenyAnon else AlwaysAllowAnon + } + Warp.run (appPort cfg) (application env) From 03eaf3ec9460c65116e20515c768f0fbb5ff91e5 Mon Sep 17 00:00:00 2001 From: Luis Borjas Reyes Date: Sat, 23 Jan 2021 15:59:36 -0500 Subject: [PATCH 5/7] Remove outdated TODO --- src/Database/Migrations.hs | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/Database/Migrations.hs b/src/Database/Migrations.hs index 6f83b57..8dd580b 100644 --- a/src/Database/Migrations.hs +++ b/src/Database/Migrations.hs @@ -5,8 +5,6 @@ import Database.PostgreSQL.Simple (connectPostgreSQL, withTransaction) import Database.PostgreSQL.Simple.Migration (MigrationCommand (..), MigrationContext (..), MigrationResult (..), runMigration) import Import --- TODO(luis) we may want to take a Bool parameter to send --- in `MigrationContext`: right now it defaults to verbose. runMigrations' :: Bool -> FilePath -> DatabaseUrl -> IO (Either String String) runMigrations' isVerbose migrationsDir (DatabaseUrl conStr) = do con <- connectPostgreSQL $ encodeUtf8 conStr From 8e3b369207d93deea63ce2cc16a58938a8f7ed01 Mon Sep 17 00:00:00 2001 From: Luis Borjas Reyes Date: Sat, 23 Jan 2021 16:09:35 -0500 Subject: [PATCH 6/7] Get REDIS_URL from environment. --- src/Config.hs | 11 +++++++++-- src/Server/Run.hs | 9 +++++---- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/src/Config.hs b/src/Config.hs index 8e58e19..2eac53b 100644 --- a/src/Config.hs +++ b/src/Config.hs @@ -35,11 +35,15 @@ newtype DatabaseUrl = DatabaseUrl Text deriving newtype (Eq, Show) deriving (Var) via Text +newtype RedisUrl = RedisUrl String + deriving newtype (Eq, Show) + deriving (Var) via String -- | Configuration as it comes from the environment; flat, static. data AppConfig = AppConfig { appPort :: !Int, appDeployEnv :: !Environment, - appDatabaseUrl :: !DatabaseUrl + appDatabaseUrl :: !DatabaseUrl, + appRedisUrl :: !RedisUrl } deriving stock (Eq, Show, Generic) @@ -60,7 +64,10 @@ defaultConfig = AppConfig { appPort = 3000, appDeployEnv = Development, - appDatabaseUrl = DatabaseUrl "postgresql://localhost/geocode_city_dev?user=luis" + appDatabaseUrl = DatabaseUrl "postgresql://localhost/geocode_city_dev?user=luis", + -- the underlying lib can parse the right stuff here: + -- https://hackage.haskell.org/package/hedis-0.14.1/docs/Database-Redis.html#v:parseConnectInfo + appRedisUrl = RedisUrl "redis://" } -- | Log levels diff --git a/src/Server/Run.hs b/src/Server/Run.hs index d811fa4..1a37e36 100644 --- a/src/Server/Run.hs +++ b/src/Server/Run.hs @@ -4,7 +4,7 @@ module Server.Run where -import Config (AnonAccess (..), AppConfig (..), AppContext (..), Environment (Production), renderLogMessage) +import Config (AnonAccess (..), AppConfig (..), AppContext (..), Environment (Production), RedisUrl (..), renderLogMessage) import Control.Carrier.Error.Either (runError) import Control.Carrier.Lift (runM) import qualified Data.Pool as P @@ -48,7 +48,7 @@ application appCtx = & runError @CacheError & runM --- TODO: we can probably have a runCacheSafe interpreter somewhere +-- FIXME: we can probably have a runCacheSafe interpreter somewhere -- that allows us to handle this deeper in, in the handlers; -- though maybe it's preferrable to handle it here? -- At the moment, it is, since we don't really care about specific @@ -63,8 +63,9 @@ handleError err = start :: AppConfig -> IO () start cfg = do pool <- DB.initPool (appDatabaseUrl cfg) - -- TODO: get REDIS_URL from config, parse (and fail if parsing fails) - redis <- R.checkedConnect R.defaultConnectInfo + redis <- case R.parseConnectInfo (appRedisUrl cfg & un) of + Left e -> fail e + Right connectInfo -> R.connect connectInfo let env = AppContext { ctxRedisConnection = redis , ctxDatabasePool = pool From 63dcc667babfecab11073f4dddaaa249c7e4cd7d Mon Sep 17 00:00:00 2001 From: Luis Borjas Reyes Date: Sat, 23 Jan 2021 18:45:02 -0500 Subject: [PATCH 7/7] Incorporate rate limiting in request handlers Also, introduces the `monthly_quota` nullable column for api keys, defaults to 100,000. --- migrations/202101231600_api_quotas.sql | 2 + src/Database/Queries.hs | 11 +-- src/Server/Handlers.hs | 120 +++++++++++++++++-------- src/Server/Types.hs | 36 ++++++-- 4 files changed, 118 insertions(+), 51 deletions(-) create mode 100644 migrations/202101231600_api_quotas.sql diff --git a/migrations/202101231600_api_quotas.sql b/migrations/202101231600_api_quotas.sql new file mode 100644 index 0000000..d239bc6 --- /dev/null +++ b/migrations/202101231600_api_quotas.sql @@ -0,0 +1,2 @@ +alter table account.api_key + add column if not exists monthly_quota bigint default 100000; diff --git a/src/Database/Queries.hs b/src/Database/Queries.hs index 7179a7a..e6e059d 100644 --- a/src/Database/Queries.hs +++ b/src/Database/Queries.hs @@ -56,11 +56,12 @@ latestUpdate = do updatedAts <- query_ "select max(modification) from raw.geonames" pure $ fromOnly =<< listToMaybe updatedAts --- | Given an API Key, find out if it exists and is enabled. -isKeyEnabled :: Has Database sig m => Text -> m Bool -isKeyEnabled key = do - exists <- query "select is_enabled from account.api_key where key = ?" (Only key) - pure $ maybe False fromOnly (listToMaybe exists) +-- | Given an API Key, find out if it exists and is enabled; +-- return status and current quota. +findApiKey :: Has Database sig m => Text -> m (Bool, Maybe Integer) +findApiKey key = do + exists <- query "select is_enabled, monthly_quota from account.api_key where key = ?" (Only key) + pure $ fromMaybe (False, Just 0) (listToMaybe exists) -- | Fast query for name autocomplete: biased towards more populous cities, diff --git a/src/Server/Handlers.hs b/src/Server/Handlers.hs index 7bafc0c..edeb96a 100644 --- a/src/Server/Handlers.hs +++ b/src/Server/Handlers.hs @@ -13,6 +13,8 @@ import Servant.Swagger import Data.Time import Config (LogMessage (..)) import Effects (hllAdd, hllCount, log, now) +import Data.Time.Clock.POSIX (utcTimeToPOSIXSeconds) +import qualified Network.HTTP.Types as N service :: AppM sig m => ServerT Service m service = @@ -22,70 +24,89 @@ service = :<|> search :<|> reverseGeocode -stats :: (AppM sig m) => RequestKey -> m Stats -stats apiKey = validateApiKey apiKey >> do +stats :: (AppM sig m) => RequestKey -> m (RateLimited Stats) +stats apiKey = do + rateLimitInfo <- checkUsage apiKey count <- Q.cityCount update <- Q.latestUpdate - return $ Stats update count + return $ addRateLimitHeaders rateLimitInfo $ Stats update count -autoComplete :: (AppM sig m) => RequestKey -> Text -> Maybe Int -> m [CityAutocomplete] -autoComplete apiKey q limit = - validateApiKey apiKey >> do - results <- Q.cityAutoComplete q limit - return $ map serializeAutocompleteResult results +autoComplete :: (AppM sig m) => RequestKey -> Text -> Maybe Int -> m (RateLimited [CityAutocomplete]) +autoComplete apiKey q limit = do + rateLimitInfo <- checkUsage apiKey + results <- Q.cityAutoComplete q limit + return $ addRateLimitHeaders rateLimitInfo $ map serializeAutocompleteResult results -- | Search city by name -search :: (AppM sig m) => RequestKey -> Text -> Maybe Int -> m [City] -search apiKey q limit = - validateApiKey apiKey >> do - results <- Q.citySearch q limit - return $ map serializeCityResult results +search :: (AppM sig m) => RequestKey -> Text -> Maybe Int -> m (RateLimited [City]) +search apiKey q limit = do + rateLimitInfo <- checkUsage apiKey + results <- Q.citySearch q limit + return $ addRateLimitHeaders rateLimitInfo $ map serializeCityResult results -- | Search city by (lat, lng) -reverseGeocode :: (AppM sig m) => RequestKey -> Latitude -> Longitude -> Maybe Int -> m [City] -reverseGeocode apiKey lat lng limit = - validateApiKey apiKey >> do - results <- Q.reverseSearch (lng, lat) limit - return $ map serializeCityResult results +reverseGeocode :: (AppM sig m) => RequestKey -> Latitude -> Longitude -> Maybe Int -> m (RateLimited [City]) +reverseGeocode apiKey lat lng limit = do + rateLimitInfo <- checkUsage apiKey + results <- Q.reverseSearch (lng, lat) limit + return $ addRateLimitHeaders rateLimitInfo $ map serializeCityResult results --- --- Rate Limiting --- - --- TODO: return requests left in quota, rate limit reset timestamp - -- | Rate limit based on api key: must be valid and under allocated quota in the current (UTC) month -validateApiKey :: (AppM sig m) => RequestKey -> m () -validateApiKey (ByApiKey (ApiKey apiKey) (RequestID requestId)) = do +-- note that a given key may not have any quota set: we currently interpret that +-- scenario as "unlimited" and return headers indicating that, just like `WithUnlimitedAccess` +checkUsage :: (AppM sig m) => RequestKey -> m RateLimitInfo +checkUsage (ByApiKey (ApiKey apiKey) (RequestID requestId)) = do currentTime <- now - let cacheKey = "requests:" <> encodeUtf8 apiKey <> encodeUtf8 (monthString currentTime) + let cacheKey = "requests:" <> encodeUtf8 apiKey <> ":" <> encodeUtf8 (monthString currentTime) _added <- hllAdd cacheKey [requestId] count <- hllCount [cacheKey] - -- TODO: retrieve quota from DB - isValidKey <- Q.isKeyEnabled apiKey + (isValidKey, quota) <- Q.findApiKey apiKey + let limit = fromMaybe 0 quota + remaining = max 0 (limit - count) + resetsAt = case quota of + Just _ -> nextMonthStart currentTime + Nothing -> currentTime + rateLimitInfo = RateLimitInfo limit remaining (resetsAt & utcTimeToPOSIXSeconds) if isValidKey then - if count < 100000 then - pure () + -- no quota means unlimited (!) + if maybe True (count <) quota then + pure rateLimitInfo else - throwError err429 {errBody = "Monthly request limit exceeded for API Key."} + throwError + err429 + { errBody = "Monthly request limit exceeded for API Key.", + errHeaders = rateLimitHeaders rateLimitInfo + } else throwError err403 {errBody = "Invalid API Key."} -- | Rate-limit based on (real) IP: must be under 1000 requests in the current (UTC) day -validateApiKey (ByIP (IPAddress ipAddress) (RequestID requestId)) = do +checkUsage (ByIP (IPAddress ipAddress) (RequestID requestId)) = do currentTime <- now - let cacheKey = "requests:" <> ipAddress <> encodeUtf8 (dayString currentTime) + let cacheKey = "requests:" <> ipAddress <> ":" <> encodeUtf8 (dayString currentTime) _added <- hllAdd cacheKey [requestId] count <- hllCount [cacheKey] - if count < 1000 then - pure () + let limit = 1000 + remaining = max 0 (limit - count) + resetsAt = nextDayStart currentTime + rateLimitInfo = RateLimitInfo limit remaining (resetsAt & utcTimeToPOSIXSeconds) + if count < limit then + pure rateLimitInfo else - throwError err429 {errBody = "Daily request limit exceeded for IP Address (try using an API Key)."} + throwError + err429 + { errBody = "Daily request limit exceeded for IP Address (try using an API Key).", + errHeaders = rateLimitHeaders rateLimitInfo + } -- | If given "unlimited access", don't do any rate limiting. -validateApiKey WithUnlimitedAccess = do - log $ Info "Request without rate limiting" - pure () +checkUsage WithUnlimitedAccess = do + log $ Info "Request without rate limiting!" + currentTime <- now + pure $ RateLimitInfo 0 0 (currentTime & utcTimeToPOSIXSeconds) err429 :: ServerError err429 = @@ -95,6 +116,21 @@ err429 = , errHeaders = [] } +rateLimitHeaders :: RateLimitInfo -> [N.Header] +rateLimitHeaders RateLimitInfo {..} = + [ ("X-RateLimit-Limit", encodeUtf8 . showString $ rateLimitTotal), + ("X-RateLimit-Remaining", encodeUtf8 . showString $ rateLimitRemaining), + ("X-RateLimit-Resets", encodeUtf8 . showString $ rateLimitResets) + ] + where + showString :: Show a => a -> String + showString = show + +addRateLimitHeaders :: RateLimitInfo -> a -> RateLimited a +addRateLimitHeaders RateLimitInfo {..} = + addHeader rateLimitTotal + . addHeader rateLimitRemaining + . addHeader rateLimitResets --- --- Swagger --- @@ -147,3 +183,13 @@ dayString = formatTime defaultTimeLocale "%Y-%m-%d" monthString :: UTCTime -> String monthString = formatTime defaultTimeLocale "%Y-%m" + +nextDayStart :: UTCTime -> UTCTime +nextDayStart (UTCTime day _time) = UTCTime (addDays 1 day) 0 + +nextMonthStart :: UTCTime -> UTCTime +nextMonthStart (UTCTime day _time) = + UTCTime monthStart 0 + where + (y, m, _d) = toGregorian $ addGregorianMonthsClip 1 day + monthStart = fromGregorian y m 1 diff --git a/src/Server/Types.hs b/src/Server/Types.hs index 69dfe74..beb3d20 100644 --- a/src/Server/Types.hs +++ b/src/Server/Types.hs @@ -20,28 +20,29 @@ import Data.Aeson defaultOptions, genericToJSON, ) -import Servant (AuthProtect, FromHttpApiData (parseUrlPiece), Get, JSON, QueryParam, QueryParam', Required, ServerError, Strict, (:<|>), type (:>)) -import Servant.Server.Experimental.Auth -import Server.Auth -import Data.Swagger hiding (SchemaOptions (fieldLabelModifier)) +import Data.Swagger (Swagger, ToParamSchema, ToSchema) +import Data.Time.Clock.POSIX (POSIXTime) +import Servant (AuthProtect, FromHttpApiData (parseUrlPiece), Get, Header, Headers, JSON, QueryParam, QueryParam', Required, ServerError, Strict, (:<|>), type (:>)) +import Servant.Server.Experimental.Auth (AuthServerData) +import Server.Auth (RequestKey) type StrictParam = QueryParam' '[Required, Strict] type ApiKeyProtect = AuthProtect "api-key" type ApiRoutes = - ApiKeyProtect :> "stats" :> Get '[JSON] Stats + ApiKeyProtect :> "stats" :> Get '[JSON] (RateLimited Stats) :<|> ApiKeyProtect :> "autocomplete" :> StrictParam "q" Text :> QueryParam "limit" Int - :> Get '[JSON] [CityAutocomplete] + :> Get '[JSON] (RateLimited [CityAutocomplete]) :<|> ApiKeyProtect :> "search" :> StrictParam "name" Text :> QueryParam "limit" Int - :> Get '[JSON] [City] + :> Get '[JSON] (RateLimited [City]) :<|> ApiKeyProtect :> "locationSearch" :> StrictParam "lat" Latitude :> StrictParam "lng" Longitude :> QueryParam "limit" Int - :> Get '[JSON] [City] + :> Get '[JSON] (RateLimited [City]) type SwaggerAPI = "swagger.json" :> Get '[JSON] Swagger @@ -60,10 +61,27 @@ proxyService = Proxy proxyApi :: Proxy ApiRoutes proxyApi = Proxy +--- +--- INTERNAL TYPES +--- -- | Auth type for api keys -- from: https://docs.servant.dev/en/stable/tutorial/Authentication.html#generalized-authentication type instance AuthServerData (AuthProtect "api-key") = RequestKey - +type RateLimited a = + Headers + '[ Header "X-RateLimit-Limit" Integer, + Header "X-RateLimit-Remaining" Integer, + Header "X-RateLimit-Resets" POSIXTime + ] + a + +-- Inspired by Github: +-- https://docs.github.com/en/rest/overview/resources-in-the-rest-api#rate-limiting +data RateLimitInfo = RateLimitInfo + { rateLimitTotal :: Integer, + rateLimitRemaining :: Integer, + rateLimitResets :: POSIXTime + } --- --- REQUEST TYPES ---