Source code for fluent_codegen.codegen

"""
Utilities for doing Python code generation
"""

from __future__ import annotations

import builtins
import enum
import keyword
import re
import sys
import textwrap
from abc import ABC, abstractmethod
from collections.abc import Callable, Mapping, Sequence
from dataclasses import dataclass
from functools import cached_property
from typing import ClassVar, Literal, assert_never, overload

if sys.version_info >= (3, 13):
    from typing import TypeIs  # pragma: no cover
else:
    from typing_extensions import TypeIs  # pragma: no cover

from . import ast_compat as py_ast
from .ast_compat import (
    DEFAULT_AST_ARGS,
    DEFAULT_AST_ARGS_ADD,
    DEFAULT_AST_ARGS_ARGUMENTS,
    DEFAULT_AST_ARGS_MODULE,
    CommentNode,
    unparse_with_comments,
)
from .utils import allowable_keyword_arg_name, allowable_name

#: Type alias for comment strings stored in :attr:`Block.statements`.
Comment = str

# This module provides simple utilities for building up Python source code.
# The design originally came from fluent-compiler, so had the following aims
# and constraints:
#
# 1. Performance.
#
#    The resulting Python code should do as little as possible, especially for
#    simple cases.
#
# 2. Correctness (obviously)
#
#    In particular, we should try to make it hard to generate code that is
#    syntactically correct and therefore compiles but doesn't work. We try to
#    make it hard to generate accidental name clashes, or use variables that are
#    not defined.
#
#    Correctness also has a security implication, since the result of this code
#    might be 'exec'ed. To that end:
#     * We build up AST, rather than strings. This eliminates many
#       potential bugs caused by wrong escaping/interpolation.
#     * the `as_ast()` methods are paranoid about input, and do many asserts.
#       We do this even though other layers will usually have checked the
#       input, to allow us to reason locally when checking these methods. These
#       asserts must also have 100% code coverage.
#
# 3. Simplicity
#
#    The resulting Python code should be easy to read and understand.
#
# 4. Predictability
#
#    Since we want to test the resulting source code, we have made some design
#    decisions that aim to ensure things like function argument names are
#    consistent and so can be predicted easily.

# Outside of the original project `fluent-compiler`, this code will likely be
# useful for situations which have similar aims.

# It is has since evolved further aims:

# 5. Usability
#
#    We add convenience wrappers to make it easier to see what the output
#    is going to look like.


# --- Layers

# Within this module, there are about 3 conceptual layers, though not completely
# separated out:

# 1. At the bottom layer, there are CodeGenAst classes, such as `Number`,
#    `String`, `If` etc. These are thin wrappers around Python `ast` classes.
#
# 2. Block, Scope and Expression, which provide higher level convenience.
#    In particular:
#
#    - Scope provides management of names, to avoid name clashes. (For this
#      reason the `Name` class, which is a `CodeGenAst`, actually depends
#      on `Scope`, to make it harder to create Names that clash.)
#
#    - Block makes it easier to build up a list of statements.
#
#    - Expression is a base class that provides a convenient chained method
#      syntax for creating further expressions.
#
# 3. E-objects: These add another layer of convenience, allowing you to build up
#    Python code using Python syntax rather than the chained method
#    calls on Expression which can be awkward if you know statically what operations
#    are to be done.
#

# Note that:
# - the bottom layer does not accept `E-objects` in its constructors
#
# - the middle layer accepts both Expression and E-objects to support
#   the convenience of building up expressions easily.
#
# - the top layer, E-objects, allows mixing Expression and also allows mixing
#   native Python objects which get converted to E-objects.

# ---

SENSITIVE_FUNCTIONS = {
    # builtin functions that we should never be calling from our code
    # generation. This is a defense-in-depth mechansim to stop our code
    # generation becoming a code execution vulnerability. There should also be
    # higher level code that ensures we are not generating calls to arbitrary
    # Python functions. This is not a comprehensive list of functions we are not
    # using, but functions we definitely don't need and are most likely to be
    # used to execute remote code or to get around safety mechanisms.
    "__import__",
    "__build_class__",
    "apply",
    "compile",
    "eval",
    "exec",
    "execfile",
    "exit",
    "file",
    "globals",
    "locals",
    "open",
    "object",
    "reload",
    "type",
}


class CodeGenAst(ABC):
    """
    Base class representing a simplified Python AST (not the real one).
    Generates real `ast.*` nodes via `as_ast()` method.
    """

    @abstractmethod
    def as_ast(self, *, include_comments: bool = False) -> py_ast.AST: ...

    def as_python_source(self) -> str:
        """Return the Python source code for this AST node."""
        node = self.as_ast(include_comments=True)
        py_ast.fix_missing_locations(node)
        return unparse_with_comments(node)


class CodeGenAstList(ABC):
    """
    Alternative base class to CodeGenAst when we have code that wants to return a
    list of AST objects. These must also be `stmt` objects.
    """

    @abstractmethod
    def as_ast_list(self, allow_empty: bool = True, *, include_comments: bool = False) -> list[py_ast.stmt]: ...

    def as_python_source(self) -> str:
        """Return the Python source code for this AST list."""
        mod = py_ast.Module(body=self.as_ast_list(include_comments=True), type_ignores=[], **DEFAULT_AST_ARGS_MODULE)
        py_ast.fix_missing_locations(mod)
        return unparse_with_comments(mod)


CodeGenAstType = CodeGenAst | CodeGenAstList


