diff --git a/compiler.py b/compiler.py index fab688b5..893283b6 100644 --- a/compiler.py +++ b/compiler.py @@ -18,6 +18,7 @@ Hole, Int, List, + MatchCase, MatchFunction, Object, Record, @@ -60,9 +61,218 @@ def decl(self) -> str: return f"struct object* {self.name}({args})" +class MatchKind: + def compile(self, arg: str) -> str: + raise NotImplementedError + + +class AcceptAny(MatchKind): + def compile(self, arg: str) -> str: + return "true" + + +class IsNumber(MatchKind): + def compile(self, arg: str) -> str: + return f"is_num({arg})" + + +class IsHole(MatchKind): + def compile(self, arg: str) -> str: + return f"is_hole({arg})" + + +class IsString(MatchKind): + def compile(self, arg: str) -> str: + return f"is_string({arg})" + + +class IsVariant(MatchKind): + def compile(self, arg: str) -> str: + return f"is_variant({arg})" + + +class IsList(MatchKind): + pass + + +class IsRecord(MatchKind): + pass + + +@dataclasses.dataclass +class NumberHasValue(MatchKind): + value: int + + def compile(self, arg: str) -> str: + return f"is_num_equal_word({arg}, {self.value})" + + +def coerce_int(object: Object) -> int: + assert isinstance(object, Int) + return object.value + + +@dataclasses.dataclass +class StringHasValue(MatchKind): + value: str + + def compile(self, arg: str) -> str: + if len(self.value) < 8: + return f"({arg} == mksmallstring({json.dumps(self.value)}, {len(self.value)}))" + return f"string_equal_cstr_len({arg}, {json.dumps(self.value)}, {len(self.value)})" + + +def coerce_string(object: Object) -> str: + assert isinstance(object, String) + return object.value + + +@dataclasses.dataclass +class VariantHasTag(MatchKind): + tag: str + + def compile(self, arg: str) -> str: + return f"(variant_tag({arg}) == Tag_{self.tag})" + + +@dataclasses.dataclass(frozen=True) +class CondExpr(Object): + arg: Var # Actually, probably this one isn't needed?? + condition: MatchKind + body: Object + + +@dataclasses.dataclass(frozen=True) +class MatchExpr(Object): + arg: Object # Maybe not needed? + cases: typing.List[CondExpr] + fallthrough_case: Where | None + + +@dataclasses.dataclass(frozen=True) +class VariantValueExpr(Object): + variant: Object + + +def group_cases( + cases: typing.List[MatchCase], keyof: object, is_fallthrough: object +) -> tuple[typing.List[typing.List[MatchCase]], MatchCase | None]: + print("ungrouped cases") + print(cases) + groups = {} + fallthrough = None + for case in cases: + if is_fallthrough(case): + fallthrough = case + # nothing can match after the var + break + else: + if keyof(case) in groups: + groups[keyof(case)].append(case) + else: + groups[keyof(case)] = [case] + + print("grouped cases") + print(groups) + return list(groups.values()), fallthrough + + +def typename(case: MatchCase) -> str: + return type(case.pattern).__name__ + + +def pattern_is_var(case: MatchCase) -> bool: + return isinstance(case.pattern, Var) + + +def let(name: Var, value: Object, body: Object) -> Where: + return Where(body, Assign(name, value)) + + +def compile_match_function(match_fn: MatchFunction) -> Function: + fn_arg = Var(gensym("fn_arg")) + match_arg = Var(gensym("match")) + cases, fallthrough_case = compile_ungrouped_match_cases(match_arg, match_fn.cases, typename, pattern_is_var) + return Function(fn_arg, let(match_arg, fn_arg, MatchExpr(match_arg, cases, fallthrough_case))) + + +def compile_ungrouped_match_cases( + arg: Var, cases: typing.List[MatchCase], group_key: object, is_fallthrough: object +) -> tuple[typing.List[CondExpr], Where | None]: + grouped, fallthrough_case = group_cases(cases, group_key, is_fallthrough) + return [expand_group(arg, group, fallthrough_case) for group in grouped], compile_var_case(arg, fallthrough_case) + + +def compile_var_case(arg: Var, case: MatchCase | None) -> Where | None: + if case: + assert isinstance(case.pattern, Var) + return Where(case.body, Assign(case.pattern, arg)) + return None + + +def compile_int_cases(arg: Var, group: typing.List[MatchCase], fallthrough_case: MatchCase | None): + cases = [CondExpr(arg, NumberHasValue(coerce_int(case.pattern)), case.body) for case in group] + return MatchExpr(arg, cases, compile_var_case(arg, fallthrough_case)) + + +def compile_string_cases(arg: Var, group: typing.List[MatchCase], fallthrough_case: MatchCase | None): + cases = [CondExpr(arg, StringHasValue(coerce_string(case.pattern)), case.body) for case in group] + return MatchExpr(arg, cases, compile_var_case(arg, fallthrough_case)) + + +def compile_variant_cases(arg: Var, group: typing.List[MatchCase], fallthrough_case: MatchCase | None): + def case_tag(case: MatchCase): + assert isinstance(case.pattern, Variant) + return case.pattern.tag + + grouped_by_variant, _ = group_cases(group, case_tag, lambda x: False) + cond_exprs = [] + for group in grouped_by_variant: + lifted_matches = [MatchCase(case.pattern.value, case.body) for case in group] + print("lifted_matches", repr(lifted_matches)) + inner_arg = Var(gensym("variant_match")) + expanded_cases, inner_fallthrough_case = compile_ungrouped_match_cases( + inner_arg, lifted_matches, typename, pattern_is_var + ) + match_expr = let(inner_arg, VariantValueExpr(arg), MatchExpr(inner_arg, expanded_cases, inner_fallthrough_case)) + cond_exprs.append(CondExpr(arg, VariantHasTag(group[0].pattern.tag), match_expr)) + + return MatchExpr(arg, cond_exprs, compile_var_case(arg, fallthrough_case)) + + +def expand_group(arg: Var, group: typing.List[MatchCase], fallthrough_case: MatchCase | None): + if not group: + assert fallthrough_case + return compile_var_case(arg, fallthrough_case) + canonical_case = group[0] + if isinstance(canonical_case.pattern, Int): + return CondExpr(arg, IsNumber(), compile_int_cases(arg, group, fallthrough_case)) + if isinstance(canonical_case.pattern, Hole): + # throwing away subsequent holes + return CondExpr(arg, IsHole(), canonical_case.body) + if isinstance(canonical_case.pattern, Var): + raise Exception("saw a var") + if isinstance(canonical_case.pattern, Variant): + return CondExpr(arg, IsVariant(), compile_variant_cases(arg, group, fallthrough_case)) + if isinstance(canonical_case.pattern, String): + return CondExpr(arg, IsString(), compile_string_cases(arg, group, fallthrough_case)) + # if isinstance(canonical_case.pattern, List): + # if isinstance(canonical_case.pattern, Record): + raise NotImplementedError("expand_group", canonical_case.pattern) + + +gensym_counter = 0 + + +def gensym(stem: str = "tmp") -> str: + global gensym_counter + gensym_counter += 1 + return f"{stem}_{gensym_counter-1}" + + class Compiler: def __init__(self, main_fn: CompiledFunction) -> None: - self.gensym_counter: int = 0 + # self.gensym_counter: int = 0 self.functions: typing.List[CompiledFunction] = [main_fn] self.function: CompiledFunction = main_fn self.record_keys: Dict[str, int] = {} @@ -105,8 +315,7 @@ def variant_tag(self, key: str) -> int: return result def gensym(self, stem: str = "tmp") -> str: - self.gensym_counter += 1 - return f"{stem}_{self.gensym_counter-1}" + return gensym(stem) def _emit(self, line: str) -> None: self.function.code.append(line) @@ -152,7 +361,7 @@ def compile_assign(self, env: Env, exp: Assign) -> Env: return {**env, name: value} if isinstance(exp.value, MatchFunction): # Named match function - value = self.compile_match_function(env, exp.value, name) + value = self.compile_function(env, compile_match_function(exp.value), name) return {**env, name: value} value = self.compile(env, exp.value) return {**env, name: value} @@ -262,29 +471,6 @@ def try_match(self, env: Env, arg: str, pattern: Object, fallthrough: str) -> En return updates raise NotImplementedError("try_match", pattern) - def compile_match_function(self, env: Env, exp: MatchFunction, name: Optional[str]) -> str: - arg = self.gensym() - fn = self.make_compiled_function(arg, exp, name) - self.functions.append(fn) - cur = self.function - self.function = fn - funcenv = self.compile_function_env(fn, name) - for i, case in enumerate(exp.cases): - fallthrough = f"case_{i+1}" if i < len(exp.cases) - 1 else "no_match" - env_updates = self.try_match(funcenv, arg, case.pattern, fallthrough) - case_result = self.compile({**funcenv, **env_updates}, case.body) - self._emit(f"return {case_result};") - self._emit(f"{fallthrough}:;") - self._emit(r'fprintf(stderr, "no matching cases\n");') - self._emit("abort();") - # Pacify the C compiler - self._emit("return NULL;") - self.function = cur - if not fn.fields: - # TODO(max): Closure over freevars but only consts - return self._const_closure(fn) - return self.make_closure(env, fn) - def make_closure(self, env: Env, fn: CompiledFunction) -> str: name = self._mktemp(f"mkclosure(heap, {fn.name}, {len(fn.fields)})") for i, field in enumerate(fn.fields): @@ -448,8 +634,38 @@ def compile(self, env: Env, exp: Object) -> str: return self.compile_function(env, exp, name=None) if isinstance(exp, MatchFunction): # Anonymous match function - return self.compile_match_function(env, exp, name=None) - raise NotImplementedError(f"exp {type(exp)} {exp}") + return self.compile_function(env, compile_match_function(exp), name=None) + if isinstance(exp, MatchExpr): + return self.compile_match_expr(env, exp) + if isinstance(exp, VariantValueExpr): + value = self.compile(env, exp.variant) + return self._mktemp(f"variant_value({value});") + raise NotImplementedError(f"exp {type(exp)} {exp!r}") + + def compile_match_expr(self, env: Env, match_expr: MatchExpr) -> str: + arg = self.compile(env, match_expr.arg) + result = self.gensym("result") + done = self.gensym("done") + self._emit(f"struct object* {result} = NULL;") + for cond in match_expr.cases: + if isinstance(cond.condition, VariantHasTag): + self.variant_tag(cond.condition.tag) + fallthrough = self.gensym("case") + c_cond = cond.condition.compile(arg) + self._emit(f"if (!{c_cond}) goto {fallthrough};") + case_result = self.compile(env, cond.body) + self._emit(f"{result} = {case_result};") + self._emit(f"goto {done};") + self._emit(f"{fallthrough}:;") + if match_expr.fallthrough_case: + c_name = self.compile(env, match_expr.fallthrough_case) + self._emit(f"{result} = {c_name};") + self._emit(f"goto {done};") + else: + self._emit(r'fprintf(stderr, "no matching cases\n");') + self._emit("abort();") + self._emit(f"{done}:;") + return result def compile_to_string(program: Object, debug: bool) -> str: diff --git a/scrapscript.py b/scrapscript.py index e6418e50..48e527b7 100755 --- a/scrapscript.py +++ b/scrapscript.py @@ -1312,7 +1312,9 @@ def free_in(exp: Object) -> Set[str]: if isinstance(exp, Closure): # TODO(max): Should this remove the set of keys in the closure env? return free_in(exp.func) - raise NotImplementedError(("free_in", type(exp))) + # :'( + return set() + # raise NotImplementedError(("free_in", type(exp))) def improve_closure(closure: Closure) -> Closure: