diff --git a/README.md b/README.md index 08394fe..f8e90c1 100644 --- a/README.md +++ b/README.md @@ -64,7 +64,7 @@ Lamb comes with a few commands. Prefix them with a `:` ## Internals -Lamb treats each λ expression as a binary tree. Variable binding and reduction are all simple operations on that tree. All the magic happens inside [`nodes.py`](./lamb/nodes.py). +Lamb treats each λ expression as a binary tree. Variable binding and reduction are all simple operations on that tree. All this magic happens in [`nodes.py`](./lamb/nodes.py). **Highlights:** - `TreeWalker` is the iterator we (usually) use to traverse our tree. It walks the "perimeter" of the tree, visiting some nodes multiple times. @@ -80,16 +80,17 @@ Lamb treats each λ expression as a binary tree. Variable binding and reduction - Prettier colors - Prevent macro-chaining recursion - step-by-step reduction + - Full-reduce option (expand all macros) - Show a warning when a free variable is created - PyPi package ## Todo: + - Optimization: clone only if absolutely necessary + - Better class mutation: when is a node no longer valid? + - Loop detection - Command-line options (load a file, run a set of commands) - $\alpha$-equivalence check - Unchurch macro: make church numerals human-readable - - Full-reduce option (expand all macros) - - Print macro content if only a macro is typed - Smart alignment in all printouts - - Syntax highlighting: parenthesis, bound variables, macros, etc - + - Syntax highlighting: parenthesis, bound variables, macros, etc \ No newline at end of file diff --git a/lamb/node.py b/lamb/node.py index 2e1edf1..bfb0e36 100644 --- a/lamb/node.py +++ b/lamb/node.py @@ -152,6 +152,14 @@ class Node: else: raise TypeError("Can only set left or right side.") + def get_side(self, side: Direction): + if side == Direction.LEFT: + return self.left + elif side == Direction.RIGHT: + return self.right + else: + raise TypeError("Can only get left or right side.") + def go_left(self): """ @@ -210,7 +218,7 @@ class EndNode(Node): raise NotImplementedError("EndNodes MUST provide a `print_value` method!") class ExpandableEndNode(EndNode): - def expand(self) -> tuple[ReductionType, Node]: + def expand(self, *, macro_table = {}) -> tuple[ReductionType, Node]: raise NotImplementedError("ExpandableEndNodes MUST provide an `expand` method!") class FreeVar(EndNode): @@ -269,7 +277,7 @@ class Church(ExpandableEndNode): def print_value(self): return str(self.value) - def expand(self) -> tuple[ReductionType, Node]: + def expand(self, *, macro_table = {}) -> tuple[ReductionType, Node]: f = Bound("f") a = Bound("a") chain = a @@ -518,12 +526,57 @@ def reduce(node: Node, *, macro_table = {}) -> tuple[ReductionType, Node]: return ReductionType.FUNCTION_APPLY, out elif isinstance(n.left, ExpandableEndNode): - if isinstance(n.left, Macro): - r, n.left = n.left.expand( - macro_table = macro_table - ) - else: - r, n.left = n.left.expand() + r, n.left = n.left.expand( + macro_table = macro_table + ) return r, out + return ReductionType.NOTHING, out - return ReductionType.NOTHING, out \ No newline at end of file + + +# Expand all expandable end nodes. +def force_expand_macros(node: Node, *, macro_table = {}) -> tuple[int, Node]: + if not isinstance(node, Node): + raise TypeError(f"I can't reduce a {type(node)}") + + + out = clone(node) + ptr = out + from_side = Direction.UP + macro_expansions = 0 + + while True: + if isinstance(ptr, ExpandableEndNode): + if ptr.parent is None: + ptr = ptr.expand(macro_table = macro_table)[1] + out = ptr + ptr._set_parent(None, None) + else: + ptr.parent.set_side( + ptr.parent_side, # type: ignore + ptr.expand(macro_table = macro_table)[1] + ) + ptr = ptr.parent.get_side( + ptr.parent_side # type: ignore + ) + macro_expansions += 1 + + + if isinstance(ptr, EndNode): + from_side, ptr = ptr.go_up() + elif isinstance(ptr, Func): + if from_side == Direction.UP: + from_side, ptr = ptr.go_left() + elif from_side == Direction.LEFT: + from_side, ptr = ptr.go_up() + elif isinstance(ptr, Call): + if from_side == Direction.UP: + from_side, ptr = ptr.go_left() + elif from_side == Direction.LEFT: + from_side, ptr = ptr.go_right() + elif from_side == Direction.RIGHT: + from_side, ptr = ptr.go_up() + if ptr is node.parent: + break + + return macro_expansions, out # type: ignore \ No newline at end of file diff --git a/lamb/runner.py b/lamb/runner.py index 062d08f..3325134 100644 --- a/lamb/runner.py +++ b/lamb/runner.py @@ -108,6 +108,9 @@ class Runner: stop_reason = StopReason.MAX_EXCEEDED start_time = time.time() + full_reduce = isinstance(node, lamb.node.ExpandableEndNode) + out_text = [] + while (self.reduction_limit is None) or (i < self.reduction_limit): @@ -116,7 +119,7 @@ class Runner: print(f" Reducing... {i}", end = "\r") try: - red_type, new_node = lamb.node.reduce( + red_type, node = lamb.node.reduce( node, macro_table = self.macro_table ) @@ -124,8 +127,6 @@ class Runner: stop_reason = StopReason.INTERRUPT break - node = new_node - # If we can't reduce this expression anymore, # it's in beta-normal form. if red_type == lamb.node.ReductionType.NOTHING: @@ -137,12 +138,20 @@ class Runner: if red_type == lamb.node.ReductionType.FUNCTION_APPLY: macro_expansions += 1 + # Expand all macros if we need to + if full_reduce: + m, node = lamb.node.force_expand_macros( + node, + macro_table = self.macro_table + ) + macro_expansions += m + if i >= self.iter_update: # Clear reduction counter print(" " * round(14 + math.log10(i)), end = "\r") - out_text = [ - ("class:result_header", f"\nRuntime: "), + out_text += [ + ("class:result_header", f"Runtime: "), ("class:text", f"{time.time() - start_time:.03f} seconds"), ("class:result_header", f"\nExit reason: "), @@ -152,14 +161,19 @@ class Runner: ("class:text", f"{macro_expansions:,}"), ("class:result_header", f"\nReductions: "), - ("class:text", f"{i:,} "), + ("class:text", f"{i:,}\t"), ("class:muted", f"(Limit: {self.reduction_limit:,})") ] + if full_reduce: + out_text += [ + ("class:warn", "\nAll macros have been expanded") + ] + if (stop_reason == StopReason.BETA_NORMAL or stop_reason == StopReason.LOOP_DETECTED): out_text += [ ("class:result_header", "\n\n => "), - ("class:text", str(new_node)), # type: ignore + ("class:text", str(node)), # type: ignore ] printf(