Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ DSL Improvements: Ability to register functions #84

Merged
merged 1 commit into from
Nov 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 44 additions & 16 deletions core/harambe_core/parser/expression/evaluator.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,31 @@
from functools import wraps
from typing import Any
from typing import Any, Callable

Func = Callable[..., Any]


class ExpressionEvaluator:
functions = {}
__builtins__ = {}

@classmethod
def register(cls, func_name: str):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
def __init__(self):
self._functions = {}

cls.functions[func_name.upper()] = wrapper
return wrapper
def __contains__(self, function_name: str) -> bool:
return (
function_name.upper() in self._functions
or function_name.upper() in self.__builtins__
)

return decorator
def __getitem__(self, function_name: str) -> Func:
return (
self._functions.get(function_name.upper())
or self.__builtins__[function_name.upper()]
)

@staticmethod
def evaluate(expression: str, obj: Any) -> Any:
def __call__(self, func_name: str, *args: Any, **kwargs: Any) -> Any:
return self[func_name](*args, **kwargs)

def evaluate(self, expression: str, obj: Any) -> Any:
expression = expression.strip()

if not expression:
Expand All @@ -31,7 +38,7 @@ def evaluate(expression: str, obj: Any) -> Any:
if not func_name:
raise ValueError("Invalid function name")

if func_name not in ExpressionEvaluator.functions:
if func_name not in self:
raise ValueError(f"Unknown function: {func_name}")

remaining = expression[len(func_name) :].strip()
Expand Down Expand Up @@ -70,11 +77,18 @@ def evaluate(expression: str, obj: Any) -> Any:
if ExpressionEvaluator._is_string_literal(arg):
evaluated_args.append(arg[1:-1])
elif "(" in arg:
evaluated_args.append(ExpressionEvaluator.evaluate(arg, obj))
evaluated_args.append(self.evaluate(arg, obj))
else:
evaluated_args.append(ExpressionEvaluator._get_field_value(arg, obj))

return ExpressionEvaluator.functions[func_name](*evaluated_args)
return self(func_name, *evaluated_args)

def define_function(self, func_name: str):
return self._wrap(func_name, self._functions)

@classmethod
def define_builtin(cls, func_name: str):
return cls._wrap(func_name, cls.__builtins__)

@staticmethod
def _is_string_literal(arg: str):
Expand Down Expand Up @@ -105,3 +119,17 @@ def _get_field_value(field_path: str, obj: Any) -> Any:
return None

return current

@staticmethod
def _wrap(
func_name: str, function_store: dict[str, Callable[..., Any]]
) -> Callable[..., Any]:
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)

function_store[func_name.upper()] = wrapper
return wrapper

return decorator
17 changes: 9 additions & 8 deletions core/harambe_core/parser/expression/functions.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,44 @@
from typing import Any

from slugify import slugify as python_slugify

from harambe_core.parser.expression.evaluator import ExpressionEvaluator


@ExpressionEvaluator.register("NOOP")
@ExpressionEvaluator.define_builtin("NOOP")
def noop(*args: Any) -> Any:
return args[0] if len(args) == 1 else args


@ExpressionEvaluator.register("CONCAT")
@ExpressionEvaluator.define_builtin("CONCAT")
def concat(*args: Any, seperator: str = "") -> str:
return seperator.join(str(arg) for arg in args if arg is not None)


@ExpressionEvaluator.register("CONCAT_WS")
@ExpressionEvaluator.define_builtin("CONCAT_WS")
def concat_ws(seperator: str, *args: Any) -> str:
return concat(*args, seperator=seperator)


@ExpressionEvaluator.register("COALESCE")
@ExpressionEvaluator.define_builtin("COALESCE")
def coalesce(*args: Any) -> Any:
for arg in args:
if arg:
return arg
return None


@ExpressionEvaluator.register("SLUGIFY")
@ExpressionEvaluator.define_builtin("SLUGIFY")
def slugify(*args: Any) -> str:
text = concat_ws(" ", *args)
text = concat_ws("-", *args)
return python_slugify(text)


@ExpressionEvaluator.register("UPPER")
@ExpressionEvaluator.define_builtin("UPPER")
def upper(text: str) -> str:
return text.upper()


@ExpressionEvaluator.register("LOWER")
@ExpressionEvaluator.define_builtin("LOWER")
def lower(text: str) -> str:
return text.lower()
17 changes: 11 additions & 6 deletions core/harambe_core/parser/parser.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from abc import ABC, abstractmethod
from typing import Any, List, Optional, Type, Self

