Skip to content

Commit

Permalink
Merge pull request #4 from sampingantech/fix/use-asynccasbin
Browse files Browse the repository at this point in the history
Fix/use asynccasbin
  • Loading branch information
sakti authored Jul 26, 2021
2 parents a18b8e4 + 07d66cf commit 57bbfa1
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 91 deletions.
8 changes: 0 additions & 8 deletions casbin_databases_adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
from databases import Database
from sqlalchemy import Table

from casbin_databases_adapter.utils import to_sync


class Filter:
ptype: List[str] = []
Expand All @@ -26,7 +24,6 @@ def __init__(self, db: Database, table: Table, filtered=False):
self.table: Table = table
self.filtered: bool = filtered

@to_sync()
async def load_policy(self, model: Model):
query = self.table.select()
rows = await self.db.fetch_all(query)
Expand All @@ -35,7 +32,6 @@ async def load_policy(self, model: Model):
line = [v for k, v in row.items() if k in self.cols and v is not None]
persist.load_policy_line(", ".join(line), model)

@to_sync()
async def save_policy(self, model: Model):
await self.db.execute(self.table.delete())
query = self.table.insert()
Expand All @@ -54,12 +50,10 @@ async def save_policy(self, model: Model):
await self.db.execute_many(query, values)
return True

@to_sync()
async def add_policy(self, sec, p_type, rule):
row = self._policy_to_dict(p_type, rule)
await self.db.execute(self.table.insert(), row)

@to_sync()
async def remove_policy(self, sec, p_type, rule):
query = self.table.delete().where(self.table.columns.ptype == p_type)
for i, value in enumerate(rule):
Expand All @@ -69,7 +63,6 @@ async def remove_policy(self, sec, p_type, rule):

return True if result > 0 else False

@to_sync()
async def remove_filtered_policy(self, sec, ptype, field_index, *field_values):
query = self.table.delete().where(self.table.columns.ptype == ptype)
if not (0 <= field_index <= 5):
Expand All @@ -82,7 +75,6 @@ async def remove_filtered_policy(self, sec, ptype, field_index, *field_values):
result = await self.db.execute(query)
return True if result else False

@to_sync()
async def load_filtered_policy(self, model: Model, filter_: Filter) -> None:
query = self.table.select().order_by(self.table.columns.id)
for att, value in filter_.__dict__.items():
Expand Down
58 changes: 0 additions & 58 deletions casbin_databases_adapter/utils.py

This file was deleted.

4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
casbin>=0.8.1
SQLAlchemy>=1.2.18
databases>=0.2.6
databases>=0.2.6
asynccasbin>=1.1.7
4 changes: 3 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,6 @@ async def enforcer(
db: Database, setup_policies, casbin_rule_table: Table, model_conf_path
) -> Enforcer:
adapter = DatabasesAdapter(db, table=casbin_rule_table)
return Enforcer(model_conf_path, adapter)
enforcer = Enforcer(model_conf_path, adapter)
await enforcer.load_policy()
return enforcer
44 changes: 22 additions & 22 deletions tests/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from casbin_databases_adapter.adapter import Filter


def test_load_policy(db: Database, enforcer: Enforcer):
async def test_load_policy(db: Database, enforcer: Enforcer):

assert enforcer.enforce("alice", "data1", "read") == True
assert enforcer.enforce("bob", "data2", "write") == True
Expand All @@ -16,60 +16,60 @@ def test_load_policy(db: Database, enforcer: Enforcer):
assert enforcer.enforce("bob", "data2", "read") == False


def test_add_policy(db: Database, enforcer: Enforcer):
async def test_add_policy(db: Database, enforcer: Enforcer):
assert not enforcer.enforce("eve", "data3", "read")
result = enforcer.add_permission_for_user("eve", "data3", "read")
result = await enforcer.add_permission_for_user("eve", "data3", "read")
assert result
assert enforcer.enforce("eve", "data3", "read")


def test_save_policy(db: Database, enforcer: Enforcer):
async def test_save_policy(db: Database, enforcer: Enforcer):
assert not enforcer.enforce("alice", "data4", "read")

model: Model = enforcer.get_model()
model.clear_policy()

model.add_policy("p", "p", ["alice", "data4", "read"])
adapter: Adapter = enforcer.get_adapter()
adapter.save_policy(model)
await adapter.save_policy(model)
assert enforcer.enforce("alice", "data4", "read")


def test_remove_policy(db: Database, enforcer: Enforcer):
async def test_remove_policy(db: Database, enforcer: Enforcer):
assert not (enforcer.enforce("alice", "data5", "read"))
enforcer.add_permission_for_user("alice", "data5", "read")
await enforcer.add_permission_for_user("alice", "data5", "read")
assert enforcer.enforce("alice", "data5", "read")
enforcer.delete_permission_for_user("alice", "data5", "read")
await enforcer.delete_permission_for_user("alice", "data5", "read")
assert not (enforcer.enforce("alice", "data5", "read"))


def test_remove_filtered_policy(db: Database, enforcer: Enforcer):
async def test_remove_filtered_policy(db: Database, enforcer: Enforcer):

