Source code for fluent_codegen.remove_unused_assignments

"""
Utility to remove unused assignments from a Function body.

Uses ``rewriting_traverse`` to walk the codegen AST.  Only handles a single
function scope — raises ``AssertionError`` if a nested ``Scope`` (e.g. inner
function or class) is encountered.
"""

from __future__ import annotations

from collections.abc import Callable, Sequence

from .codegen import (
    Assignment,
    Block,
    CodeGenAst,
    CodeGenAstList,
    CodeGenAstType,
    Function,
    Name,
    Scope,
)


def _collect_blocks(func: Function) -> list[Block]:
    """Return every Block reachable from *func*.body, checking for nested scopes."""
    blocks: list[Block] = []
    seen: set[int] = set()

    def _walk(obj: object) -> None:
        obj_id = id(obj)
        if obj_id in seen:
            return
        seen.add(obj_id)

        if isinstance(obj, Scope) and obj is not func:
            if type(obj) is not Scope:
                raise AssertionError(f"remove_unused_assignments does not handle nested Scopes ({type(obj).__name__})")
            return

        if isinstance(obj, Block):
            blocks.append(obj)

        # Recurse into attributes, finding Block children generically.
        if isinstance(obj, (CodeGenAst, CodeGenAstList)):
            for value in obj.__dict__.values():
                _walk_value(value)

    def _walk_value(value: object) -> None:
        if isinstance(value, (CodeGenAst, CodeGenAstList, Scope)):
            _walk(value)
        elif isinstance(value, (list, tuple)):
            for item in value:  # type: ignore[reportUnknownVariableType]
                _walk_value(item)  # type: ignore[reportUnknownVariableType]

    # Start from statements inside the body (not the Function itself, which is
    # a Scope we need to skip).
    for stmt in func.body.statements:
        _walk_value(stmt)
    # Also include the body block itself.
    blocks.insert(0, func.body)
    return blocks


def traverse(
    node: CodeGenAstType | Sequence[CodeGenAstType],
    func: Callable[[CodeGenAstType], CodeGenAstType],
    *,
    _visited: set[int] | None = None,
    exclude_assignment_target: bool = False,
):
    """
    Apply 'func' to node and all sub CodeGenAst nodes.

    Discovers child nodes by introspecting instance attributes rather than
    relying on a manually-maintained list.  A *visited* set (keyed by
    ``id()``) prevents infinite recursion through circular references
    (e.g. Block.scope → Function → body → Block).

    if `include_func` is not None, it should be a callable that
    decides whether or not to visit a node and it's descendants
    """
    if _visited is None:
        _visited = set()
    node_id = id(node)
    if node_id in _visited:
        return
    _visited.add(node_id)
    if isinstance(node, (CodeGenAst, CodeGenAstList)):
        func(node)
        if exclude_assignment_target and isinstance(node, Assignment):
            exclude_keys = ("target",)
        else:
            exclude_keys = ()

        for key, value in node.__dict__.items():
            if key in exclude_keys:
                continue
            traverse(value, func, _visited=_visited, exclude_assignment_target=exclude_assignment_target)
    elif isinstance(node, (list, tuple)):
        for i in node:
            traverse(i, func, _visited=_visited, exclude_assignment_target=exclude_assignment_target)
    elif isinstance(node, dict):
        for v in node.values():  # type: ignore[reportUnknownVariableType]
            traverse(v, func, _visited=_visited, exclude_assignment_target=exclude_assignment_target)  # type: ignore[reportUnknownVariableType]


def _collect_name_references(func: Function) -> set[str]:
    """Return the set of all name strings referenced as ``Name`` expressions,
    apart from those in LHS of an Assignment
    """
    names: set[str] = set()

    def _visitor(node: CodeGenAstType) -> CodeGenAstType:
        if isinstance(node, Name):
            names.add(node.name)
        return node

    traverse(func.body, _visitor, exclude_assignment_target=True)
    return names


def _assigned_names(blocks: list[Block]) -> set[str]:
    """Return all names that appear on the LHS of an ``Assignment``."""
    result: set[str] = set()
    for block in blocks:
        for stmt in block.statements:
            if isinstance(stmt, Assignment):
                result.update(stmt.names)
    return result


def _remove_once(func: Function, blocks: list[Block]) -> bool:
    """Remove unused assignment statements.  Returns True if anything changed."""
    referenced = _collect_name_references(func)
    assigned = _assigned_names(blocks)
    unused = assigned - referenced
    if not unused:
        return False

    any_changed = False
    for block in blocks:
        original_length = len(block.statements)
        new_statements = [
            stmt for stmt in block.statements if not (isinstance(stmt, Assignment) and set(stmt.names) <= unused)
        ]
        if len(new_statements) < original_length:
            block.statements = new_statements
            any_changed = True

    return any_changed


[docs] def remove_unused_assignments(func: Function) -> None: """ Remove statements that assign to a name which is never read. Works recursively: if removing ``x = y`` makes ``y`` unused, it will be removed too. Raises ``AssertionError`` on nested scopes. """ blocks = _collect_blocks(func) while _remove_once(func, blocks): pass