diff --git a/core/harambe_core/parser/expression/evaluator.py b/core/harambe_core/parser/expression/evaluator.py index 33f0b40..0a9d03f 100644 --- a/core/harambe_core/parser/expression/evaluator.py +++ b/core/harambe_core/parser/expression/evaluator.py @@ -1,24 +1,31 @@ from functools import wraps -from typing import Any +from typing import Any, Callable + +Func = Callable[..., Any] class ExpressionEvaluator: - functions = {} + __builtins__ = {} - @classmethod - def register(cls, func_name: str): - def decorator(func): - @wraps(func) - def wrapper(*args, **kwargs): - return func(*args, **kwargs) + def __init__(self): + self._functions = {} - cls.functions[func_name.upper()] = wrapper - return wrapper + def __contains__(self, function_name: str) -> bool: + return ( + function_name.upper() in self._functions + or function_name.upper() in self.__builtins__ + ) - return decorator + def __getitem__(self, function_name: str) -> Func: + return ( + self._functions.get(function_name.upper()) + or self.__builtins__[function_name.upper()] + ) - @staticmethod - def evaluate(expression: str, obj: Any) -> Any: + def __call__(self, func_name: str, *args: Any, **kwargs: Any) -> Any: + return self[func_name](*args, **kwargs) + + def evaluate(self, expression: str, obj: Any) -> Any: expression = expression.strip() if not expression: @@ -31,7 +38,7 @@ def evaluate(expression: str, obj: Any) -> Any: if not func_name: raise ValueError("Invalid function name") - if func_name not in ExpressionEvaluator.functions: + if func_name not in self: raise ValueError(f"Unknown function: {func_name}") remaining = expression[len(func_name) :].strip() @@ -70,11 +77,18 @@ def evaluate(expression: str, obj: Any) -> Any: if ExpressionEvaluator._is_string_literal(arg): evaluated_args.append(arg[1:-1]) elif "(" in arg: - evaluated_args.append(ExpressionEvaluator.evaluate(arg, obj)) + evaluated_args.append(self.evaluate(arg, obj)) else: evaluated_args.append(ExpressionEvaluator._get_field_value(arg, obj)) - return ExpressionEvaluator.functions[func_name](*evaluated_args) + return self(func_name, *evaluated_args) + + def define_function(self, func_name: str): + return self._wrap(func_name, self._functions) + + @classmethod + def define_builtin(cls, func_name: str): + return cls._wrap(func_name, cls.__builtins__) @staticmethod def _is_string_literal(arg: str): @@ -105,3 +119,17 @@ def _get_field_value(field_path: str, obj: Any) -> Any: return None return current + + @staticmethod + def _wrap( + func_name: str, function_store: dict[str, Callable[..., Any]] + ) -> Callable[..., Any]: + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + function_store[func_name.upper()] = wrapper + return wrapper + + return decorator diff --git a/core/harambe_core/parser/expression/functions.py b/core/harambe_core/parser/expression/functions.py index 4483842..5f6377a 100644 --- a/core/harambe_core/parser/expression/functions.py +++ b/core/harambe_core/parser/expression/functions.py @@ -1,25 +1,26 @@ from typing import Any + from slugify import slugify as python_slugify from harambe_core.parser.expression.evaluator import ExpressionEvaluator -@ExpressionEvaluator.register("NOOP") +@ExpressionEvaluator.define_builtin("NOOP") def noop(*args: Any) -> Any: return args[0] if len(args) == 1 else args -@ExpressionEvaluator.register("CONCAT") +@ExpressionEvaluator.define_builtin("CONCAT") def concat(*args: Any, seperator: str = "") -> str: return seperator.join(str(arg) for arg in args if arg is not None) -@ExpressionEvaluator.register("CONCAT_WS") +@ExpressionEvaluator.define_builtin("CONCAT_WS") def concat_ws(seperator: str, *args: Any) -> str: return concat(*args, seperator=seperator) -@ExpressionEvaluator.register("COALESCE") +@ExpressionEvaluator.define_builtin("COALESCE") def coalesce(*args: Any) -> Any: for arg in args: if arg: @@ -27,17 +28,17 @@ def coalesce(*args: Any) -> Any: return None -@ExpressionEvaluator.register("SLUGIFY") +@ExpressionEvaluator.define_builtin("SLUGIFY") def slugify(*args: Any) -> str: - text = concat_ws(" ", *args) + text = concat_ws("-", *args) return python_slugify(text) -@ExpressionEvaluator.register("UPPER") +@ExpressionEvaluator.define_builtin("UPPER") def upper(text: str) -> str: return text.upper() -@ExpressionEvaluator.register("LOWER") +@ExpressionEvaluator.define_builtin("LOWER") def lower(text: str) -> str: return text.lower() diff --git a/core/harambe_core/parser/parser.py b/core/harambe_core/parser/parser.py index 9820143..9973829 100644 --- a/core/harambe_core/parser/parser.py +++ b/core/harambe_core/parser/parser.py @@ -1,4 +1,3 @@ -from abc import ABC, abstractmethod from typing import Any, List, Optional, Type, Self from pydantic import ( @@ -20,8 +19,8 @@ from harambe_core.parser.type_phone_number import ParserTypePhoneNumber from harambe_core.parser.type_string import ParserTypeString from harambe_core.parser.type_url import ParserTypeUrl -from harambe_core.types import SchemaFieldType from harambe_core.types import Schema +from harambe_core.types import SchemaFieldType class SchemaParser: @@ -31,10 +30,16 @@ class SchemaParser: model: Type[BaseModel] - def __init__(self, schema: Schema): + def __init__( + self, schema: Schema, evaluator: ExpressionEvaluator | None = None + ) -> None: if "$schema" in schema: del schema["$schema"] + if evaluator is None: + evaluator = ExpressionEvaluator() + + self.evaluator = evaluator self.schema = schema self.field_types: dict[SchemaFieldType, Any] = {} self.all_required_fields = self._get_all_required_fields(self.schema) @@ -154,7 +159,7 @@ def _schema_to_pydantic_model( config: ConfigDict = {"extra": "forbid", "str_strip_whitespace": True} config.update(schema.get("__config__", {})) - base_model = base_model_factory(config, computed_fields) + base_model = base_model_factory(config, computed_fields, self.evaluator) return create_model(model_name, __base__=base_model, **fields) @@ -256,7 +261,7 @@ def is_empty(value: Any) -> bool: def base_model_factory( - config: ConfigDict, computed_fields: dict[str, str] + config: ConfigDict, computed_fields: dict[str, str], evaluator: ExpressionEvaluator ) -> Type[BaseModel]: class PreValidatedBaseModel(BaseModel): model_config: ConfigDict = config @@ -292,7 +297,7 @@ def trim_and_nullify(value: Any) -> Any: @model_validator(mode="after") def evaluate_computed_fields(self) -> Self: for field, expression in computed_fields.items(): - res = ExpressionEvaluator.evaluate(expression, self) + res = evaluator.evaluate(expression, self) setattr(self, field, res) return self diff --git a/core/pyproject.toml b/core/pyproject.toml index 80288bd..9cd2e17 100644 --- a/core/pyproject.toml +++ b/core/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "harambe-core" -version = "0.41.0" +version = "0.42.0" description = "Core types for harambe SDK 🐒🍌" authors = [ { name = "Adam Watkins", email = "adam@reworkd.ai" } diff --git a/core/test/parser/expression/test_expression.py b/core/test/parser/expression/test_expression.py index dee3fa5..ba8a41e 100644 --- a/core/test/parser/expression/test_expression.py +++ b/core/test/parser/expression/test_expression.py @@ -5,7 +5,7 @@ @pytest.fixture def evaluator(): - return ExpressionEvaluator + return ExpressionEvaluator() def test_evaluate_simple(evaluator): @@ -101,3 +101,17 @@ def test_evaluate_unknown_function(evaluator): def test_invalid_parenthesis(evaluator, expression): with pytest.raises(SyntaxError): evaluator.evaluate(expression, {}) + + +def test_register_custom_function(): + evaluator1 = ExpressionEvaluator() + evaluator2 = ExpressionEvaluator() + + @evaluator1.define_function("CUSTOM") + def custom_func(a, b): + return a + b + + assert evaluator1.evaluate("CUSTOM(a, b)", {"a": 10, "b": 20}) == 30 + + with pytest.raises(ValueError, match="Unknown function: CUSTOM"): + evaluator2.evaluate("CUSTOM(a, b)", {"a": 10, "b": 20}) diff --git a/core/test/parser/expression/test_functions.py b/core/test/parser/expression/test_functions.py index 22b69af..b6e118d 100644 --- a/core/test/parser/expression/test_functions.py +++ b/core/test/parser/expression/test_functions.py @@ -38,6 +38,7 @@ def test_slugify(): assert slugify("Hello-World") == "hello-world" assert slugify("Hello World", "Another") == "hello-world-another" assert slugify("Hello World", "Another", 2) == "hello-world-another-2" + assert slugify(1, 2, 3, "four") == "1-2-3-four" def test_upper(): diff --git a/sdk/harambe/core.py b/sdk/harambe/core.py index 1f9f728..8b2ee93 100644 --- a/sdk/harambe/core.py +++ b/sdk/harambe/core.py @@ -57,6 +57,7 @@ ) from harambe_core import SchemaParser, Schema +from harambe_core.parser.expression import ExpressionEvaluator class AsyncScraper(Protocol): @@ -88,6 +89,7 @@ def __init__( context: Optional[Context] = None, schema: Optional[Schema] = None, deduper: Optional[DuplicateHandler] = None, + evaluator: Optional[ExpressionEvaluator] = None, ): self.page: Page = page # type: ignore self._id = run_id or uuid.uuid4() @@ -95,7 +97,7 @@ def __init__( self._stage = stage self._scraper = scraper self._context = context or {} - self._validator = SchemaParser(schema) if schema else None + self._validator = SchemaParser(schema, evaluator) if schema else None self._saved_data: set[ScrapeResult] = set() self._saved_cookies: List[Cookie] = [] self._saved_local_storage: List[LocalStorage] = [] diff --git a/sdk/pyproject.toml b/sdk/pyproject.toml index 089bff7..c1abc8a 100644 --- a/sdk/pyproject.toml +++ b/sdk/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "harambe-sdk" -version = "0.41.0" +version = "0.42.0" description = "Data extraction SDK for Playwright 🐒🍌" authors = [ { name = "Adam Watkins", email = "adam@reworkd.ai" }