Skip to content

Commit

Permalink
Gracefully handle Auth API rate-limiting
Browse files Browse the repository at this point in the history
  • Loading branch information
ysyrota committed Oct 2, 2024
1 parent bbec1e2 commit 43f1d95
Show file tree
Hide file tree
Showing 7 changed files with 204 additions and 10 deletions.
47 changes: 42 additions & 5 deletions lib/duo.c
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,47 @@ _duo_json_response(struct duo_ctx *ctx) {
return code;
}

int
_duo_https_exchange(struct duo_ctx *ctx, const char *method, const char *uri, int msecs, int *code)
{
const int max_int_digits = (241 * sizeof(int) / 100 + 1);
const int max_backoff_wait_secs = 32;
const int initial_backof_wait_secs = 1;
const int backoff_factor = 2;

static const char fmt[] = "Rate-limiting response received from server. Waiting for %d seconds before retrying.";
char msg[(sizeof fmt) + max_int_digits];
int wait_secs = initial_backof_wait_secs;

while (1) {
HTTPScode rc;
time_t retry_after;

rc = https_send(ctx->https, method, uri,
ctx->argc, ctx->argv, ctx->ikey, ctx->skey, ctx->useragent);
if (rc != HTTPS_OK)
return rc;
rc = https_recv(ctx->https, code, &ctx->body, &ctx->body_len, &retry_after, msecs);
if (retry_after != (time_t)-1)
wait_secs = retry_after - time(NULL);

if (rc != HTTPS_OK || *code != 429 || wait_secs > max_backoff_wait_secs)
return rc;

struct timespec timeout = {
.tv_sec = wait_secs,
.tv_nsec = (float)rand() / RAND_MAX * 1000000000
};

snprintf(msg, sizeof(msg), fmt, timeout.tv_sec);
if (ctx->conv_status)
ctx->conv_status(NULL, msg);
nanosleep(&timeout, NULL);
if (retry_after == (time_t)-1)
wait_secs *= backoff_factor;
}
}