assert enforcer.enforce("alice", "data1", "read")
enforcer.remove_filtered_policy(1, "data1")
await enforcer.remove_filtered_policy(1, "data1")
assert not (enforcer.enforce("alice", "data1", "read"))

assert enforcer.enforce("bob", "data2", "write")
assert enforcer.enforce("alice", "data2", "read")
assert enforcer.enforce("alice", "data2", "write")

enforcer.remove_filtered_policy(1, "data2", "read")
await enforcer.remove_filtered_policy(1, "data2", "read")

assert enforcer.enforce("bob", "data2", "write")
assert not (enforcer.enforce("alice", "data2", "read"))
assert enforcer.enforce("alice", "data2", "write")

enforcer.remove_filtered_policy(2, "write")
await enforcer.remove_filtered_policy(2, "write")

assert not (enforcer.enforce("bob", "data2", "write"))
assert not (enforcer.enforce("alice", "data2", "write"))


def test_filtered_policy(db: Database, enforcer: Enforcer):
async def test_filtered_policy(db: Database, enforcer: Enforcer):
filter = Filter()

filter.ptype = ["p"]
enforcer.load_filtered_policy(filter)
await enforcer.load_filtered_policy(filter)
assert enforcer.enforce("alice", "data1", "read")
assert not (enforcer.enforce("alice", "data1", "write"))
assert not (enforcer.enforce("alice", "data2", "read"))
Expand All @@ -81,7 +81,7 @@ def test_filtered_policy(db: Database, enforcer: Enforcer):

filter.ptype = []
filter.v0 = ["alice"]
enforcer.load_filtered_policy(filter)
await enforcer.load_filtered_policy(filter)
assert enforcer.enforce("alice", "data1", "read")
assert not (enforcer.enforce("alice", "data1", "write"))
assert not (enforcer.enforce("alice", "data2", "read"))
Expand All @@ -94,7 +94,7 @@ def test_filtered_policy(db: Database, enforcer: Enforcer):
assert not (enforcer.enforce("data2_admin", "data2", "write"))

filter.v0 = ["bob"]
enforcer.load_filtered_policy(filter)
await enforcer.load_filtered_policy(filter)
assert not (enforcer.enforce("alice", "data1", "read"))
assert not (enforcer.enforce("alice", "data1", "write"))
assert not (enforcer.enforce("alice", "data2", "read"))
Expand All @@ -107,7 +107,7 @@ def test_filtered_policy(db: Database, enforcer: Enforcer):
assert not (enforcer.enforce("data2_admin", "data2", "write"))

filter.v0 = ["data2_admin"]
enforcer.load_filtered_policy(filter)
await enforcer.load_filtered_policy(filter)
assert enforcer.enforce("data2_admin", "data2", "read")
assert enforcer.enforce("data2_admin", "data2", "read")
assert not (enforcer.enforce("alice", "data1", "read"))
Expand All @@ -120,7 +120,7 @@ def test_filtered_policy(db: Database, enforcer: Enforcer):
assert not (enforcer.enforce("bob", "data2", "write"))

filter.v0 = ["alice", "bob"]
enforcer.load_filtered_policy(filter)
await enforcer.load_filtered_policy(filter)
assert enforcer.enforce("alice", "data1", "read")
assert not (enforcer.enforce("alice", "data1", "write"))
assert not (enforcer.enforce("alice", "data2", "read"))
Expand All @@ -134,7 +134,7 @@ def test_filtered_policy(db: Database, enforcer: Enforcer):

filter.v0 = []
filter.v1 = ["data1"]
enforcer.load_filtered_policy(filter)
await enforcer.load_filtered_policy(filter)
assert enforcer.enforce("alice", "data1", "read")
assert not (enforcer.enforce("alice", "data1", "write"))
assert not (enforcer.enforce("alice", "data2", "read"))
Expand All @@ -147,7 +147,7 @@ def test_filtered_policy(db: Database, enforcer: Enforcer):
assert not (enforcer.enforce("data2_admin", "data2", "write"))

filter.v1 = ["data2"]
enforcer.load_filtered_policy(filter)
await enforcer.load_filtered_policy(filter)
assert not (enforcer.enforce("alice", "data1", "read"))
assert not (enforcer.enforce("alice", "data1", "write"))
assert not (enforcer.enforce("alice", "data2", "read"))
Expand All @@ -161,7 +161,7 @@ def test_filtered_policy(db: Database, enforcer: Enforcer):

filter.v1 = []
filter.v2 = ["read"]
enforcer.load_filtered_policy(filter)
await enforcer.load_filtered_policy(filter)
assert enforcer.enforce("alice", "data1", "read")
assert not (enforcer.enforce("alice", "data1", "write"))
assert not (enforcer.enforce("alice", "data2", "read"))
Expand All @@ -174,7 +174,7 @@ def test_filtered_policy(db: Database, enforcer: Enforcer):
assert not (enforcer.enforce("data2_admin", "data2", "write"))

filter.v2 = ["write"]
enforcer.load_filtered_policy(filter)
await enforcer.load_filtered_policy(filter)
assert not (enforcer.enforce("alice", "data1", "read"))
assert not (enforcer.enforce("alice", "data1", "write"))
assert not (enforcer.enforce("alice", "data2", "read"))
Expand Down

0 comments on commit 57bbfa1

Please sign in to comment.