[docs] class Scope: """Track name reservations and assignments within a lexical scope.""" def __init__(self, parent_scope: Scope | None = None): self.parent_scope = parent_scope self.names: set[str] = set() self._function_arg_reserved_names: set[str] = set() self._assignments: set[str] = set()
[docs] def is_name_in_use(self, name: str) -> bool: """Return whether *name* is already reserved in this scope or any parent.""" if name in self.names: return True if self.parent_scope is None: return False return self.parent_scope.is_name_in_use(name)
[docs] def is_name_reserved_function_arg(self, name: str) -> bool: """Return whether *name* is reserved for use as a function argument.""" if name in self._function_arg_reserved_names: return True if self.parent_scope is None: return False return self.parent_scope.is_name_reserved_function_arg(name)
[docs] def is_name_reserved(self, name: str) -> bool: """Return whether *name* is reserved for any purpose in this scope.""" return self.is_name_in_use(name) or self.is_name_reserved_function_arg(name)
[docs] def reserve_name( self, requested: str, function_arg: bool = False, is_builtin: bool = False, ): """ Reserve a name as being in use in a scope. Pass function_arg=True if this is a function argument. """ def _add(final: str): self.names.add(final) return final if function_arg: if self.is_name_reserved_function_arg(requested): assert not self.is_name_in_use(requested) return _add(requested) if self.is_name_reserved(requested): raise AssertionError(f"Cannot use '{requested}' as argument name as it is already in use") cleaned = cleanup_name(requested) attempt = cleaned count = 2 # instance without suffix is regarded as 1 # To avoid shadowing of global names in local scope, we # take into account parent scope when assigning names. def _is_name_allowed(name: str) -> bool: # We need to also protect against using keywords ('class', 'def' etc.) # i.e. count all keywords as 'used'. # However, some builtins are also keywords (e.g. 'None'), and so # if a builtin is being reserved, don't check against the keyword list if (not is_builtin) and keyword.iskeyword(name): return False return not self.is_name_reserved(name) while not _is_name_allowed(attempt): attempt = cleaned + "_" + str(count) count += 1 return _add(attempt)
[docs] def reserve_function_arg_name(self, name: str): """ Reserve a name for *later* use as a function argument. This does not result in that name being considered 'in use' in the current scope, but will avoid the name being assigned for any use other than as a function argument. """ # To keep things simple, and the generated code predictable, we reserve # names for all function arguments in a separate scope, and insist on # the exact names if self.is_name_reserved(name): raise AssertionError(f"Can't reserve '{name}' as function arg name as it is already reserved") self._function_arg_reserved_names.add(name)
[docs] def has_assignment(self, name: str) -> bool: """Return whether *name* has been assigned a value in this scope.""" return name in self._assignments
[docs] def register_assignment(self, name: str) -> None: """Record that *name* has been assigned a value in this scope.""" self._assignments.add(name)
[docs] def create_name(self, name: str) -> Name: """Reserve *name* (or a variant) and return a :class:`Name` expression for it.""" reserved = self.reserve_name(name) return Name(reserved, self)
[docs] def name(self, name: str) -> Name: """Return a :class:`Name` expression for an already-reserved name.""" return Name(name, self)
[docs] @cached_property def enames(self) -> Enames: """ Returns an Enames object, which provides easy access to names as E-objects. """ return Enames(self)
_IDENTIFIER_SANITIZER_RE = re.compile("[^a-zA-Z0-9_]") _IDENTIFIER_START_RE = re.compile("^[a-zA-Z_]")
[docs] def cleanup_name(name: str) -> str: """ Convert name to a allowable identifier """ # See https://docs.python.org/2/reference/lexical_analysis.html#grammar-token-identifier name = _IDENTIFIER_SANITIZER_RE.sub("", name) if not _IDENTIFIER_START_RE.match(name): name = "n" + name return name
[docs] class Statement(CodeGenAst): """Base class for code-generation nodes that represent Python statements.""" pass
[docs] class Annotation(Statement): """A bare type annotation without a value, e.g. ``x: int``.""" def __init__(self, name: str, annotation: Expression): self.name = name self.annotation = annotation
[docs] def as_ast(self, *, include_comments: bool = False): if not allowable_name(self.name): raise AssertionError(f"Expected {self.name} to be a valid Python identifier") return py_ast.AnnAssign( target=py_ast.Name(id=self.name, ctx=py_ast.Store(), **DEFAULT_AST_ARGS), annotation=self.annotation.as_ast(), simple=1, value=None, **DEFAULT_AST_ARGS, )
type AugOp = Literal[ "+=", "-=", "*=", "/=", "//=", "%=", "**=", "@=", "<<=", ">>=", "|=", "&=", "^=", ] """Operator strings accepted by :class:`AugAssignment` and :meth:`Block.aug_assign`.""" _AUG_OP_MAP: dict[str, type[py_ast.operator]] = { "+=": py_ast.Add, "-=": py_ast.Sub, "*=": py_ast.Mult, "/=": py_ast.Div, "//=": py_ast.FloorDiv, "%=": py_ast.Mod, "**=": py_ast.Pow, "@=": py_ast.MatMult, "<<=": py_ast.LShift, ">>=": py_ast.RShift, "|=": py_ast.BitOr, "^=": py_ast.BitXor, "&=": py_ast.BitAnd, } type AugAssignTarget = Name | Attr | Subscript """Valid targets for augmented assignment (no tuples).""" def _is_aug_assign_target(value: object) -> TypeIs[AugAssignTarget]: return isinstance(value, (Name, Attr, Subscript)) class AugAssignment(Statement): """An augmented assignment statement, e.g. ``x += 1``.""" def __init__(self, target: AugAssignTarget, op: AugOp, value: Expression, /): if not _is_aug_assign_target(target): raise AssertionError( f"Invalid augmented assignment target: {type(target).__name__}. " "Only Name, Attr, and Subscript are allowed (not tuples)." ) if op not in _AUG_OP_MAP: raise AssertionError( f"Invalid augmented assignment operator: {op!r}. Must be one of: {', '.join(sorted(_AUG_OP_MAP))}" ) self.target = target self.op: AugOp = op self.value = value def as_ast(self, *, include_comments: bool = False): target_ast = _aug_target_as_store_ast(self.target) return py_ast.AugAssign( target=target_ast, op=_AUG_OP_MAP[self.op](**DEFAULT_AST_ARGS_ADD), value=self.value.as_ast(), **DEFAULT_AST_ARGS, )
[docs] class Assignment(Statement): """A variable assignment statement, optionally with a type annotation.""" def __init__(self, target: Target, value: Expression, /, *, type_hint: Expression | None = None): if not is_target(target): raise AssertionError("Invalid assignment target") self.names: list[str] = _target_names(target) self.target: Target = target if type_hint is not None and not isinstance(target, Name): raise AssertionError("Type hints are only supported for simple name assignment targets") self.value = value self.type_hint = type_hint
[docs] def as_ast(self, *, include_comments: bool = False): if self.type_hint is None: return py_ast.Assign( targets=[_target_as_store_ast(self.target)], value=self.value.as_ast(), **DEFAULT_AST_ARGS, ) else: # type_hint is only allowed for Name targets (enforced in __init__) assert isinstance(self.target, Name) target_ast = _target_as_store_ast(self.target) assert isinstance(target_ast, py_ast.Name) return py_ast.AnnAssign( target=target_ast, annotation=self.type_hint.as_ast(), simple=1, value=self.value.as_ast(), **DEFAULT_AST_ARGS, )
[docs] class Block(CodeGenAstList): """An ordered sequence of statements sharing a common :class:`Scope`.""" def __init__(self, scope: Scope, parent_block: Block | None = None): self.scope = scope # We allow `Expression` here for things like MethodCall which # are bare expressions that are still useful for side effects. # `Comment` (str) entries are rendered as ``# text`` comments. self.statements: list[Block | Statement | Expression | Comment] = [] self.parent_block = parent_block
[docs] def as_ast_list(self, allow_empty: bool = True, *, include_comments: bool = False) -> list[py_ast.stmt]: retval: list[py_ast.stmt] = [] for s in self.statements: if isinstance(s, str): # Comment if include_comments: retval.append(CommentNode(s, **DEFAULT_AST_ARGS)) # type: ignore[reportArgumentType] continue if isinstance(s, CodeGenAstList): retval.extend(s.as_ast_list(allow_empty=True, include_comments=include_comments)) elif isinstance(s, Statement): ast_obj = s.as_ast(include_comments=include_comments) assert isinstance(ast_obj, py_ast.stmt), ( "Statement object return {ast_obj} which is not a subclass of py_ast.stmt" ) retval.append(ast_obj) else: # Things like bare function/method calls need to be wrapped # in `Expr` to match the way Python parses. retval.append(py_ast.Expr(s.as_ast(include_comments=include_comments), **DEFAULT_AST_ARGS)) if len(retval) == 0 and not allow_empty: return [py_ast.Pass(**DEFAULT_AST_ARGS)] return retval
[docs] def add_comment(self, text: str, *, wrap: int | None = None) -> None: """Add a ``# text`` comment line at the current position in the block. If *wrap* is given as an integer, long lines are wrapped at word boundaries so that no comment line exceeds *wrap* characters (including the ``#`` prefix and space). """ if wrap is not None: # Account for the "# " prefix (2 chars) that will be added during rendering. effective_width = max(wrap - 2, 1) lines = textwrap.wrap(text, width=effective_width) if not lines: # Empty or whitespace-only text – preserve as a single blank comment. lines = [""] for line in lines: self.statements.append(line) else: self.statements.append(text)
[docs] def add_statement(self, statement: Statement | Block | ExpressionLike) -> None: """Append a statement, block, or bare expression to this block.""" if isinstance(statement, E): statement = E_to_Expression(statement) self.statements.append(statement) if isinstance(statement, Block): if statement.parent_block is None: statement.parent_block = self else: if statement.parent_block != self: raise AssertionError( f"Block {statement} is already child of {statement.parent_block}, can't reassign to {self}" )
[docs] def add_statements(self, statements: Sequence[Statement | Block | ExpressionLike]) -> None: """Append multiple statements to this block. This is a convenience wrapper around :meth:`add_statement`:: block.add_statements([stmt_a, stmt_b, stmt_c]) # equivalent to: block.add_statement(stmt_a) block.add_statement(stmt_b) block.add_statement(stmt_c) """ for statement in statements: self.add_statement(statement)
[docs] def create_import(self, module: str, as_: str | None = None) -> tuple[Import, Name]: """Create an ``import`` statement, reserve the resulting name, and add it to this block.""" return_name_object: Name if as_ is not None: # "import foo as bar" results in `bar` name being assigned. if not allowable_name(as_): raise AssertionError(f"{as_!r} is not an allowable 'as' name") if self.scope.is_name_in_use(as_): raise AssertionError(f"{as_!r} is already assigned in the scope") as_name_object = self.scope.create_name(as_) return_name_object = as_name_object else: as_name_object = None # "import foo" results in `foo` name being assigned # "import foo.bar" also results in `foo` being reserved. dotted_parts = module.split(".") for part in dotted_parts: if not allowable_name(part): raise AssertionError(f"{module!r} not an allowable 'import' name") name_to_assign = dotted_parts[0] # We can't rename, so don't use `reserve_name` or `create_name`. # We also need to allow for multiple imports, like `import foo.bar` then `import foo.baz` if not self.scope.is_name_in_use(name_to_assign): self.scope.reserve_name(name_to_assign) return_name_object = self.scope.name(name_to_assign) import_statement = Import(module=module, as_=as_name_object) self.add_statement(import_statement) return import_statement, return_name_object
[docs] def create_import_from(self, *, from_: str, import_: str, as_: str | None = None) -> tuple[ImportFrom, Name]: """Create a ``from ... import`` statement, reserve the resulting name, and add it to this block.""" return_name_object: Name if as_ is not None: # "from foo import bar as baz" results in `baz` name being assigned. if not allowable_name(as_): raise AssertionError(f"{as_!r} is not an allowable 'as' name") if self.scope.is_name_in_use(as_): raise AssertionError(f"{as_!r} is already assigned in the scope") as_name_object = self.scope.create_name(as_) return_name_object = as_name_object else: as_name_object = None # Check the dotted bit. dotted_parts = from_.split(".") for part in dotted_parts: if not allowable_name(part): raise AssertionError(f"{from_!r} not an allowable 'import' name") # Check the `import_` for clashes. name_to_assign = import_ if self.scope.is_name_in_use(name_to_assign): raise AssertionError(f"{name_to_assign!r} is already assigned in the scope") return_name_object = self.scope.create_name(name_to_assign) import_statement = ImportFrom(from_module=from_, import_=import_, as_=as_name_object) self.add_statement(import_statement) return import_statement, return_name_object
# Safe alternatives to Block.statements being manipulated directly:
[docs] def create_assignment( self, target: str | Target, value: ExpressionLike, *, type_hint: ExpressionLike | None = None, allow_multiple: bool = False, ): """ Adds an assigment of the form: x = value or more complex like: x[0] = value x, y = value """ if isinstance(target, str): if not self.scope.is_name_in_use(target): raise AssertionError(f"Cannot assign to unreserved name '{target}'") target = self.scope.name(target) names = _target_names(target) for name in names: if self.scope.has_assignment(name): if not allow_multiple: raise AssertionError(f"Have already assigned to '{name}' in this scope") self.scope.register_assignment(name) self.add_statement( Assignment( target, E_to_Expression(value), type_hint=E_to_Expression(type_hint) if type_hint is not None else None ) )
@overload def assign(self, target: str, value: ExpressionLike, *, type_hint: ExpressionLike | None = ...) -> Name: ... @overload def assign(self, target: tuple[str, ...], value: ExpressionLike) -> tuple[Name, ...]: ...
[docs] def assign( self, target: str | tuple[str, ...], value: ExpressionLike, *, type_hint: ExpressionLike | None = None, ) -> Name | tuple[Name, ...]: """ Shortcut that reserves names and creates an assignment in one step. When *target* is a single ``str``, reserves the name and assigns to it, returning the new :class:`Name`:: result = block.assign("x", some_expr) # equivalent to: # x = scope.create_name("x") # block.create_assignment(x, some_expr) When *target* is a ``tuple`` of ``str``, reserves each name and creates a tuple-unpacking assignment, returning a tuple of :class:`Name` objects:: a, b = block.assign(("a", "b"), some_pair_expr) # equivalent to: # a = scope.create_name("a") # b = scope.create_name("b") # block.create_assignment((a, b), some_pair_expr) """ if isinstance(target, str): name_obj = self.scope.create_name(target) self.create_assignment(name_obj, value, type_hint=type_hint) return name_obj else: if type_hint is not None: raise AssertionError("Can't use type hint with tuple unpacking assignment") name_objs = tuple(self.scope.create_name(t) for t in target) self.create_assignment(name_objs, value) return name_objs
[docs] def aug_assign(self, target: AugAssignTarget, op: AugOp, value: ExpressionLike, /) -> None: """Add an augmented assignment statement to this block. Usage:: x = block.assign("x", auto(0)) block.aug_assign(x, "+=", auto(1)) # x += 1 *target* must be a :class:`Name`, :class:`Attr`, or :class:`Subscript` (tuples are not valid Python augmented-assignment targets). *op* is one of the Python augmented-assignment operator strings: - ``"+="`` - ``"-="`` - ``"*="`` - ``"/="`` - ``"//="`` - ``"%="`` - ``"**="`` - ``"@="`` - ``"<<="`` - ``">>="`` - ``"|="`` - ``"&="`` - ``"^="`` """ self.add_statement(AugAssignment(target, op, E_to_Expression(value)))
[docs] def create_annotation(self, name: str, annotation: ExpressionLike) -> Name: """ Adds a bare type annotation of the form:: x: int Reserves the name and adds the annotation statement to the block. """ name_obj = self.scope.create_name(name) self.scope.register_assignment(name_obj.name) self.add_statement(Annotation(name_obj.name, E_to_Expression(annotation))) return name_obj
[docs] def create_field(self, name: str, annotation: ExpressionLike, *, default: ExpressionLike | None = None) -> Name: """ Create a typed field, typically used in dataclass bodies. If *default* is provided, creates an annotated assignment:: x: int = 0 Otherwise, creates a bare annotation:: x: int """ if default is not None: name_obj = self.scope.create_name(name) self.scope.register_assignment(name_obj.name) self.add_statement(Assignment(name_obj, E_to_Expression(default), type_hint=E_to_Expression(annotation))) return name_obj else: return self.create_annotation(name, annotation)
[docs] def create_function( self, name: str, args: Sequence[str | FunctionArg], decorators: Sequence[ExpressionLike] | None = None, return_type: ExpressionLike | None = None, ) -> tuple[Function, Name]: """ Reserve a name for a function, create the Function and add the function statement to the block. """ name_obj = self.scope.create_name(name) func = Function( name_obj.name, args=args, parent_scope=self.scope, decorators=[E_to_Expression(d) for d in decorators] if decorators is not None else None, return_type=E_to_Expression(return_type) if return_type is not None else None, ) self.add_statement(func) return func, name_obj
[docs] def create_class( self, name: str, bases: Sequence[ExpressionLike] | None = None, decorators: Sequence[ExpressionLike] | None = None, ) -> tuple[Class, Name]: """ Reserve a name for a class, create the Class and add the class statement to the block. """ name_obj = self.scope.create_name(name) cls = Class( name_obj.name, parent_scope=self.scope, bases=[E_to_Expression(b) for b in bases] if bases is not None else None, decorators=[E_to_Expression(d) for d in decorators] if decorators is not None else None, ) self.add_statement(cls) return cls, name_obj
[docs] def create_return(self, value: ExpressionLike) -> None: """Add a ``return`` statement to this block.""" self.add_statement(Return(E_to_Expression(value)))
[docs] def create_break(self) -> None: """Add a ``break`` statement to this block.""" self.add_statement(Break())
[docs] def create_continue(self) -> None: """Add a ``continue`` statement to this block.""" self.add_statement(Continue())
[docs] def create_assert(self, test: ExpressionLike, msg: ExpressionLike | None = None) -> None: """Add an ``assert`` statement to this block.""" self.add_statement(Assert(E_to_Expression(test), E_to_Expression(msg) if msg is not None else None))
[docs] def create_raise(self, exc: ExpressionLike | None = None, cause: ExpressionLike | None = None) -> None: """Add a ``raise`` statement to this block.""" self.add_statement( Raise( E_to_Expression(exc) if exc is not None else None, E_to_Expression(cause) if cause is not None else None, ) )
[docs] def create_if(self) -> If: """ Create an If statement, add it to this block, and return it. Usage:: if_stmt = block.create_if() if_block = if_stmt.add_if(condition) if_block.create_return(value) """ if_statement = If(self.scope, parent_block=self) self.add_statement(if_statement) return if_statement
@overload def create_with(self, context_expr: ExpressionLike, target: Name | str) -> tuple[With, Name]: ... @overload def create_with(self, context_expr: ExpressionLike) -> With: ...
[docs] def create_with(self, context_expr: ExpressionLike, target: Name | str | None = None) -> With | tuple[With, Name]: """ Create a With statement, add it to this block, and return it Usage:: with_stmt, target = block.create_with(expr, "f") with_stmt.body.create_return(value) If target is a str, the name will be reserved. If target is `None`, only the with_statement will be returned. """ if isinstance(target, str): name_obj = self.scope.create_name(target) target = name_obj with_statement = With(E_to_Expression(context_expr), target=target, parent_scope=self.scope, parent_block=self) self.add_statement(with_statement) if target is None: return with_statement return with_statement, target
@overload def create_for(self, target: str, iterable: ExpressionLike) -> tuple[For, Name]: ... @overload def create_for(self, target: tuple[str, ...], iterable: ExpressionLike) -> tuple[For, tuple[Name, ...]]: ... @overload def create_for(self, target: tuple[Name, ...], iterable: ExpressionLike) -> tuple[For, tuple[Name, ...]]: ... @overload def create_for(self, target: Target, iterable: ExpressionLike) -> tuple[For, Target]: ...
[docs] def create_for(self, target: str | tuple[str, ...] | Target, iterable: ExpressionLike) -> tuple[For, Target]: """ Create a ``for`` loop, add it to this block, and return it. The first parameter is the loop variable. If this is a str or tuple[str] then these names will reserved and Name objects created, similar to `assign`. The second parameter is an expression that will be iterated over. Usage:: for_stmt, index = func.body.create_for("i", items) for_stmt.body.add_statement(some_expr) """ target = _normalize_targets(self.scope, target) for_statement = For(target, E_to_Expression(iterable), parent_scope=self.scope, parent_block=self) self.add_statement(for_statement) return for_statement, target
[docs] def create_try(self) -> Try: """ Create a Try statement, add it to this block, and return it. Add ``except`` clauses via :meth:`Try.create_except`. Usage:: try_stmt = block.create_try() try_stmt.try_block.add_statement(some_expr) except_block, e_name = try_stmt.create_except([my_error], "e") except_block.create_return(value) """ try_statement = Try( self.scope, parent_block=self, ) self.add_statement(try_statement) return try_statement
def _normalize_targets(scope: Scope, target: str | tuple[str, ...] | Target) -> Target: if isinstance(target, str): name_obj = scope.create_name(target) target = name_obj elif isinstance(target, tuple): target = tuple([_normalize_targets(scope, t) for t in target]) else: assert is_target(target), f"Expected str, tuple or Target, got {type(target)}" return target
[docs] class Module(Block, CodeGenAst): """Top-level module block with its own :class:`Scope` pre-loaded with builtins.""" def __init__(self, reserve_builtins: bool = True): scope = Scope(parent_scope=None) if reserve_builtins: for name in dir(builtins): scope.reserve_name(name, is_builtin=True) Block.__init__(self, scope)
[docs] def as_ast(self, *, include_comments: bool = False) -> py_ast.Module: return py_ast.Module( body=self.as_ast_list(include_comments=include_comments), type_ignores=[], **DEFAULT_AST_ARGS_MODULE )
[docs] def as_python_source(self) -> str: mod = self.as_ast(include_comments=True) py_ast.fix_missing_locations(mod) return unparse_with_comments(mod)
[docs] @cached_property def enames(self) -> Enames: """ Provides access to scope names as E-objects e.g. Module().enames.str for the `str` Name that is registered as a builtin. """ return self.scope.enames
[docs] class ArgKind(enum.Enum): """The kind of a function argument.""" POSITIONAL_ONLY = "positional_only" POSITIONAL_OR_KEYWORD = "positional_or_keyword" KEYWORD_ONLY = "keyword_only"
[docs] @dataclass(frozen=True) class FunctionArg: """A function argument with a name, kind, and optional default value.""" name: str kind: ArgKind = ArgKind.POSITIONAL_OR_KEYWORD default: Expression | None = None annotation: Expression | None = None
[docs] @classmethod def positional( cls, name: str, *, default: ExpressionLike | None = None, annotation: ExpressionLike | None = None ) -> FunctionArg: """Create a positional-only argument.""" return cls( name=name, kind=ArgKind.POSITIONAL_ONLY, default=E_to_Expression(default) if default is not None else None, annotation=E_to_Expression(annotation) if annotation is not None else None, )
[docs] @classmethod def keyword( cls, name: str, *, default: ExpressionLike | None = None, annotation: ExpressionLike | None = None ) -> FunctionArg: """Create a keyword-only argument.""" return cls( name=name, kind=ArgKind.KEYWORD_ONLY, default=E_to_Expression(default) if default is not None else None, annotation=E_to_Expression(annotation) if annotation is not None else None, )
[docs] @classmethod def standard( cls, name: str, *, default: ExpressionLike | None = None, annotation: ExpressionLike | None = None ) -> FunctionArg: """Create a positional-or-keyword argument (the Python default).""" return cls( name=name, kind=ArgKind.POSITIONAL_OR_KEYWORD, default=E_to_Expression(default) if default is not None else None, annotation=E_to_Expression(annotation) if annotation is not None else None, )
def _normalize_args(args: Sequence[str | FunctionArg]) -> list[FunctionArg]: """Normalize a mixed list of str and FunctionArg into a list of FunctionArg.""" return [FunctionArg(name=a) if isinstance(a, str) else a for a in args] def _validate_arg_order(args: list[FunctionArg]) -> None: """Validate that args are in the correct order: positional-only, then positional-or-keyword, then keyword-only. Within each group, defaults must come after non-defaults. """ # Check kind ordering KIND_ORDER = { ArgKind.POSITIONAL_ONLY: 0, ArgKind.POSITIONAL_OR_KEYWORD: 1, ArgKind.KEYWORD_ONLY: 2, } prev_order = -1 for arg in args: order = KIND_ORDER[arg.kind] if order < prev_order: raise ValueError( f"Argument '{arg.name}' of kind {arg.kind.value} " f"is out of order: positional-only args must come first, " f"then positional-or-keyword, then keyword-only" ) prev_order = order # Check default ordering within positional groups # (positional-only and positional-or-keyword share defaults list, # so non-default can't follow default across these groups) seen_default_in_positional = False for arg in args: if arg.kind in (ArgKind.POSITIONAL_ONLY, ArgKind.POSITIONAL_OR_KEYWORD): if arg.default is not None: seen_default_in_positional = True elif seen_default_in_positional: raise ValueError(f"Non-default argument '{arg.name}' follows default argument in positional arguments") # keyword-only args can have defaults in any order (Python allows it)
[docs] class Function(Scope, Statement): """A function definition statement that also acts as a :class:`Scope` for its body.""" def __init__( self, name: str, args: Sequence[str | FunctionArg] | None = None, parent_scope: Scope | None = None, decorators: Sequence[Expression] | None = None, return_type: Expression | None = None, ): super().__init__(parent_scope=parent_scope) self.body = Block(self) self.func_name = name self.decorators: list[Expression] = list(decorators) if decorators else [] self.return_type: Expression | None = return_type self._args: list[FunctionArg] = [] if args is not None: self.add_args(args) @property def args(self) -> Sequence[FunctionArg]: """Return the function's arguments as a read-only sequence.""" return tuple(self._args)
[docs] def add_args(self, args: Sequence[str | FunctionArg]) -> None: """Add arguments to the function, with the same validation as in __init__.""" normalized = _normalize_args(args) combined = self._args + normalized _validate_arg_order(combined) for arg in normalized: if self.is_name_in_use(arg.name): raise AssertionError(f"Can't use '{arg.name}' as function argument name because it shadows other names") self.reserve_name(arg.name, function_arg=True) self._args = combined
[docs] def as_ast(self, *, include_comments: bool = False) -> py_ast.stmt: if not allowable_name(self.func_name): raise AssertionError(f"Expected '{self.func_name}' to be a valid Python identifier") arguments = _function_arg_list_to_ast_arguments(self._args) return py_ast.FunctionDef( name=self.func_name, args=arguments, body=self.body.as_ast_list(allow_empty=False, include_comments=include_comments), decorator_list=[d.as_ast() for d in self.decorators], type_params=[], returns=self.return_type.as_ast() if self.return_type is not None else None, **DEFAULT_AST_ARGS, )
[docs] def create_return(self, value: ExpressionLike): """Add a ``return`` statement to the function body.""" self.body.create_return(E_to_Expression(value))
def _function_arg_list_to_ast_arguments(args: Sequence[FunctionArg]): for arg in args: if not allowable_name(arg.name): raise AssertionError(f"Expected '{arg.name}' to be a valid Python identifier") def _make_arg(a: FunctionArg) -> py_ast.arg: return py_ast.arg( arg=a.name, annotation=a.annotation.as_ast() if a.annotation is not None else None, **DEFAULT_AST_ARGS, ) posonlyargs = [_make_arg(a) for a in args if a.kind == ArgKind.POSITIONAL_ONLY] regular_args = [_make_arg(a) for a in args if a.kind == ArgKind.POSITIONAL_OR_KEYWORD] kwonlyargs = [_make_arg(a) for a in args if a.kind == ArgKind.KEYWORD_ONLY] # defaults: right-aligned to posonlyargs + regular_args positional_all = [a for a in args if a.kind in (ArgKind.POSITIONAL_ONLY, ArgKind.POSITIONAL_OR_KEYWORD)] defaults = [a.default.as_ast() for a in positional_all if a.default is not None] # kw_defaults: one entry per kwonlyarg, None if no default kw_defaults: list[py_ast.expr | None] = [ a.default.as_ast() if a.default is not None else None for a in args if a.kind == ArgKind.KEYWORD_ONLY ] return py_ast.arguments( posonlyargs=posonlyargs, args=regular_args, vararg=None, kwonlyargs=kwonlyargs, kw_defaults=kw_defaults, kwarg=None, defaults=defaults, **DEFAULT_AST_ARGS_ARGUMENTS, )
[docs] class Class(Scope, Statement): """A class definition statement that also acts as a :class:`Scope` for its body.""" def __init__( self, name: str, parent_scope: Scope | None = None, bases: Sequence[Expression] | None = None, decorators: Sequence[Expression] | None = None, ): super().__init__(parent_scope=parent_scope) self.body = Block(self) self.class_name = name self.bases: list[Expression] = list(bases) if bases else [] self.decorators: list[Expression] = list(decorators) if decorators else []
[docs] def as_ast(self, *, include_comments: bool = False) -> py_ast.stmt: if not allowable_name(self.class_name): raise AssertionError(f"Expected '{self.class_name}' to be a valid Python identifier") return py_ast.ClassDef( name=self.class_name, bases=[b.as_ast() for b in self.bases], keywords=[], body=self.body.as_ast_list(allow_empty=False, include_comments=include_comments), decorator_list=[d.as_ast() for d in self.decorators], type_params=[], **DEFAULT_AST_ARGS, )
[docs] class Return(Statement): """A ``return`` statement.""" def __init__(self, value: Expression): self.value = value
[docs] def as_ast(self, *, include_comments: bool = False): return py_ast.Return(self.value.as_ast(), **DEFAULT_AST_ARGS)
def __repr__(self): return f"Return({repr(self.value)}"
[docs] class Break(Statement): """A ``break`` statement."""
[docs] def as_ast(self, *, include_comments: bool = False): return py_ast.Break(**DEFAULT_AST_ARGS)
[docs] class Continue(Statement): """A ``continue`` statement."""
[docs] def as_ast(self, *, include_comments: bool = False): return py_ast.Continue(**DEFAULT_AST_ARGS)
[docs] class Assert(Statement): """An ``assert`` statement with an optional message.""" def __init__(self, test: Expression, msg: Expression | None = None): self.test = test self.msg = msg
[docs] def as_ast(self, *, include_comments: bool = False): msg_ast = self.msg.as_ast() if self.msg is not None else None return py_ast.Assert( test=self.test.as_ast(), msg=msg_ast, **DEFAULT_AST_ARGS, )
def __repr__(self): return f"Assert({repr(self.test)}, {repr(self.msg)})"
[docs] class Raise(Statement): """A ``raise`` statement. Supports: * ``raise exc`` * ``raise exc from cause`` * bare ``raise`` (re-raise the current exception) """ def __init__(self, exc: Expression | None = None, cause: Expression | None = None): if cause is not None and exc is None: raise AssertionError("Cannot use 'cause' without 'exc'") self.exc = exc self.cause = cause
[docs] def as_ast(self, *, include_comments: bool = False): return py_ast.Raise( exc=self.exc.as_ast() if self.exc is not None else None, cause=self.cause.as_ast() if self.cause is not None else None, **DEFAULT_AST_ARGS, )
def __repr__(self): return f"Raise({repr(self.exc)}, {repr(self.cause)})"
[docs] class If(Statement): """A compound ``if``/``elif``/``else`` statement.""" def __init__(self, parent_scope: Scope, parent_block: Block | None = None): # We model a "compound if statement" as a list of if blocks # (if/elif/elif etc), each with their own condition, with a final else # block. Note this is quite different from Python's AST for the same # thing, so conversion to AST is more complex because of this. self.if_blocks: list[Block] = [] self.conditions: list[Expression] = [] self._parent_block = parent_block self.else_block = Block(parent_scope, parent_block=self._parent_block) self._parent_scope = parent_scope
[docs] def create_if_branch(self, condition: ExpressionLike) -> Block: """ Create new if branch with a condition. """ new_if = Block(self._parent_scope, parent_block=self._parent_block) self.if_blocks.append(new_if) self.conditions.append(E_to_Expression(condition)) return new_if
[docs] def finalize(self) -> Block | Statement: """Return a simplified node: the else block if there are no conditions, otherwise *self*.""" if not self.if_blocks: # Unusual case of no conditions, only default case, but it # simplifies other code to be able to handle this uniformly. We can # replace this if statement with a single unconditional block. return self.else_block return self
[docs] def as_ast(self, *, include_comments: bool = False) -> py_ast.If: if len(self.if_blocks) == 0: raise AssertionError("Should have called `finalize` on If") if_ast = empty_If() current_if = if_ast previous_if = None for condition, if_block in zip(self.conditions, self.if_blocks): current_if.test = condition.as_ast() current_if.body = if_block.as_ast_list(include_comments=include_comments) if previous_if is not None: previous_if.orelse.append(current_if) previous_if = current_if current_if = empty_If() if self.else_block.statements: assert previous_if is not None previous_if.orelse = self.else_block.as_ast_list(include_comments=include_comments) return if_ast
[docs] class With(Statement): """A ``with`` statement.""" def __init__( self, context_expr: Expression, target: Name | None = None, *, parent_scope: Scope, parent_block: Block | None = None, ): self.context_expr = context_expr self.target = target self._parent_scope = parent_scope self._parent_block = parent_block self.body = Block(parent_scope, parent_block=parent_block)
[docs] def as_ast(self, *, include_comments: bool = False) -> py_ast.With: optional_vars = None if self.target is not None: optional_vars = py_ast.Name(id=self.target.name, ctx=py_ast.Store(), **DEFAULT_AST_ARGS) return py_ast.With( items=[ py_ast.withitem( context_expr=self.context_expr.as_ast(), optional_vars=optional_vars, ) ], body=self.body.as_ast_list(allow_empty=False, include_comments=include_comments), **DEFAULT_AST_ARGS, )
[docs] class For(Statement): """A ``for`` loop, with optional ``else`` clause.""" def __init__( self, target: Target, iterable: Expression, *, parent_scope: Scope, parent_block: Block | None = None, ): if not is_target(target): raise AssertionError("Invalid for-loop target") self.target: Target = target self.iterable = iterable self._parent_scope = parent_scope self._parent_block = parent_block self.body = Block(parent_scope, parent_block=parent_block) self.else_block = Block(parent_scope, parent_block=parent_block)
[docs] def as_ast(self, *, include_comments: bool = False) -> py_ast.For: return py_ast.For( target=_target_as_store_ast(self.target), iter=self.iterable.as_ast(), body=self.body.as_ast_list(allow_empty=False, include_comments=include_comments), orelse=self.else_block.as_ast_list(allow_empty=True, include_comments=include_comments), **DEFAULT_AST_ARGS, )
[docs] class Try(Statement): """A ``try``/``except``/``else``/``finally`` statement. Except clauses are added incrementally via :meth:`create_except`, similar to how :meth:`If.create_if_branch` works. """ def __init__( self, parent_scope: Scope, *, parent_block: Block | None = None, ): self._parent_scope = parent_scope self._parent_block = parent_block self.try_block = Block(parent_scope, parent_block=parent_block) self.except_blocks: list[Block] = [] self.except_types: list[list[Expression]] = [] self.except_names: list[str | None] = [] self.else_block = Block(parent_scope, parent_block=parent_block) self.finally_block = Block(parent_scope, parent_block=parent_block) @overload def create_except( self, catch_exceptions: Sequence[ExpressionLike], *, name: str | Name, ) -> tuple[Block, Name]: ... @overload def create_except(self, catch_exceptions: Sequence[ExpressionLike]) -> Block: ...
[docs] def create_except( self, catch_exceptions: Sequence[ExpressionLike], *, name: str | Name | None = None, ) -> Block | tuple[Block, Name]: """ Add an ``except`` clause and return its body block. *catch_exceptions* is the list of exception types to catch (a single-element list produces ``except Foo:``, multiple elements produce ``except (Foo, Bar):``). *name*, if given, becomes the ``as`` target (``except Foo as name:``). If it is passed as a `str` it is reserved in the scope. It is returned as a Name object. """ block = Block(self._parent_scope, parent_block=self._parent_block) self.except_blocks.append(block) self.except_types.append([E_to_Expression(e) for e in catch_exceptions]) # Normalise `name` to `str | None` for appending to `except_names` # and `name_obj` to a `Name | None` for returning if isinstance(name, str): name_obj = self._parent_scope.create_name(name) elif isinstance(name, Name): name_obj = name name = name_obj.name else: name_obj = None self.except_names.append(name) if name_obj is not None: return block, name_obj return block
def _handler_type_ast(self, exceptions: list[Expression]) -> py_ast.expr: if len(exceptions) == 1: return exceptions[0].as_ast() return py_ast.Tuple( elts=[e.as_ast() for e in exceptions], ctx=py_ast.Load(), **DEFAULT_AST_ARGS, )
[docs] def as_ast(self, *, include_comments: bool = False) -> py_ast.Try: return py_ast.Try( body=self.try_block.as_ast_list(allow_empty=False, include_comments=include_comments), handlers=[ py_ast.ExceptHandler( type=self._handler_type_ast(exc_types), name=exc_name, body=exc_block.as_ast_list(allow_empty=False, include_comments=include_comments), **DEFAULT_AST_ARGS, ) for exc_types, exc_name, exc_block in zip(self.except_types, self.except_names, self.except_blocks) ], orelse=self.else_block.as_ast_list(allow_empty=True, include_comments=include_comments), finalbody=self.finally_block.as_ast_list(allow_empty=True, include_comments=include_comments), **DEFAULT_AST_ARGS, )
[docs] class Import(Statement): """ Simple import statements, supporting: - import foo - import foo as bar - import foo.bar - import foo.bar as baz Use via `Block.create_import` We deliberately don't support multiple imports - these should be cleaned up later using a linter on the generated code, if desired. """ def __init__(self, module: str, as_: Name | None) -> None: self.module = module self.as_ = as_
[docs] def as_ast(self, *, include_comments: bool = False) -> py_ast.AST: if self.as_ is None: # No alias needed: return py_ast.Import(names=[py_ast.alias(name=self.module)], **DEFAULT_AST_ARGS) else: return py_ast.Import(names=[py_ast.alias(name=self.module, asname=self.as_.name)], **DEFAULT_AST_ARGS)
[docs] class ImportFrom(Statement): """ ``from ... import`` statement, supporting: - ``from foo import bar`` - ``from foo import bar as baz`` Use via `Block.create_import_from` We deliberately don't support multiple imports - these should be cleaned up later using a linter on the generated code, if desired. """ def __init__(self, from_module: str, import_: str, as_: Name | None) -> None: self.from_module = from_module self.import_ = import_ self.as_ = as_
[docs] def as_ast(self, *, include_comments: bool = False) -> py_ast.AST: if self.as_ is None: # No alias needed: return py_ast.ImportFrom( module=self.from_module, names=[py_ast.alias(name=self.import_)], level=0, **DEFAULT_AST_ARGS, ) else: return py_ast.ImportFrom( module=self.from_module, names=[py_ast.alias(name=self.import_, asname=self.as_.name)], level=0, **DEFAULT_AST_ARGS, )
[docs] class Expression(CodeGenAst): """Base class for code-generation nodes that represent Python expressions."""
[docs] @abstractmethod def as_ast(self, *, include_comments: bool = False) -> py_ast.expr: ...
# Conversion:
[docs] @classmethod def from_e(cls, obj: Expression | E) -> Expression: return E_to_Expression(obj)
# Some utilities for easy chaining:
[docs] def attr(self, attribute: str, /) -> Attr: """Return an :class:`Attr` expression for accessing *attribute* on this expression.""" return Attr(self, attribute)
[docs] def call( self, args: Sequence[ExpressionLike] | None = None, kwargs: Mapping[str, ExpressionLike] | None = None, ) -> Call: """Return a :class:`Call` expression invoking this expression.""" return Call( self, [E_to_Expression(arg) for arg in args] if args is not None else [], {k: E_to_Expression(val) for k, val in kwargs.items()} if kwargs is not None else {}, )
[docs] def method_call( self, attribute: str, args: Sequence[ExpressionLike] | None = None, kwargs: Mapping[str, ExpressionLike] | None = None, ) -> Call: """Return a :class:`Call` expression for a method call on this expression.""" return self.attr(attribute).call(args, kwargs)
[docs] def subscript(self, slice: ExpressionLike, /) -> Subscript: """Return a :class:`Subscript` expression indexing this expression.""" return Subscript(self, E_to_Expression(slice))
# Arithmetic operators
[docs] def add(self, other: ExpressionLike, /) -> Add: """Return an :class:`Add` (``+``) expression.""" return Add(self, E_to_Expression(other))
[docs] def sub(self, other: ExpressionLike, /) -> Sub: """Return a :class:`Sub` (``-``) expression.""" return Sub(self, E_to_Expression(other))
[docs] def mul(self, other: ExpressionLike, /) -> Mul: """Return a :class:`Mul` (``*``) expression.""" return Mul(self, E_to_Expression(other))
[docs] def div(self, other: ExpressionLike, /) -> Div: """Return a :class:`Div` (``/``) expression.""" return Div(self, E_to_Expression(other))
[docs] def floordiv(self, other: ExpressionLike, /) -> FloorDiv: """Return a :class:`FloorDiv` (``//``) expression.""" return FloorDiv(self, E_to_Expression(other))
[docs] def mod(self, other: ExpressionLike, /) -> Mod: """Return a :class:`Mod` (``%``) expression.""" return Mod(self, E_to_Expression(other))
[docs] def pow(self, other: ExpressionLike, /) -> Pow: """Return a :class:`Pow` (``**``) expression.""" return Pow(self, E_to_Expression(other))
[docs] def matmul(self, other: ExpressionLike, /) -> MatMul: """Return a :class:`MatMul` (``@``) expression.""" return MatMul(self, E_to_Expression(other))
# Bitwise operators
[docs] def bitand(self, other: ExpressionLike, /) -> BitAnd: """Return a :class:`BitAnd` (``&``) expression.""" return BitAnd(self, E_to_Expression(other))
[docs] def bitor(self, other: ExpressionLike, /) -> BitOr: """Return a :class:`BitOr` (``|``) expression.""" return BitOr(self, E_to_Expression(other))
[docs] def xor(self, other: ExpressionLike, /) -> BitXor: """Return a :class:`BitXor` (``^``) expression.""" return BitXor(self, E_to_Expression(other))
[docs] def lshift(self, other: ExpressionLike, /) -> LShift: """Return a :class:`LShift` (``<<``) expression.""" return LShift(self, E_to_Expression(other))
[docs] def rshift(self, other: ExpressionLike, /) -> RShift: """Return a :class:`RShift` (``>>``) expression.""" return RShift(self, E_to_Expression(other))
[docs] def invert(self) -> Invert: """Return an :class:`Invert` (``~self``) expression.""" return Invert(self)
# Comparison operators
[docs] def eq(self, other: ExpressionLike, /) -> Equals: """Return an :class:`Equals` (``==``) expression.""" return Equals(self, E_to_Expression(other))
[docs] def ne(self, other: ExpressionLike, /) -> NotEquals: """Return a :class:`NotEquals` (``!=``) expression.""" return NotEquals(self, E_to_Expression(other))
[docs] def lt(self, other: ExpressionLike, /) -> Lt: """Return a :class:`Lt` (``<``) expression.""" return Lt(self, E_to_Expression(other))
[docs] def gt(self, other: ExpressionLike, /) -> Gt: """Return a :class:`Gt` (``>``) expression.""" return Gt(self, E_to_Expression(other))
[docs] def le(self, other: ExpressionLike, /) -> LtE: """Return a :class:`LtE` (``<=``) expression.""" return LtE(self, E_to_Expression(other))
[docs] def ge(self, other: ExpressionLike, /) -> GtE: """Return a :class:`GtE` (``>=``) expression.""" return GtE(self, E_to_Expression(other))
# Boolean operators
[docs] def and_(self, other: ExpressionLike, /) -> And: """Return an :class:`And` (``and``) expression.""" return And(self, E_to_Expression(other))
[docs] def or_(self, other: ExpressionLike, /) -> Or: """Return an :class:`Or` (``or``) expression.""" return Or(self, E_to_Expression(other))
# Membership operators
[docs] def in_(self, other: ExpressionLike, /) -> In: """Return an :class:`In` (``in``) expression.""" return In(self, E_to_Expression(other))
[docs] def not_in(self, other: ExpressionLike, /) -> NotIn: """Return a :class:`NotIn` (``not in``) expression.""" return NotIn(self, E_to_Expression(other))
# Identity operators
[docs] def is_(self, other: ExpressionLike, /) -> Is: """Return an :class:`Is` (``is``) expression.""" return Is(self, E_to_Expression(other))
[docs] def is_not(self, other: ExpressionLike, /) -> IsNot: """Return an :class:`IsNot` (``is not``) expression.""" return IsNot(self, E_to_Expression(other))
# Unary operators
[docs] def not_(self) -> Not: """Return a :class:`Not` (``not self``) expression.""" return Not(self)
[docs] def pos(self) -> UAdd: """Return a :class:`UAdd` (``+self``) expression.""" return UAdd(self)
[docs] def neg(self) -> USub: """Return a :class:`USub` (``-self``) expression.""" return USub(self)
# Unpacking
[docs] def starred(self) -> Starred: """Return a :class:`Starred` (``*self``) unpacking expression.""" return Starred(self)
# Walrus / NamedExpr
[docs] def named(self, name: Name) -> NamedExpr: """ Return the expression as a named expression (walrus operator) """ return NamedExpr(name=name, value=self)
# E-object:
[docs] @cached_property def e(self) -> E: """ Returns this Expression as an E-object for easier expression generation. """ return E(self)
[docs] class String(Expression): """A string literal expression.""" def __init__(self, string_value: str): self.string_value = string_value
[docs] def as_ast(self, *, include_comments: bool = False) -> py_ast.expr: return py_ast.Constant( self.string_value, kind=None, # 3.8, indicates no prefix, needed only for tests **DEFAULT_AST_ARGS, )
def __repr__(self): return f"String({repr(self.string_value)})" def __eq__(self, other: object): return isinstance(other, String) and other.string_value == self.string_value
[docs] class Bool(Expression): """A boolean literal expression (``True`` or ``False``).""" def __init__(self, value: bool): self.value = value
[docs] def as_ast(self, *, include_comments: bool = False) -> py_ast.expr: return py_ast.Constant(self.value, **DEFAULT_AST_ARGS)
def __repr__(self): return f"Bool({self.value!r})"
[docs] class Bytes(Expression): """A bytes literal expression.""" def __init__(self, value: bytes): self.value = value
[docs] def as_ast(self, *, include_comments: bool = False) -> py_ast.expr: return py_ast.Constant(self.value, **DEFAULT_AST_ARGS)
def __repr__(self): return f"Bytes({self.value!r})"
[docs] class Number(Expression): """A numeric literal expression (``int`` or ``float``).""" def __init__(self, number: int | float): self.number = number
[docs] def as_ast(self, *, include_comments: bool = False) -> py_ast.expr: return py_ast.Constant(self.number, **DEFAULT_AST_ARGS)
def __repr__(self): return f"Number({repr(self.number)})"
[docs] class List(Expression): """A list literal expression.""" def __init__(self, items: Sequence[Expression]): self.items = items
[docs] def as_ast(self, *, include_comments: bool = False) -> py_ast.expr: return py_ast.List(elts=[i.as_ast() for i in self.items], ctx=py_ast.Load(), **DEFAULT_AST_ARGS)
[docs] class Tuple(Expression): """A tuple literal expression.""" def __init__(self, items: Sequence[Expression]): self.items = items
[docs] def as_ast(self, *, include_comments: bool = False) -> py_ast.expr: return py_ast.Tuple(elts=[i.as_ast() for i in self.items], ctx=py_ast.Load(), **DEFAULT_AST_ARGS)
[docs] class Set(Expression): """A set literal expression, using ``set([])`` for the empty case.""" def __init__(self, items: Sequence[Expression]): self.items = items
[docs] def as_ast(self, *, include_comments: bool = False) -> py_ast.expr: if len(self.items) == 0: # {} is a dict literal in Python, so empty sets must use set([]) return py_ast.Call( func=py_ast.Name(id="set", ctx=py_ast.Load(), **DEFAULT_AST_ARGS), args=[py_ast.List(elts=[], ctx=py_ast.Load(), **DEFAULT_AST_ARGS)], keywords=[], **DEFAULT_AST_ARGS, ) return py_ast.Set(elts=[i.as_ast() for i in self.items], **DEFAULT_AST_ARGS)
[docs] class Dict(Expression): """A dict literal expression.""" def __init__(self, pairs: Sequence[tuple[Expression, Expression]]): self.pairs = pairs
[docs] def as_ast(self, *, include_comments: bool = False) -> py_ast.expr: return py_ast.Dict( keys=[k.as_ast() for k, _ in self.pairs], values=[v.as_ast() for _, v in self.pairs], **DEFAULT_AST_ARGS, )
[docs] class StringJoinBase(Expression): """Base class for string concatenation expressions.""" def __init__(self, parts: Sequence[Expression]): self.parts = parts def __repr__(self): return f"{self.__class__.__name__}([{', '.join(repr(p) for p in self.parts)}])"
[docs] @classmethod def build(cls: type[StringJoinBase], parts: Sequence[Expression]) -> StringJoinBase | Expression: """ Build a string join operation, but return a simpler expression if possible. """ # Merge adjacent String objects. new_parts: list[Expression] = [] for part in parts: if len(new_parts) > 0 and isinstance(new_parts[-1], String) and isinstance(part, String): new_parts[-1] = String(new_parts[-1].string_value + part.string_value) else: new_parts.append(part) parts = new_parts # See if we can eliminate the StringJoin altogether if len(parts) == 0: return String("") if len(parts) == 1: return parts[0] return cls(parts)
[docs] class FStringJoin(StringJoinBase): """Join string parts using an f-string expression."""
[docs] def as_ast(self, *, include_comments: bool = False) -> py_ast.expr: # f-strings values: list[py_ast.expr] = [] for part in self.parts: if isinstance(part, String): values.append(part.as_ast()) else: values.append( py_ast.FormattedValue( value=part.as_ast(), conversion=-1, format_spec=None, **DEFAULT_AST_ARGS, ) ) return py_ast.JoinedStr(values=values, **DEFAULT_AST_ARGS)
[docs] class ConcatJoin(StringJoinBase): """Join string parts using ``+`` concatenation."""
[docs] def as_ast(self, *, include_comments: bool = False) -> py_ast.expr: # Concatenate with + left = self.parts[0].as_ast() for part in self.parts[1:]: right = part.as_ast() left = py_ast.BinOp( left=left, op=py_ast.Add(**DEFAULT_AST_ARGS_ADD), right=right, **DEFAULT_AST_ARGS, ) return left
# For CPython, f-strings give a measurable improvement over concatenation, # so make that default StringJoin = FStringJoin
[docs] class Name(Expression): """A reference to a named variable that must already be reserved in its :class:`Scope`.""" def __init__(self, name: str, scope: Scope): if not scope.is_name_in_use(name): raise AssertionError(f"Cannot refer to undefined name '{name}'") self.name = name
[docs] def as_ast(self, *, include_comments: bool = False) -> py_ast.Name: if not allowable_name(self.name, allow_builtin=True): raise AssertionError(f"Expected {self.name} to be a valid Python identifier") return py_ast.Name(id=self.name, ctx=py_ast.Load(), **DEFAULT_AST_ARGS)
def __eq__(self, other: object): return type(other) is type(self) and other.name == self.name def __repr__(self): return f"Name({repr(self.name)})"
[docs] class Attr(Expression): """An attribute access expression (e.g. ``obj.attr``).""" def __init__(self, value: Expression, attribute: str) -> None: self.value = value if not allowable_name(attribute, allow_builtin=True): raise AssertionError(f"Expected {attribute} to be a valid Python identifier") self.attribute = attribute
[docs] def as_ast(self, *, include_comments: bool = False) -> py_ast.expr: return py_ast.Attribute(value=self.value.as_ast(), attr=self.attribute, **DEFAULT_AST_ARGS)
[docs] class Starred(Expression): """A starred (unpacking) expression (e.g. ``*args``).""" def __init__(self, value: Expression): self.value = value
[docs] def as_ast(self, *, include_comments: bool = False) -> py_ast.expr: return py_ast.Starred(value=self.value.as_ast(), ctx=py_ast.Load(), **DEFAULT_AST_ARGS)
def __repr__(self): return f"Starred({self.value!r})"
class NamedExpr(Expression): """ A named expression (walrus operator) e.g. `x := 1` """ def __init__(self, name: Name, value: Expression) -> None: self.name = name self.value = value def as_ast(self, *, include_comments: bool = False) -> py_ast.expr: return py_ast.NamedExpr(target=self.name.as_ast(), value=self.value.as_ast(), **DEFAULT_AST_ARGS)
[docs] def function_call( function_name: str, args: Sequence[Expression], kwargs: Mapping[str, Expression], scope: Scope, ) -> Expression: """Create a function call expression, validating that the name exists in *scope*.""" if not scope.is_name_in_use(function_name): raise AssertionError(f"Cannot call unknown function '{function_name}'") if function_name in SENSITIVE_FUNCTIONS: raise AssertionError(f"Disallowing call to '{function_name}'") return Name(name=function_name, scope=scope).call(args, kwargs)
[docs] class Call(Expression): """A function or method call expression.""" def __init__( self, value: Expression, args: Sequence[Expression], kwargs: Mapping[str, Expression], ): self.value = value self.args = list(args) self.kwargs = dict(kwargs)
[docs] def as_ast(self, *, include_comments: bool = False) -> py_ast.expr: for name in self.kwargs.keys(): if not allowable_keyword_arg_name(name): raise AssertionError(f"Expected {name} to be a valid Fluent NamedArgument name") if any(not allowable_name(name) for name in self.kwargs.keys()): # This branch covers function arg names like 'foo-bar', which are # allowable in languages like Fluent, but not normally in Python. We work around # this using `my_function(**{'foo-bar': baz})` syntax. # (If we only wanted to exec the resulting AST, this branch is technically not # necessary, since it is the Python parser that disallows `foo-bar` as an identifier, # and we are by-passing that by creating AST directly. However, to produce something # that can be decompiled to valid Python, we solve the general case). kwarg_pairs = list(sorted(self.kwargs.items())) kwarg_names, kwarg_values = [k for k, _ in kwarg_pairs], [v for _, v in kwarg_pairs] return py_ast.Call( func=self.value.as_ast(), args=[arg.as_ast() for arg in self.args], keywords=[ py_ast.keyword( arg=None, value=py_ast.Dict( keys=[py_ast.Constant(k, kind=None, **DEFAULT_AST_ARGS) for k in kwarg_names], values=[v.as_ast() for v in kwarg_values], **DEFAULT_AST_ARGS, ), **DEFAULT_AST_ARGS, ) ], **DEFAULT_AST_ARGS, ) # Normal `my_function(foo=bar)` syntax return py_ast.Call( func=self.value.as_ast(), args=[arg.as_ast() for arg in self.args], keywords=[ py_ast.keyword(arg=name, value=value.as_ast(), **DEFAULT_AST_ARGS) for name, value in self.kwargs.items() ], **DEFAULT_AST_ARGS, )
def __repr__(self): return f"Call({self.value!r}, {self.args}, {self.kwargs})"
[docs] def method_call( obj: Expression, method_name: str, args: Sequence[Expression], kwargs: Mapping[str, Expression], ): """Create a method call expression on *obj*.""" return obj.attr(method_name).call(args=args, kwargs=kwargs)
[docs] class Subscript(Expression): """A subscript (indexing) expression (e.g. ``obj[key]``).""" def __init__(self, value: Expression, slice: Expression): self.value = value self.slice = slice
[docs] def as_ast(self, *, include_comments: bool = False) -> py_ast.expr: return py_ast.Subscript( value=self.value.as_ast(), slice=py_ast.subscript_slice_object(self.slice.as_ast()), ctx=py_ast.Load(), **DEFAULT_AST_ARGS, )
[docs] class Slice(Expression): """A slice expression (e.g. ``0:10``, ``::2``, ``1:-1``). Used as the *slice* argument to :class:`Subscript`. All three components — *start*, *stop*, and *step* — are optional. """ def __init__( self, start: Expression | None = None, stop: Expression | None = None, step: Expression | None = None, ): self.start = start self.stop = stop self.step = step
[docs] def as_ast(self, *, include_comments: bool = False) -> py_ast.expr: return py_ast.Slice( lower=self.start.as_ast() if self.start is not None else None, upper=self.stop.as_ast() if self.stop is not None else None, step=self.step.as_ast() if self.step is not None else None, **DEFAULT_AST_ARGS, )
def __repr__(self): parts: list[str] = [] if self.start is not None: parts.append(f"start={self.start!r}") if self.stop is not None: parts.append(f"stop={self.stop!r}") if self.step is not None: parts.append(f"step={self.step!r}") return f"Slice({', '.join(parts)})"
[docs] class Lambda(Expression): """ A lambda expression e.g. `lambda x: x + 1` """ def __init__(self, args: Sequence[str | FunctionArg], body: Expression) -> None: self.args = _normalize_args(args) self.body = body # A Lambda is a small Scope, with fixed arguments. scope = Scope() for arg in self.args: scope.reserve_name(arg.name) self.enames = Enames(scope=scope)
[docs] def as_ast(self, *, include_comments: bool = False) -> py_ast.expr: return py_ast.Lambda( args=_function_arg_list_to_ast_arguments(self.args), body=self.body.as_ast(), **DEFAULT_AST_ARGS, )
[docs] class Comprehension: def __init__(self, target: Target, iter: Expression): self.target = target self.iter = iter
class _EltComprehensionBase(Expression): """Base class for comprehensions with a single element expression (list, set, generator).""" def __init__(self, element: Expression, generators: Sequence[Comprehension], ifs: Sequence[Expression]) -> None: self.element = element self.generators = generators self.ifs = ifs def _make_generators(self) -> list[py_ast.comprehension]: return [ py_ast.comprehension( target=_target_as_store_ast(comp.target), iter=comp.iter.as_ast(), ifs=[if_.as_ast() for if_ in self.ifs], is_async=False, ) for comp in self.generators ]
[docs] class ListComp(_EltComprehensionBase): """A list comprehension expression, e.g. ``[x + 1 for x in items]``."""
[docs] def as_ast(self, *, include_comments: bool = False) -> py_ast.expr: return py_ast.ListComp( elt=self.element.as_ast(), generators=self._make_generators(), )
[docs] class SetComp(_EltComprehensionBase): """A set comprehension expression, e.g. ``{x + 1 for x in items}``."""
[docs] def as_ast(self, *, include_comments: bool = False) -> py_ast.expr: return py_ast.SetComp( elt=self.element.as_ast(), generators=self._make_generators(), )
[docs] class GeneratorExpr(_EltComprehensionBase): """A generator expression, e.g. ``(x + 1 for x in items)``."""
[docs] def as_ast(self, *, include_comments: bool = False) -> py_ast.expr: return py_ast.GeneratorExp( elt=self.element.as_ast(), generators=self._make_generators(), )
[docs] class DictComp(Expression): """A dict comprehension expression, e.g. ``{k: v for k, v in items}``.""" def __init__( self, key: Expression, value: Expression, generators: Sequence[Comprehension], ifs: Sequence[Expression] ) -> None: self.key = key self.value = value self.generators = generators self.ifs = ifs
[docs] def as_ast(self, *, include_comments: bool = False) -> py_ast.expr: return py_ast.DictComp( key=self.key.as_ast(), value=self.value.as_ast(), generators=[ py_ast.comprehension( target=_target_as_store_ast(comp.target), iter=comp.iter.as_ast(), ifs=[if_.as_ast() for if_ in self.ifs], is_async=False, ) for comp in self.generators ], )
[docs] def list_comprehension( *, iterable: ExpressionLike, target: Target, element: ExpressionLike, condition: ExpressionLike | None = None ) -> ListComp: """Create a :class:`ListComp` (list comprehension) expression. e.g.:: data = auto([1, 2, 3]) list_comprehension( iterable=data, target=(loop_var := mod.scope.create_name("item")), element=loop_var.e + 1, condition=loop_var.e > 0 ) This produces: ``[item + 1 for item in [1, 2, 3] if item > 0]`` """ comprehension = Comprehension(target, E_to_Expression(iterable)) return ListComp( E_to_Expression(element), [comprehension], ifs=[E_to_Expression(condition)] if condition is not None else [] )
[docs] def set_comprehension( *, iterable: ExpressionLike, target: Target, element: ExpressionLike, condition: ExpressionLike | None = None ) -> SetComp: """Create a :class:`SetComp` (set comprehension) expression. e.g.:: data = auto([1, 2, 3]) set_comprehension( iterable=data, target=(loop_var := mod.scope.create_name("item")), element=loop_var.e + 1, condition=loop_var.e > 0 ) This produces: ``{item + 1 for item in [1, 2, 3] if item > 0}`` """ comprehension = Comprehension(target, E_to_Expression(iterable)) return SetComp( E_to_Expression(element), [comprehension], ifs=[E_to_Expression(condition)] if condition is not None else [] )
[docs] def generator_expression( *, iterable: ExpressionLike, target: Target, element: ExpressionLike, condition: ExpressionLike | None = None ) -> GeneratorExpr: """Create a :class:`SetComp` (set comprehension) expression. e.g.:: data = auto([1, 2, 3]) my_func = mod.scope.create_name("my_func") my_func.e(generator_expression( iterable=data, target=(loop_var := mod.scope.create_name("item")), element=loop_var.e + 1, condition=loop_var.e > 0 ) This produces: ``my_func((item + 1 for item in [1, 2, 3] if item > 0))`` """ comprehension = Comprehension(target, E_to_Expression(iterable)) return GeneratorExpr( E_to_Expression(element), [comprehension], ifs=[E_to_Expression(condition)] if condition is not None else [] )
[docs] def dict_comprehension( *, iterable: ExpressionLike, target: Target, key: ExpressionLike, value: ExpressionLike, condition: ExpressionLike | None = None, ) -> DictComp: """Create a :class:`DictComp` (dict comprehension) expression. e.g.:: dict_comprehension( iterable=items, target=( (key_var := mod.scope.create_name("k")), (value_var := mod.scope.create_name("v")), ), key=key_var.e + "_x", value=value_var.e + 1, ) This produces: ``{k + '_x': v + 1 for k, v in items}`` """ comprehension = Comprehension(target, E_to_Expression(iterable)) return DictComp( E_to_Expression(key), E_to_Expression(value), [comprehension], ifs=[E_to_Expression(condition)] if condition is not None else [], )
[docs] def create_lambda( args: Sequence[str | FunctionArg], body: ExpressionLike | Callable[[Lambda], ExpressionLike] ) -> Lambda: """ Create a lambda expression. The body can be supplied by either an expression, or a callable that will be called with a `Lambda` object as its only argument. This makes it possible to access the `enames` object on the `Lambda`:: create_lambda('x', lambda self: self.enames.x + 1) Produces: ``lambda x: x + 1`` """ if callable(body): temp_lambda = Lambda(args, body=constants.None_) body_expr = E_to_Expression(body(temp_lambda)) else: body_expr = E_to_Expression(body) return Lambda( args=args, body=body_expr, )
[docs] def named(name: Name, value: ExpressionLike) -> NamedExpr: """ Create a NamedExpr from an Expression or E-object:: x = mod.scope.create_name("x") value = codegen.auto(1).e + 1 named_val = codegen.named(x, value) Produces:: (x := 1 + 1) """ return NamedExpr(name=name, value=E_to_Expression(value))
#: Type alias for valid assignment target expressions. #: A :class:`Name`, :class:`Attr`, or :class:`Subscript` expression, #: or a tuple of targets (for unpacking assignments). type Target = Name | Attr | Subscript | tuple[Target, ...] def is_target(value: object) -> TypeIs[Target]: """Return whether *value* is a valid assignment :data:`Target`.""" if isinstance(value, (Name, Attr, Subscript)): return True if isinstance(value, tuple): return all(is_target(element) for element in value) # pyright: ignore[reportUnknownArgumentType,reportUnknownVariableType] return False def _target_as_store_ast(target: Target) -> py_ast.expr: """Convert a Target expression to an AST node with Store context.""" if isinstance(target, tuple): return py_ast.Tuple( elts=[_target_as_store_ast(t) for t in target], ctx=py_ast.Store(), **DEFAULT_AST_ARGS, ) node = target.as_ast() if isinstance(node, (py_ast.Name, py_ast.Attribute, py_ast.Subscript)): node.ctx = py_ast.Store() return node raise AssertionError(f"Unexpected AST node type for target: {type(node)}") # pragma: no cover def _aug_target_as_store_ast(target: AugAssignTarget) -> py_ast.Name | py_ast.Attribute | py_ast.Subscript: node = target.as_ast() if isinstance(node, (py_ast.Name, py_ast.Attribute, py_ast.Subscript)): node.ctx = py_ast.Store() return node raise AssertionError(f"Unexpected AST node type for target: {type(node)}") # pragma: no cover def _target_names(target: Target) -> list[str]: """Collect all plain :class:`Name` identifiers from a *target*.""" if isinstance(target, Name): return [target.name] if isinstance(target, tuple): names: list[str] = [] for t in target: names.extend(_target_names(t)) return names return [] create_class_instance = function_call
[docs] class NoneExpr(Expression): """A ``None`` literal expression."""
[docs] def as_ast(self, *, include_comments: bool = False) -> py_ast.expr: return py_ast.Constant(value=None, **DEFAULT_AST_ARGS)
[docs] class BinaryOperator(Expression): """Base class for binary operator expressions with a left and right operand.""" def __init__(self, left: Expression, right: Expression): self.left = left self.right = right
[docs] class ArithOp(BinaryOperator, ABC): """Arithmetic binary operator (ast.BinOp).""" op: ClassVar[type[py_ast.operator]]
[docs] def as_ast(self, *, include_comments: bool = False) -> py_ast.expr: return py_ast.BinOp( left=self.left.as_ast(), op=self.op(**DEFAULT_AST_ARGS_ADD), right=self.right.as_ast(), **DEFAULT_AST_ARGS, )
[docs] class Add(ArithOp): """Addition (``+``) operator.""" op = py_ast.Add
[docs] class Sub(ArithOp): """Subtraction (``-``) operator.""" op = py_ast.Sub
[docs] class Mul(ArithOp): """Multiplication (``*``) operator.""" op = py_ast.Mult
[docs] class Div(ArithOp): """True division (``/``) operator.""" op = py_ast.Div
[docs] class FloorDiv(ArithOp): """Floor division (``//``) operator.""" op = py_ast.FloorDiv
[docs] class Mod(ArithOp): """Modulo (``%``) operator.""" op = py_ast.Mod
[docs] class Pow(ArithOp): """Exponentiation (``**``) operator.""" op = py_ast.Pow
[docs] class MatMul(ArithOp): """Matrix multiplication (``@``) operator.""" op = py_ast.MatMult
class BitAnd(ArithOp): """Bitwise AND (``&``) operator.""" op = py_ast.BitAnd class BitOr(ArithOp): """Bitwise OR (``|``) operator.""" op = py_ast.BitOr class BitXor(ArithOp): """Bitwise XOR (``^``) operator.""" op = py_ast.BitXor class LShift(ArithOp): """Left shift (``<<``) operator.""" op = py_ast.LShift class RShift(ArithOp): """Right shift (``>>``) operator.""" op = py_ast.RShift
[docs] class CompareOp(BinaryOperator, ABC): """Comparison operator (ast.Compare).""" op: ClassVar[type[py_ast.cmpop]]
[docs] def as_ast(self, *, include_comments: bool = False) -> py_ast.expr: return py_ast.Compare( left=self.left.as_ast(), comparators=[self.right.as_ast()], ops=[self.op()], **DEFAULT_AST_ARGS, )
[docs] class Equals(CompareOp): """Equality (``==``) comparison.""" op = py_ast.Eq
[docs] class NotEquals(CompareOp): """Inequality (``!=``) comparison.""" op = py_ast.NotEq
[docs] class Lt(CompareOp): """Less-than (``<``) comparison.""" op = py_ast.Lt
[docs] class Gt(CompareOp): """Greater-than (``>``) comparison.""" op = py_ast.Gt
[docs] class LtE(CompareOp): """Less-than-or-equal (``<=``) comparison.""" op = py_ast.LtE
[docs] class GtE(CompareOp): """Greater-than-or-equal (``>=``) comparison.""" op = py_ast.GtE
[docs] class In(CompareOp): """Membership (``in``) comparison.""" op = py_ast.In
[docs] class NotIn(CompareOp): """Non-membership (``not in``) comparison.""" op = py_ast.NotIn
class Is(CompareOp): """Identity (``is``) comparison.""" op = py_ast.Is class IsNot(CompareOp): """Non-identity (``is not``) comparison.""" op = py_ast.IsNot
[docs] class BoolOp(BinaryOperator, ABC): """Boolean operator (``and`` / ``or``).""" op: ClassVar[type[py_ast.boolop]]
[docs] def as_ast(self, *, include_comments: bool = False) -> py_ast.expr: return py_ast.BoolOp( op=self.op(), values=[self.left.as_ast(), self.right.as_ast()], **DEFAULT_AST_ARGS, )
[docs] class And(BoolOp): """Logical ``and`` operator.""" op = py_ast.And
[docs] class Or(BoolOp): """Logical ``or`` operator.""" op = py_ast.Or
class UnaryOperator(Expression, ABC): """Base class for unary operator expressions.""" op: ClassVar[type[py_ast.unaryop]] def __init__(self, operand: Expression): self.operand = operand def as_ast(self, *, include_comments: bool = False) -> py_ast.expr: return py_ast.UnaryOp( op=self.op(**DEFAULT_AST_ARGS_ADD), operand=self.operand.as_ast(), **DEFAULT_AST_ARGS, ) class Not(UnaryOperator): """Logical ``not`` operator.""" op = py_ast.Not class UAdd(UnaryOperator): """Unary positive (``+``) operator.""" op = py_ast.UAdd class USub(UnaryOperator): """Unary negation (``-``) operator.""" op = py_ast.USub class Invert(UnaryOperator): """Bitwise inversion (``~``) operator.""" op = py_ast.Invert
[docs] def simplify(codegen_ast: CodeGenAstType, simplifier: Callable[[CodeGenAstType], CodeGenAst | None]): """ Repeatedly apply *simplifier* to *codegen_ast* until no more changes are made. The simplifier function should return None if no changes are to be made, or the new CodeGenAst object otherwise. """ changes = [True] def rewriter(node: CodeGenAstType) -> CodeGenAstType: simplified = simplifier(node) if simplified is not None: changes.append(True) return simplified return node while any(changes): changes[:] = [] rewriting_traverse(codegen_ast, rewriter) return codegen_ast
[docs] def rewriting_traverse( node: CodeGenAstType | Sequence[CodeGenAstType], func: Callable[[CodeGenAstType], CodeGenAstType], _visited: set[int] | None = None, ): """ Apply 'func' to node and all sub CodeGenAst nodes. Discovers child nodes by introspecting instance attributes rather than relying on a manually-maintained list. A *visited* set (keyed by ``id()``) prevents infinite recursion through circular references (e.g. Block.scope → Function → body → Block). """ if _visited is None: _visited = set() node_id = id(node) if node_id in _visited: return if isinstance(node, (CodeGenAst, CodeGenAstList)): _visited.add(node_id) new_node = func(node) if new_node is not node: morph_into(node, new_node) for value in node.__dict__.values(): rewriting_traverse(value, func, _visited) elif isinstance(node, (list, tuple)): for i in node: rewriting_traverse(i, func, _visited) elif isinstance(node, dict): for v in node.values(): # type: ignore[reportUnknownVariableType] rewriting_traverse(v, func, _visited) # type: ignore[reportUnknownVariableType]
[docs] def morph_into(item: object, new_item: object) -> None: """Mutate *item* in-place to behave identically to *new_item* while preserving identity.""" item.__class__ = new_item.__class__ item.__dict__ = new_item.__dict__
def empty_If() -> py_ast.If: """ Create an empty If ast node. The `test` attribute must be added later. """ return py_ast.If(test=None, orelse=[], **DEFAULT_AST_ARGS) # type: ignore[reportArgumentType] type SimplePythonObj = bool | str | bytes | int | float | None type Autoable = ( SimplePythonObj | Expression | E | slice[Autoable] | list[Autoable] | tuple[Autoable, ...] | set[Autoable] | dict[Autoable, Autoable] ) @overload def auto(value: bool) -> Bool: ... # type: ignore[overload-overlap] # bool before int/float is intentional @overload def auto(value: str) -> String: ... @overload def auto(value: bytes) -> Bytes: ... @overload def auto(value: int) -> Number: ... @overload def auto(value: float) -> Number: ... @overload def auto(value: None) -> NoneExpr: ... @overload def auto(value: slice[Autoable]) -> Slice: ... @overload def auto(value: Expression) -> Expression: ... @overload def auto(value: E) -> Expression: ... @overload def auto(value: list[Autoable]) -> List: ... @overload def auto(value: tuple[Autoable, ...]) -> Tuple: ... @overload def auto(value: set[Autoable]) -> Set: ... @overload def auto(value: dict[Autoable, Autoable]) -> Dict: ...
[docs] def auto(value: Autoable) -> Expression: """ Create a codegen Expression from a plain Python object. Supports bool, str, bytes, int, float, None, slice, and recursively list, tuple, set, and dict. It also supports a mixture - containers than have both plain Python objects and values that are already Expression or E-objects. """ if isinstance(value, bool): return Bool(value) elif isinstance(value, str): return String(value) elif isinstance(value, bytes): return Bytes(value) elif isinstance(value, (int, float)): return Number(value) elif value is None: return constants.None_ elif isinstance(value, slice): return Slice( start=auto(value.start) if value.start is not None else None, stop=auto(value.stop) if value.stop is not None else None, step=auto(value.step) if value.step is not None else None, ) elif isinstance(value, Expression): return value elif isinstance(value, E): return E_to_Expression(value) elif isinstance(value, list): return List([auto(item) for item in value]) elif isinstance(value, tuple): return Tuple([auto(item) for item in value]) elif isinstance(value, set): return Set([auto(item) for item in sorted(value, key=repr)]) elif isinstance(value, dict): # type: ignore[reportUnnecessaryIsInstance] return Dict([(auto(k), auto(v)) for k, v in value.items()]) assert_never(value)
type ELike = E | Autoable """ E-objects or things used in similar contexts (Python literals) """ type ExpressionLike = Expression | E """ Expression objects, or things used in similar contexts (E-objects) """ class E: def __init__(self, expr: Expression) -> None: self._the_expr = expr def __getattr__(self, name: str) -> E: return E(self._the_expr.attr(name)) def __call__(self, *args: ELike, **kwargs: ELike) -> E: exp_args = [ELike_to_Expression(arg) for arg in args] exp_kwargs = {key: ELike_to_Expression(val) for key, val in kwargs.items()} return E(self._the_expr.call(exp_args, exp_kwargs)) # Subscript def __getitem__(self, key: ELike) -> E: return E(self._the_expr.subscript(ELike_to_Expression(key))) # Arithmetic operators def __add__(self, other: ELike) -> E: return E(self._the_expr.add(ELike_to_Expression(other))) def __radd__(self, other: ELike) -> E: return E(ELike_to_Expression(other).add(self._the_expr)) def __sub__(self, other: ELike) -> E: return E(self._the_expr.sub(ELike_to_Expression(other))) def __rsub__(self, other: ELike) -> E: return E(ELike_to_Expression(other).sub(self._the_expr)) def __mul__(self, other: ELike) -> E: return E(self._the_expr.mul(ELike_to_Expression(other))) def __rmul__(self, other: ELike) -> E: return E(ELike_to_Expression(other).mul(self._the_expr)) def __truediv__(self, other: ELike) -> E: return E(self._the_expr.div(ELike_to_Expression(other))) def __rtruediv__(self, other: ELike) -> E: return E(ELike_to_Expression(other).div(self._the_expr)) def __floordiv__(self, other: ELike) -> E: return E(self._the_expr.floordiv(ELike_to_Expression(other))) def __rfloordiv__(self, other: ELike) -> E: return E(ELike_to_Expression(other).floordiv(self._the_expr)) def __mod__(self, other: ELike) -> E: return E(self._the_expr.mod(ELike_to_Expression(other))) def __rmod__(self, other: ELike) -> E: return E(ELike_to_Expression(other).mod(self._the_expr)) def __pow__(self, other: ELike) -> E: return E(self._the_expr.pow(ELike_to_Expression(other))) def __rpow__(self, other: ELike) -> E: return E(ELike_to_Expression(other).pow(self._the_expr)) def __matmul__(self, other: ELike) -> E: return E(self._the_expr.matmul(ELike_to_Expression(other))) def __rmatmul__(self, other: ELike) -> E: return E(ELike_to_Expression(other).matmul(self._the_expr)) # Bitwise: def __and__(self, other: ELike) -> E: return E(self._the_expr.bitand(ELike_to_Expression(other))) def __or__(self, other: ELike) -> E: return E(self._the_expr.bitor(ELike_to_Expression(other))) def __xor__(self, other: ELike) -> E: return E(self._the_expr.xor(ELike_to_Expression(other))) def __lshift__(self, other: ELike) -> E: return E(self._the_expr.lshift(ELike_to_Expression(other))) def __rshift__(self, other: ELike) -> E: return E(self._the_expr.rshift(ELike_to_Expression(other))) def __invert__(self) -> E: return E(self._the_expr.invert()) # Comparison operators def __eq__(self, other: ELike) -> E: # type: ignore return E(self._the_expr.eq(ELike_to_Expression(other))) def __ne__(self, other: ELike) -> E: # type: ignore return E(self._the_expr.ne(ELike_to_Expression(other))) def __lt__(self, other: ELike) -> E: # type: ignore return E(self._the_expr.lt(ELike_to_Expression(other))) def __gt__(self, other: ELike) -> E: # type: ignore return E(self._the_expr.gt(ELike_to_Expression(other))) def __le__(self, other: ELike) -> E: # type: ignore return E(self._the_expr.le(ELike_to_Expression(other))) def __ge__(self, other: ELike) -> E: # type: ignore return E(self._the_expr.ge(ELike_to_Expression(other))) # Boolean operators - these cannot exists, and/or are short-circuiting # operators that cannot be overridden # Membership # It looks like we can't override __contains__ effectively # to do what we need to support `in` and `not in` # - Python coerces to bool afterwards, # - `__bool__` has to return True/False # - we can't do `not in` # Identity operators - again can't be overridden # # Unary operators def __pos__(self) -> E: return E(self._the_expr.pos()) def __neg__(self) -> E: return E(self._the_expr.neg()) def E_to_Expression(val: E | Expression) -> Expression: if isinstance(val, E): return val._the_expr # type: ignore[reportPrivateUsage] return val def ELike_to_Expression(val: ELike) -> Expression: if isinstance(val, E): return val._the_expr # type: ignore[reportPrivateUsage] return auto(val) class Enames: def __init__(self, scope: Scope) -> None: self.__scope = scope def __getattr__(self, name: str) -> E: return E(self.__scope.name(name))
[docs] class constants: """ Useful pre-made Expression constants """ None_: NoneExpr = NoneExpr() True_: Bool = auto(True) False_: Bool = auto(False)