static duo_code_t
duo_call(struct duo_ctx *ctx, const char *method, const char *uri, int msecs)
{
Expand All @@ -361,12 +402,8 @@ duo_call(struct duo_ctx *ctx, const char *method, const char *uri, int msecs)
}
break;
}
if ((err = https_send(ctx->https, method, uri,
ctx->argc, ctx->argv, ctx->ikey, ctx->skey, ctx->useragent)) == HTTPS_OK &&
(err = https_recv(ctx->https, &code,
&ctx->body, &ctx->body_len, msecs)) == HTTPS_OK) {
if (_duo_https_exchange(ctx, method, uri, msecs, &code) == HTTPS_OK)
break;
}
https_close(&ctx->https);
}
duo_reset(ctx);
Expand Down
98 changes: 97 additions & 1 deletion lib/https.c
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ struct https_ctx {

struct https_ctx ctx;

typedef enum
{
CB_NONE = 0, /* First callback*/
CB_KEY, /* Last was key */
CB_VAL /* Last was value */
} callback_status_t;

struct https_request {
BIO *cbio;
BIO *body;
Expand All @@ -69,6 +76,14 @@ struct https_request {

int sigpipe_ignored;
struct sigaction old_sigpipe;

time_t retry_after;

char *value;
size_t value_size;
char* key; /* current header name */
size_t key_size; /* size of header name */
callback_status_t last_cb;
};

static int
Expand All @@ -79,15 +94,92 @@ __on_body(http_parser *p, const char *buf, size_t len)
return (BIO_write(req->body, buf, len) != len);
}

time_t
_parse_retry_after(const char *header_value)
{
if (header_value == NULL) {
return (time_t)-1;
}

/* Try to parse as an integer (delay in seconds) */
char *endptr;
long delay_seconds = strtol(header_value, &endptr, 10);
if (*endptr == '\0') {
return time(NULL) + delay_seconds;
}

/* Try to parse as a date */
struct tm tm;
memset(&tm, 0, sizeof(struct tm));
if (strptime(header_value, "%a, %d %b %Y %H:%M:%S %Z", &tm) != NULL) {
return mktime(&tm);
}

return (time_t)-1;
}

static int
__on_message_complete(http_parser *p)
{
struct https_request *req = (struct https_request *)p->data;

req->retry_after = _parse_retry_after(req->value);

free(req->value);
req->value = NULL;
req->value_size = 0;
free(req->key);
req->key = NULL;
req->key_size = 0;
req->last_cb = CB_NONE;

req->done = 1;
return (0);
}

static const char retry_after_header[] = "Retry-After";
static const char x_retry_after_header[] = "X-Retry-After";

static int
__on_header_field(http_parser* p, const char* at, size_t length)
{
struct https_request *client = p->data;

if (client->last_cb == CB_VAL)
client->key_size = 0;

client->key = realloc(client->key, client->key_size + length + 1);
memcpy(client->key + client->key_size, at, length);
client->key_size += length;
client->key[client->key_size] = 0;

client->last_cb = CB_KEY;

return 0;
}

static int
__on_header_value(http_parser* p, const char* at, size_t length)
{
struct https_request *client = p->data;

if (strcasecmp(client->key, retry_after_header) == 0
|| strcasecmp(client->key, x_retry_after_header) == 0)
{
if (client->last_cb != CB_VAL)
client->value_size = 0;

client->value = realloc(client->value, client->value_size + length + 1);
memcpy(client->value + client->value_size, at, length);
client->value_size += length;
client->value[client->value_size] = 0;
}

client->last_cb = CB_VAL;

return 0;
}

static const char *
_SSL_strerror(void)
{
Expand Down Expand Up @@ -470,6 +562,8 @@ https_init(const char *cafile, const char *http_proxy)
/* Set HTTP parser callbacks */
ctx.parse_settings.on_body = __on_body;
ctx.parse_settings.on_message_complete = __on_message_complete;
ctx.parse_settings.on_header_field = __on_header_field;
ctx.parse_settings.on_header_value = __on_header_value;

return (0);
}
Expand Down Expand Up @@ -740,7 +834,7 @@ https_send(struct https_request *req, const char *method, const char *uri,

HTTPScode
https_recv(struct https_request *req, int *code, const char **body, int *len,
int msecs)
time_t *retry_after, int msecs)
{
int n, err;

Expand All @@ -765,6 +859,8 @@ https_recv(struct https_request *req, int *code, const char **body, int *len,
}
*len = BIO_get_mem_data(req->body, (char **)body);
*code = req->parser->status_code;
if (retry_after)
*retry_after = req->retry_after;

return (HTTPS_OK);
}
Expand Down
1 change: 1 addition & 0 deletions lib/https.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ HTTPScode https_recv(
int *code,
const char **body,
int *length,
time_t *retry_after,
int msecs
);

Expand Down
30 changes: 29 additions & 1 deletion tests/common_suites.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import os
import subprocess
import time
import unittest
import sys

Expand Down Expand Up @@ -325,6 +326,33 @@ def test_preauth_allow_bad_response(self):
"preauth-allow-bad_response", "JSON missing valid 'status'"
)

def test_preauth_allow_retry_after(self):
start_time = time.time()
self.check_preauth_state(
"retry-after-3-preauth-allow", "preauth-allowed", prefix="Skipped"
)
execution_time = time.time() - start_time
# 3.x seconds executed twice
self.assertGreater(execution_time, 6)

def test_preauth_allow_retry_after_date(self):
start_time = time.time()
self.check_preauth_state(
"retry-after-date-preauth-allow", "preauth-allowed", prefix="Skipped"
)
execution_time = time.time() - start_time
# 3.x seconds executed twice
self.assertGreater(execution_time, 6)

