diff --git a/lamb/node.py b/lamb/node.py index 752a5bf..6b74005 100644 --- a/lamb/node.py +++ b/lamb/node.py @@ -36,7 +36,6 @@ class ReductionError(Exception): def __init__(self, msg: str): self.msg = msg - class TreeWalker: """ An iterator that walks the "outline" of a tree @@ -348,7 +347,6 @@ class History(ExpandableEndNode): def copy(self): return History(runner = self.runner) - bound_counter = 0 class Bound(EndNode): def __init__(self, name: str, *, forced_id = None, runner = None): @@ -590,7 +588,6 @@ def call_func(fn: Func, arg: Node): n.parent.set_side(n.parent_side, clone(arg)) # type: ignore return fn.left - # Do a single reduction step def reduce(node: Node) -> tuple[ReductionType, Node]: if not isinstance(node, Node): @@ -617,12 +614,19 @@ def reduce(node: Node) -> tuple[ReductionType, Node]: return ReductionType.NOTHING, out +def expand(node: Node, *, force_all = False) -> tuple[int, Node]: + """ + Expands expandable nodes in the given tree. + + If force_all is false, this only expands + ExpandableEndnodes that have "always_expand" set to True. + + If force_all is True, this expands ALL + ExpandableEndnodes. + """ -# Expand all expandable end nodes. -def finalize_macros(node: Node, *, force = False) -> tuple[int, Node]: if not isinstance(node, Node): - raise TypeError(f"I can't reduce a {type(node)}") - + raise TypeError(f"I don't know what to do with a {type(node)}") out = clone(node) ptr = out @@ -632,7 +636,7 @@ def finalize_macros(node: Node, *, force = False) -> tuple[int, Node]: while True: if ( isinstance(ptr, ExpandableEndNode) and - (force or ptr.always_expand) + (force_all or ptr.always_expand) ): if ptr.parent is None: ptr = ptr.expand()[1] @@ -649,6 +653,7 @@ def finalize_macros(node: Node, *, force = False) -> tuple[int, Node]: macro_expansions += 1 + # Tree walk logic if isinstance(ptr, EndNode): from_side, ptr = ptr.go_up() elif isinstance(ptr, Func): @@ -666,4 +671,4 @@ def finalize_macros(node: Node, *, force = False) -> tuple[int, Node]: if ptr is node.parent: break - return macro_expansions, out # type: ignore \ No newline at end of file + return macro_expansions, out \ No newline at end of file diff --git a/lamb/runner.py b/lamb/runner.py index 81120b7..4bbdebf 100644 --- a/lamb/runner.py +++ b/lamb/runner.py @@ -136,8 +136,8 @@ class Runner: ("class:warn", "All macros will be expanded"), ("class:warn", "\n") ] - m, node = lamb.node.finalize_macros(node, force = True) - macro_expansions += m + m, node = lamb.node.expand(node, force_all = only_macro) + macro_expansions += m for i in status["free_variables"]: @@ -173,12 +173,8 @@ class Runner: if red_type == lamb.node.ReductionType.FUNCTION_APPLY: macro_expansions += 1 - # Expand all remaining macros - m, node = lamb.node.finalize_macros(node, force = only_macro) - macro_expansions += m - if k >= self.iter_update: - # Clear reduction counter + # Clear reduction counter if it was printed print(" " * round(14 + math.log10(k)), end = "\r") out_text += [ @@ -202,7 +198,7 @@ class Runner: ("class:text", str(node)), # type: ignore ] - self.history.append(lamb.node.finalize_macros(node, force = True)[1]) + self.history.append(lamb.node.expand(node, force_all = True)[1]) printf(