Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 42 additions & 1 deletion src/kbase/auth/_async/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,20 @@ class Token:
_VALID_TOKEN_FIELDS = {f.name for f in fields(Token)}


@dataclass
class User:
""" Information about a KBase user. """
user: str
""" The username of the user associated with the token. """
customroles: list[str]
""" The Auth2 custom roles the user possesses. """
# Not seeing any other fields that are generally useful right now
# Don't really want to expose idents unless there's a very good reason


_VALID_USER_FIELDS = {f.name for f in fields(User)}


def _require_string(putative: str, name: str) -> str:
if not isinstance(putative, str) or not putative.strip():
raise ValueError(f"{name} is required and cannot be a whitespace only string")
Expand Down Expand Up @@ -98,7 +112,7 @@ async def create(
await cli.close()
raise
# TODO CLIENT look through the myriad of auth clients to see what functionality we need
# TODO CLIENT cache user using cachefor value from token
# TODO CLIENT cache valid user names using cachefor value from token
# TODO RELIABILITY could add retries for these methods, tenacity looks useful
# should be safe since they're all reads only
return cli
Expand All @@ -114,6 +128,7 @@ def __init__(self, base_url: str, cache_max_size: int, timer: Callable[[[]], int
if not timer:
raise ValueError("timer is required")
self._token_cache = LRUCache(maxsize=cache_max_size, timer=timer)
self._user_cache = LRUCache(maxsize=cache_max_size, timer=timer)
self._cli = httpx.AsyncClient()

async def __aenter__(self):
Expand Down Expand Up @@ -157,3 +172,29 @@ async def get_token(self, token: str, on_cache_miss: Callable[[], None]=None) ->
# in test mode
self._token_cache.set(token, tk, ttl=tk.cachefor / 1000)
return tk

async def get_user(self, token: str, on_cache_miss: Callable[[], None]=None) -> User:
"""
Get information about a KBase user. This method caches the user;
further caching is unnecessary in most cases.

If you just need the user name get_token is potentially cheaper.

token - the token of the user to query.
on_cache_miss - a function to call if a cache miss occurs.
"""
# really similar to the above, not quite similar enough to make a shared method
_require_string(token, "token")
user = self._user_cache.get(token, default=False)
if user:
return user
if on_cache_miss:
on_cache_miss()
tk = await self.get_token(token)
res = await self._get(self._me_url, headers={"Authorization": token})
u = User(**{k: v for k, v in res.items() if k in _VALID_USER_FIELDS})
# TODO TEST later may want to add tests that change the cachefor value.
# Cleanest way to do this is update the auth2 service to allow setting it
# in test mode
self._user_cache.set(token, u, ttl=tk.cachefor / 1000)
return u
46 changes: 44 additions & 2 deletions src/kbase/auth/_sync/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,20 @@ class Token:
_VALID_TOKEN_FIELDS = {f.name for f in fields(Token)}


@dataclass
class User:
""" Information about a KBase user. """
user: str
""" The username of the user associated with the token. """
customroles: list[str]
""" The Auth2 custom roles the user possesses. """
# Not seeing any other fields that are generally useful right now
# Don't really want to expose idents unless there's a very good reason


_VALID_USER_FIELDS = {f.name for f in fields(User)}


def _require_string(putative: str, name: str) -> str:
if not isinstance(putative, str) or not putative.strip():
raise ValueError(f"{name} is required and cannot be a whitespace only string")
Expand Down Expand Up @@ -85,7 +99,8 @@ def create(

base_url - the base url for the authentication service, for example
https://kbase.us/services/auth
cache_max_size - the maximum size of the token and user caches.
cache_max_size - the maximum size of the token and user caches. When the cache size is
exceeded, the least recently used entries are evicted from the cache.
timer - the timer for the cache. Used for testing. Time unit must be seconds.
"""
cli = cls(base_url, cache_max_size, timer)
Expand All @@ -97,7 +112,7 @@ def create(
cli.close()
raise
# TODO CLIENT look through the myriad of auth clients to see what functionality we need
# TODO CLIENT cache user using cachefor value from token
# TODO CLIENT cache valid user names using cachefor value from token
# TODO RELIABILITY could add retries for these methods, tenacity looks useful
# should be safe since they're all reads only
return cli
Expand All @@ -113,6 +128,7 @@ def __init__(self, base_url: str, cache_max_size: int, timer: Callable[[[]], int
if not timer:
raise ValueError("timer is required")
self._token_cache = LRUCache(maxsize=cache_max_size, timer=timer)
self._user_cache = LRUCache(maxsize=cache_max_size, timer=timer)
self._cli = httpx.Client()

def __enter__(self):
Expand Down Expand Up @@ -156,3 +172,29 @@ def get_token(self, token: str, on_cache_miss: Callable[[], None]=None) -> Token
# in test mode
self._token_cache.set(token, tk, ttl=tk.cachefor / 1000)
return tk

def get_user(self, token: str, on_cache_miss: Callable[[], None]=None) -> User:
"""
Get information about a KBase user. This method caches the user;
further caching is unnecessary in most cases.

If you just need the user name get_token is potentially cheaper.

token - the token of the user to query.
on_cache_miss - a function to call if a cache miss occurs.
"""
# really similar to the above, not quite similar enough to make a shared method
_require_string(token, "token")
user = self._user_cache.get(token, default=False)
if user:
return user
if on_cache_miss:
on_cache_miss()
tk = self.get_token(token)
res = self._get(self._me_url, headers={"Authorization": token})
u = User(**{k: v for k, v in res.items() if k in _VALID_USER_FIELDS})
# TODO TEST later may want to add tests that change the cachefor value.
# Cleanest way to do this is update the auth2 service to allow setting it
# in test mode
self._user_cache.set(token, u, ttl=tk.cachefor / 1000)
return u
124 changes: 124 additions & 0 deletions test/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,3 +213,127 @@ async def test_get_token_cache_evict_on_time(auth_users):
ttt1 = await cli.get_token(auth_users["user"], on_cache_miss=cachemiss)
assert cachemiss.call_count == 2
assert ttt1 == t1


@pytest.mark.asyncio
async def test_get_user_basic(auth_users):
with KBaseAuthClient.create(AUTH_URL) as cli:
u1 = cli.get_user(auth_users["user"])
u2 = cli.get_user(auth_users["user_all"])
async with await AsyncKBaseAuthClient.create(AUTH_URL) as cli:
u3 = await cli.get_user(auth_users["user_random1"])
u4 = await cli.get_user(auth_users["user_random2"])

assert u1.user == "user"
assert u1.customroles == []

assert u2.user == "user_all"
assert u2.customroles == ["random1", "random2"]

assert u3.user == "user_random1"
assert u3.customroles == ["random1"]

assert u4.user == "user_random2"
assert u4.customroles == ["random2"]


@pytest.mark.asyncio
async def test_get_user_basic_fail(auth_users):
err = "token is required and cannot be a whitespace only string"
await _get_user_basic_fail(None, ValueError(err))
await _get_user_basic_fail(" \t ", ValueError(err))
err = "KBase auth server reported token is invalid."
await _get_user_basic_fail("superfake", InvalidTokenError(err))


async def _get_user_basic_fail(token: str, expected: Exception):
with KBaseAuthClient.create(AUTH_URL) as cli:
with pytest.raises(type(expected), match=f"^{expected.args[0]}$"):
cli.get_user(token)
async with await AsyncKBaseAuthClient.create(AUTH_URL) as cli:
with pytest.raises(type(expected), match=f"^{expected.args[0]}$"):
await cli.get_user(token)


@pytest.mark.asyncio
async def test_get_user_cache_evict_on_size(auth_users):
with KBaseAuthClient.create(AUTH_URL, cache_max_size=3) as cli:
cachemiss = Mock()
# fill the cache
u1 = cli.get_user(auth_users["user"], on_cache_miss=cachemiss)
u2 = cli.get_user(auth_users["user_random1"], on_cache_miss=cachemiss)
u3 = cli.get_user(auth_users["user_random2"], on_cache_miss=cachemiss)
assert cachemiss.call_count == 3
# check userss in cache
uu1 = cli.get_user(auth_users["user"], on_cache_miss=cachemiss)
uu2 = cli.get_user(auth_users["user_random1"], on_cache_miss=cachemiss)
uu3 = cli.get_user(auth_users["user_random2"], on_cache_miss=cachemiss)
assert cachemiss.call_count == 3
assert uu1 == u1
assert uu2 == u2
assert uu3 == u3
# Force an eviction
cli.get_user(auth_users["user_all"], on_cache_miss=cachemiss)
assert cachemiss.call_count == 4
# Check user was evicted
uuu1 = cli.get_user(auth_users["user"], on_cache_miss=cachemiss)
assert cachemiss.call_count == 5
assert uuu1 == u1

async with await AsyncKBaseAuthClient.create(AUTH_URL, cache_max_size=3) as cli:
cachemiss = Mock()
# fill the cache
u1 = await cli.get_user(auth_users["user"], on_cache_miss=cachemiss)
u2 = await cli.get_user(auth_users["user_random1"], on_cache_miss=cachemiss)
u3 = await cli.get_user(auth_users["user_random2"], on_cache_miss=cachemiss)
assert cachemiss.call_count == 3
# check users in cache
uu1 = await cli.get_user(auth_users["user"], on_cache_miss=cachemiss)
uu2 = await cli.get_user(auth_users["user_random1"], on_cache_miss=cachemiss)
uu3 = await cli.get_user(auth_users["user_random2"], on_cache_miss=cachemiss)
assert cachemiss.call_count == 3
assert uu1 == u1
assert uu2 == u2
assert uu3 == u3
# Force an eviction
await cli.get_user(auth_users["user_all"], on_cache_miss=cachemiss)
assert cachemiss.call_count == 4
# Check user was evicted
uuu1 = await cli.get_user(auth_users["user"], on_cache_miss=cachemiss)
assert cachemiss.call_count == 5
assert uuu1 == u1


@pytest.mark.asyncio
async def test_get_user_cache_evict_on_time(auth_users):
timer = FakeTimer()
with KBaseAuthClient.create(AUTH_URL, timer=timer) as cli:
cachemiss = Mock()
u1 = cli.get_user(auth_users["user"], on_cache_miss=cachemiss)
assert cachemiss.call_count == 1
# TODO TEST auth2 always returns 300000 ms for cachefor. Update testmode to allow
# setting different values and test here
timer.advance(299)
uu1 = cli.get_user(auth_users["user"], on_cache_miss=cachemiss)
assert cachemiss.call_count == 1
assert uu1 == u1
timer.advance(2)
uuu1 = cli.get_user(auth_users["user"], on_cache_miss=cachemiss)
assert cachemiss.call_count == 2
assert uuu1 == u1

timer = FakeTimer()
async with await AsyncKBaseAuthClient.create(AUTH_URL, timer=timer) as cli:
cachemiss = Mock()
u1 = await cli.get_user(auth_users["user"], on_cache_miss=cachemiss)
assert cachemiss.call_count == 1
# TODO TEST auth2 always returns 300000 ms for cachefor. Update testmode to allow
# setting different values and test here
timer.advance(299)
uu1 = await cli.get_user(auth_users["user"], on_cache_miss=cachemiss)
assert cachemiss.call_count == 1
assert uu1 == u1
timer.advance(2)
uuu1 = await cli.get_user(auth_users["user"], on_cache_miss=cachemiss)
assert cachemiss.call_count == 2
assert uuu1 == u1