def test_preauth_allow_rate_limited(self):
start_time = time.time()
self.check_preauth_state(
"rate-limited-preauth-allow", "preauth-allowed", prefix="Skipped"
)
execution_time = time.time() - start_time
# 1.x seconds + 2.x seconds executed twice
self.assertGreater(execution_time, 6)

class Hosts(CommonTestCase):
def run(self, result=None):
with MockDuo(NORMAL_CERT):
Expand Down Expand Up @@ -538,7 +566,7 @@ def test_configuration_with_extra_space(self):
)

class Interactive(CommonTestCase):
PROMPT_REGEX = ".* or option \(1-4\): $"
PROMPT_REGEX = ".* or option \\(1-4\\): $"
PROMPT_TEXT = [
"Duo login for foobar",
"Choose or lose:",
Expand Down
34 changes: 33 additions & 1 deletion tests/mockduo.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ class MockDuoHandler(BaseHTTPRequestHandler):
server_version = "MockDuo/1.0"
protocol_version = "HTTP/1.1"

def __init__(self, *args, **kwargs):
self._rl_req_clock = 0
self._rl_req_num = 0
super().__init__(*args, **kwargs)

def _verify_sig(self):
authz = base64.b64decode(self.headers["Authorization"].split()[1]).decode(
"utf-8"
Expand Down Expand Up @@ -119,9 +124,12 @@ def _get_tx_response(self, txid, is_async):
time.sleep(int(secs))
return rsp

def _send(self, code, buf=b""):
def _send(self, code, buf=b"", headers={}):
self.send_response(code)
self.send_header("Content-length", str(len(buf)))
if headers:
for key, value in headers.items():
self.send_header(key, value)
if buf:
self.send_header("Content-type", "application/json")
self.end_headers()
Expand Down Expand Up @@ -230,6 +238,30 @@ def do_POST(self):
ret["response"] = {"result": "enroll", "status": "please enroll"}
elif self.args["user"] == "bad-json":
buf = b""
elif self.args["user"] == "retry-after-3-preauth-allow":
if self._rl_req_num == 0:
self._rl_req_num = 1
return self._send(429, headers={"X-Retry-After": "3"})
else:
self._rl_req_num = 0
ret["response"] = {"result": "allow", "status": "preauth-allowed"}
elif self.args["user"] == "retry-after-date-preauth-allow":
if self._rl_req_num == 0:
self._rl_req_num = 1
timestr = time.strftime("%a, %d %b %Y %H:%M:%S %Z", time.localtime(time.time()+3))
return self._send(429, headers={"Retry-After": timestr})
else:
self._rl_req_num = 0
ret["response"] = {"result": "allow", "status": "preauth-allowed"}
elif self.args["user"] == "rate-limited-preauth-allow":
if self._rl_req_num in [0,1]:
self._rl_req_num += 1
return self._send(429)
elif self._rl_req_num == 2:
self._rl_req_num = 0
ret["response"] = {"result": "allow", "status": "preauth-allowed"}
else:
return self._send(500, "Wrong timeout")
else:
ret["response"] = {
"result": "auth",
Expand Down
2 changes: 1 addition & 1 deletion tests/test_login_duo.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def test_help_output(self):
def test_version_output(self):
"""Check version output"""
result = login_duo(["-v"])
self.assertRegex(result["stderr"][0], "login_duo \d+\.\d+.\d+")
self.assertRegex(result["stderr"][0], "login_duo \\d+\\.\\d+.\\d+")


class TestLoginDuoEnv(CommonSuites.Env):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_pam_duo.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def pam_duo_interactive(args, env={}, timeout=2):
return process


def pam_duo(args, env={}, timeout=2):
def pam_duo(args, env={}, timeout=10):
pam_duo_path = [os.path.join(TESTDIR, "testpam.py")]
# we don't want to accidentally grab these from the calling environment
excluded_keys = ["SSH_CONNECTION", "FALLBACK", "UID", "http_proxy", "TIMEOUT"]
Expand Down

0 comments on commit 43f1d95

Please sign in to comment.