diff --git a/tools/schemapi/vega_expr.py b/tools/schemapi/vega_expr.py index cf88dc906..8effd0aec 100644 --- a/tools/schemapi/vega_expr.py +++ b/tools/schemapi/vega_expr.py @@ -34,13 +34,16 @@ FUNCTION_DEF_LINE: Pattern[str] = re.compile(r"") LIQUID_INCLUDE: Pattern[str] = re.compile(r"( \{% include.+%\})") -TYPE: Literal[r"type"] = "type" +TYPE: Literal[r"type"] = r"type" RAW: Literal["raw"] = "raw" SOFTBREAK: Literal["softbreak"] = "softbreak" TEXT: Literal["text"] = "text" CHILDREN: Literal["children"] = "children" RETURN_ANNOTATION = "FunctionExpression" +EXPR_ANNOTATION = "IntoExpression" +NONE: Literal[r"None"] = r"None" +STAR_ARGS: Literal["*args"] = "*args" def download_expressions_md(url: str, /) -> Path: @@ -136,14 +139,14 @@ class VegaExprNode: parameters: list[VegaExprParam] = dataclasses.field(default_factory=list) def to_signature(self) -> str: + """NOTE: 101/147 cases are all required args.""" pre_params = f"def {self.name_safe}(cls, " post_params = f", /) -> {RETURN_ANNOTATION}:" param_list = "" - if all(p.required for p in self.parameters): - # NOTE: covers 101/147 cases - param_list = ", ".join(p.name for p in self.parameters) + if self.is_overloaded(): + param_list = VegaExprParam.star_args() else: - param_list = "" + param_list = ", ".join(p.to_str() for p in self.parameters) return f"{pre_params}{param_list}{post_params}" def with_parameters(self) -> Self: @@ -336,6 +339,19 @@ class VegaExprParam: required: bool variadic: bool = False + @staticmethod + def star_args() -> LiteralString: + return f"{STAR_ARGS}: Any" + + def to_str(self) -> str: + """Return as an annotated parameter, with a default if needed.""" + if self.required: + return f"{self.name}: {EXPR_ANNOTATION}" + elif not self.variadic: + return f"{self.name}: {EXPR_ANNOTATION} = {NONE}" + else: + return self.star_args() + @classmethod def iter_params(cls, raw_texts: Iterable[str], /) -> Iterator[Self]: """Yields an ordered parameter list.""" @@ -351,7 +367,7 @@ def iter_params(cls, raw_texts: Iterable[str], /) -> Iterator[Self]: elif s.isalnum(): yield cls(s, required=is_required) elif s == "...": - yield cls("*args", required=False, variadic=True) + yield cls(STAR_ARGS, required=False, variadic=True) else: continue