from pydantic import (
Expand All @@ -20,8 +19,8 @@
from harambe_core.parser.type_phone_number import ParserTypePhoneNumber
from harambe_core.parser.type_string import ParserTypeString
from harambe_core.parser.type_url import ParserTypeUrl
from harambe_core.types import SchemaFieldType
from harambe_core.types import Schema
from harambe_core.types import SchemaFieldType


class SchemaParser:
Expand All @@ -31,10 +30,16 @@ class SchemaParser:

model: Type[BaseModel]

def __init__(self, schema: Schema):
def __init__(
self, schema: Schema, evaluator: ExpressionEvaluator | None = None
) -> None:
if "$schema" in schema:
del schema["$schema"]

if evaluator is None:
evaluator = ExpressionEvaluator()

self.evaluator = evaluator
self.schema = schema
self.field_types: dict[SchemaFieldType, Any] = {}
self.all_required_fields = self._get_all_required_fields(self.schema)
Expand Down Expand Up @@ -154,7 +159,7 @@ def _schema_to_pydantic_model(

config: ConfigDict = {"extra": "forbid", "str_strip_whitespace": True}
config.update(schema.get("__config__", {}))
base_model = base_model_factory(config, computed_fields)
base_model = base_model_factory(config, computed_fields, self.evaluator)

return create_model(model_name, __base__=base_model, **fields)

Expand Down Expand Up @@ -256,7 +261,7 @@ def is_empty(value: Any) -> bool:


def base_model_factory(
config: ConfigDict, computed_fields: dict[str, str]
config: ConfigDict, computed_fields: dict[str, str], evaluator: ExpressionEvaluator
) -> Type[BaseModel]:
class PreValidatedBaseModel(BaseModel):
model_config: ConfigDict = config
Expand Down Expand Up @@ -292,7 +297,7 @@ def trim_and_nullify(value: Any) -> Any:
@model_validator(mode="after")
def evaluate_computed_fields(self) -> Self:
for field, expression in computed_fields.items():
res = ExpressionEvaluator.evaluate(expression, self)
res = evaluator.evaluate(expression, self)
setattr(self, field, res)
return self

Expand Down
2 changes: 1 addition & 1 deletion core/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "harambe-core"
version = "0.41.0"
version = "0.42.0"
description = "Core types for harambe SDK 🐒🍌"
authors = [
{ name = "Adam Watkins", email = "[email protected]" }
Expand Down
16 changes: 15 additions & 1 deletion core/test/parser/expression/test_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

@pytest.fixture
def evaluator():
return ExpressionEvaluator
return ExpressionEvaluator()


def test_evaluate_simple(evaluator):
Expand Down Expand Up @@ -101,3 +101,17 @@ def test_evaluate_unknown_function(evaluator):
def test_invalid_parenthesis(evaluator, expression):
with pytest.raises(SyntaxError):
evaluator.evaluate(expression, {})


def test_register_custom_function():
evaluator1 = ExpressionEvaluator()
evaluator2 = ExpressionEvaluator()

@evaluator1.define_function("CUSTOM")
def custom_func(a, b):
return a + b

assert evaluator1.evaluate("CUSTOM(a, b)", {"a": 10, "b": 20}) == 30

with pytest.raises(ValueError, match="Unknown function: CUSTOM"):
evaluator2.evaluate("CUSTOM(a, b)", {"a": 10, "b": 20})
1 change: 1 addition & 0 deletions core/test/parser/expression/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def test_slugify():
assert slugify("Hello-World") == "hello-world"
assert slugify("Hello World", "Another") == "hello-world-another"
assert slugify("Hello World", "Another", 2) == "hello-world-another-2"
assert slugify(1, 2, 3, "four") == "1-2-3-four"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Paxton needs to get rid of special characters like §. Does sluggify do this already?



def test_upper():
Expand Down
4 changes: 3 additions & 1 deletion sdk/harambe/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
)

from harambe_core import SchemaParser, Schema
from harambe_core.parser.expression import ExpressionEvaluator


class AsyncScraper(Protocol):
Expand Down Expand Up @@ -88,14 +89,15 @@ def __init__(
context: Optional[Context] = None,
schema: Optional[Schema] = None,
deduper: Optional[DuplicateHandler] = None,
evaluator: Optional[ExpressionEvaluator] = None,
):
self.page: Page = page # type: ignore
self._id = run_id or uuid.uuid4()
self._domain = domain
self._stage = stage
self._scraper = scraper
self._context = context or {}
self._validator = SchemaParser(schema) if schema else None
self._validator = SchemaParser(schema, evaluator) if schema else None
self._saved_data: set[ScrapeResult] = set()
self._saved_cookies: List[Cookie] = []
self._saved_local_storage: List[LocalStorage] = []
Expand Down
2 changes: 1 addition & 1 deletion sdk/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "harambe-sdk"
version = "0.41.0"
version = "0.42.0"
description = "Data extraction SDK for Playwright 🐒🍌"
authors = [
{ name = "Adam Watkins", email = "[email protected]" }
Expand Down
Loading