Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
cunla committed Dec 24, 2024
1 parent e59fff9 commit c07c998
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 25 deletions.
35 changes: 23 additions & 12 deletions fakeredis/model/_acl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import fnmatch
import hashlib
from typing import Dict, Set, List, Union, Optional
from typing import Dict, Set, List, Union, Optional, Any

from fakeredis import _msgs as msgs
from ._command_info import get_commands_by_category, get_command_info
Expand Down Expand Up @@ -64,26 +64,24 @@ def reset(self):
self._selectors.clear()

@staticmethod
def _get_command_info(fields: List[bytes]):
def _get_command_info(fields: List[bytes]) -> Optional[List[Any]]:
command = fields[0].lower()
command_info = get_command_info(command)
if not command_info and len(fields) > 1:
command = command + b" " + fields[1].lower()
command_info = get_command_info(command)
return command_info

def command_allowed(self, fields: List[bytes]) -> bool:
def command_allowed(self, command_info: Optional[List[Any]], fields: List[bytes]) -> bool:
res = fields[0].lower() == b"auth" or self._commands.get(fields[0].lower(), False)
res = res or self._commands.get(b"@all", False)
command_info = self._get_command_info(fields)
if not command_info:
return res
for category in command_info[6]:
res = res or self._commands.get(category, False)
return res

def _get_keys(self, fields: List[bytes]) -> List[bytes]:
command_info = self._get_command_info(fields)
def _get_keys(self, command_info: Optional[List[Any]], fields: List[bytes]) -> List[bytes]:
if not command_info:
return []
first_key, last_key, step = command_info[3:6]
Expand All @@ -93,15 +91,24 @@ def _get_keys(self, fields: List[bytes]) -> List[bytes]:
step = step + 1
return fields[first_key : last_key + 1 : step]

def keys_not_allowed(self, fields: List[bytes]) -> List[bytes]:
def keys_not_allowed(self, command_info: Optional[List[Any]], fields: List[bytes]) -> List[bytes]:
if len(self._key_patterns) == 0:
return []
keys = self._get_keys(fields)
keys = self._get_keys(command_info, fields)
res = set()
for pat in self._key_patterns:
res = res.union(fnmatch.filter(keys, pat))
return list(set(keys) - res)

def channels_not_allowed(self, command_info: Optional[List[Any]], fields: List[bytes]) -> List[bytes]:
if len(self._key_patterns) == 0:
return []
channels = fields[1:2]
res = set()
for pat in self._channel_patterns:
res = res.union(fnmatch.filter(channels, pat))
return list(set(channels) - res)

def set_nopass(self) -> None:
self._nopass = True
self._passwords.clear()
Expand Down Expand Up @@ -343,12 +350,16 @@ def validate_command(self, username: bytes, client_info: bytes, fields: List[byt
user_acl = self._user_acl[username]
if not user_acl.enabled:
raise SimpleError("User disabled")

if not user_acl.command_allowed(fields):
command_info = UserAccessControlList._get_command_info(fields)
if not user_acl.command_allowed(command_info, fields):
self.add_log_record(b"command", b"toplevel", fields[0], username, client_info)
raise SimpleError(msgs.NO_PERMISSION_ERROR.format(username.decode(), fields[0].lower().decode()))
keys_not_allowed = user_acl.keys_not_allowed(fields)
keys_not_allowed = user_acl.keys_not_allowed(command_info, fields)
if len(keys_not_allowed) > 0:
self.add_log_record(b"key", b"toplevel", keys_not_allowed[0], username, client_info)
raise SimpleError(msgs.NO_PERMISSION_KEY_ERROR)
# todo
if "@pubsub" in command_info[6]:
channels_not_allowed = user_acl.channels_not_allowed(command_info, fields)
if len(channels_not_allowed) > 0:
self.add_log_record(b"key", b"toplevel", keys_not_allowed[0], username, client_info)
raise SimpleError(msgs.NO_PERMISSION_KEY_ERROR)
22 changes: 9 additions & 13 deletions test/test_mixins/test_acl_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,8 @@ def teardown():
username,
enabled=True,
reset=True,
commands=["+get", "+set", "+select"],
commands=["+get", "+set", "+select", "+publish"],
channels=["message:*"],
keys=["cache:*"],
nopass=True,
)
Expand All @@ -371,24 +372,19 @@ def teardown():
assert r.set("cache:0", 1)
assert r.get("cache:0") == b"1"

# Invalid operation
with pytest.raises(exceptions.NoPermissionError) as ctx:
r.hset("cache:0", "hkey", "hval")
r.publish("invalid-channel", "message")

assert str(ctx.value) == "User fredis-py-user has no permissions to run the 'hset' command"

# Invalid key
with pytest.raises(exceptions.NoPermissionError) as ctx:
r.get("violated_cache:0")

assert str(ctx.value) == "No permissions to access a key"
assert str(ctx.value) == "No permissions to access a channel"

r.auth("", "default")
log = r.acl_log()
assert isinstance(log, list)
assert len(log) == 2
assert len(log) == 1
assert len(r.acl_log(count=1)) == 1
assert isinstance(log[0], dict)

expected = r.acl_log(count=1)[0]
assert expected["username"] == username
log_record = r.acl_log(count=1)[0]
assert log_record["username"] == username
assert log_record["reason"] == "channel"
assert log_record["object"].lower() == "invalid-channel"

0 comments on commit c07c998

Please sign in to comment.