"""
Utilities for doing Python code generation
"""
from __future__ import annotations
import builtins
import enum
import keyword
import re
import sys
import textwrap
from abc import ABC, abstractmethod
from collections.abc import Callable, Mapping, Sequence
from dataclasses import dataclass
from functools import cached_property
from typing import ClassVar, Literal, assert_never, overload
if sys.version_info >= (3, 13):
from typing import TypeIs # pragma: no cover
else:
from typing_extensions import TypeIs # pragma: no cover
from . import ast_compat as py_ast
from .ast_compat import (
DEFAULT_AST_ARGS,
DEFAULT_AST_ARGS_ADD,
DEFAULT_AST_ARGS_ARGUMENTS,
DEFAULT_AST_ARGS_MODULE,
CommentNode,
unparse_with_comments,
)
from .utils import allowable_keyword_arg_name, allowable_name
#: Type alias for comment strings stored in :attr:`Block.statements`.
Comment = str
# This module provides simple utilities for building up Python source code.
# The design originally came from fluent-compiler, so had the following aims
# and constraints:
#
# 1. Performance.
#
# The resulting Python code should do as little as possible, especially for
# simple cases.
#
# 2. Correctness (obviously)
#
# In particular, we should try to make it hard to generate code that is
# syntactically correct and therefore compiles but doesn't work. We try to
# make it hard to generate accidental name clashes, or use variables that are
# not defined.
#
# Correctness also has a security implication, since the result of this code
# might be 'exec'ed. To that end:
# * We build up AST, rather than strings. This eliminates many
# potential bugs caused by wrong escaping/interpolation.
# * the `as_ast()` methods are paranoid about input, and do many asserts.
# We do this even though other layers will usually have checked the
# input, to allow us to reason locally when checking these methods. These
# asserts must also have 100% code coverage.
#
# 3. Simplicity
#
# The resulting Python code should be easy to read and understand.
#
# 4. Predictability
#
# Since we want to test the resulting source code, we have made some design
# decisions that aim to ensure things like function argument names are
# consistent and so can be predicted easily.
# Outside of the original project `fluent-compiler`, this code will likely be
# useful for situations which have similar aims.
# It is has since evolved further aims:
# 5. Usability
#
# We add convenience wrappers to make it easier to see what the output
# is going to look like.
# --- Layers
# Within this module, there are about 3 conceptual layers, though not completely
# separated out:
# 1. At the bottom layer, there are CodeGenAst classes, such as `Number`,
# `String`, `If` etc. These are thin wrappers around Python `ast` classes.
#
# 2. Block, Scope and Expression, which provide higher level convenience.
# In particular:
#
# - Scope provides management of names, to avoid name clashes. (For this
# reason the `Name` class, which is a `CodeGenAst`, actually depends
# on `Scope`, to make it harder to create Names that clash.)
#
# - Block makes it easier to build up a list of statements.
#
# - Expression is a base class that provides a convenient chained method
# syntax for creating further expressions.
#
# 3. E-objects: These add another layer of convenience, allowing you to build up
# Python code using Python syntax rather than the chained method
# calls on Expression which can be awkward if you know statically what operations
# are to be done.
#
# Note that:
# - the bottom layer does not accept `E-objects` in its constructors
#
# - the middle layer accepts both Expression and E-objects to support
# the convenience of building up expressions easily.
#
# - the top layer, E-objects, allows mixing Expression and also allows mixing
# native Python objects which get converted to E-objects.
# ---
SENSITIVE_FUNCTIONS = {
# builtin functions that we should never be calling from our code
# generation. This is a defense-in-depth mechansim to stop our code
# generation becoming a code execution vulnerability. There should also be
# higher level code that ensures we are not generating calls to arbitrary
# Python functions. This is not a comprehensive list of functions we are not
# using, but functions we definitely don't need and are most likely to be
# used to execute remote code or to get around safety mechanisms.
"__import__",
"__build_class__",
"apply",
"compile",
"eval",
"exec",
"execfile",
"exit",
"file",
"globals",
"locals",
"open",
"object",
"reload",
"type",
}
class CodeGenAst(ABC):
"""
Base class representing a simplified Python AST (not the real one).
Generates real `ast.*` nodes via `as_ast()` method.
"""
@abstractmethod
def as_ast(self, *, include_comments: bool = False) -> py_ast.AST: ...
def as_python_source(self) -> str:
"""Return the Python source code for this AST node."""
node = self.as_ast(include_comments=True)
py_ast.fix_missing_locations(node)
return unparse_with_comments(node)
class CodeGenAstList(ABC):
"""
Alternative base class to CodeGenAst when we have code that wants to return a
list of AST objects. These must also be `stmt` objects.
"""
@abstractmethod
def as_ast_list(self, allow_empty: bool = True, *, include_comments: bool = False) -> list[py_ast.stmt]: ...
def as_python_source(self) -> str:
"""Return the Python source code for this AST list."""
mod = py_ast.Module(body=self.as_ast_list(include_comments=True), type_ignores=[], **DEFAULT_AST_ARGS_MODULE)
py_ast.fix_missing_locations(mod)
return unparse_with_comments(mod)
CodeGenAstType = CodeGenAst | CodeGenAstList
[docs]
class Scope:
"""Track name reservations and assignments within a lexical scope."""
def __init__(self, parent_scope: Scope | None = None):
self.parent_scope = parent_scope
self.names: set[str] = set()
self._function_arg_reserved_names: set[str] = set()
self._assignments: set[str] = set()
[docs]
def is_name_in_use(self, name: str) -> bool:
"""Return whether *name* is already reserved in this scope or any parent."""
if name in self.names:
return True
if self.parent_scope is None:
return False
return self.parent_scope.is_name_in_use(name)
[docs]
def is_name_reserved_function_arg(self, name: str) -> bool:
"""Return whether *name* is reserved for use as a function argument."""
if name in self._function_arg_reserved_names:
return True
if self.parent_scope is None:
return False
return self.parent_scope.is_name_reserved_function_arg(name)
[docs]
def is_name_reserved(self, name: str) -> bool:
"""Return whether *name* is reserved for any purpose in this scope."""
return self.is_name_in_use(name) or self.is_name_reserved_function_arg(name)
[docs]
def reserve_name(
self,
requested: str,
function_arg: bool = False,
is_builtin: bool = False,
):
"""
Reserve a name as being in use in a scope.
Pass function_arg=True if this is a function argument.
"""
def _add(final: str):
self.names.add(final)
return final
if function_arg:
if self.is_name_reserved_function_arg(requested):
assert not self.is_name_in_use(requested)
return _add(requested)
if self.is_name_reserved(requested):
raise AssertionError(f"Cannot use '{requested}' as argument name as it is already in use")
cleaned = cleanup_name(requested)
attempt = cleaned
count = 2 # instance without suffix is regarded as 1
# To avoid shadowing of global names in local scope, we
# take into account parent scope when assigning names.
def _is_name_allowed(name: str) -> bool:
# We need to also protect against using keywords ('class', 'def' etc.)
# i.e. count all keywords as 'used'.
# However, some builtins are also keywords (e.g. 'None'), and so
# if a builtin is being reserved, don't check against the keyword list
if (not is_builtin) and keyword.iskeyword(name):
return False
return not self.is_name_reserved(name)
while not _is_name_allowed(attempt):
attempt = cleaned + "_" + str(count)
count += 1
return _add(attempt)
[docs]
def reserve_function_arg_name(self, name: str):
"""
Reserve a name for *later* use as a function argument. This does not result
in that name being considered 'in use' in the current scope, but will
avoid the name being assigned for any use other than as a function argument.
"""
# To keep things simple, and the generated code predictable, we reserve
# names for all function arguments in a separate scope, and insist on
# the exact names
if self.is_name_reserved(name):
raise AssertionError(f"Can't reserve '{name}' as function arg name as it is already reserved")
self._function_arg_reserved_names.add(name)
[docs]
def has_assignment(self, name: str) -> bool:
"""Return whether *name* has been assigned a value in this scope."""
return name in self._assignments
[docs]
def register_assignment(self, name: str) -> None:
"""Record that *name* has been assigned a value in this scope."""
self._assignments.add(name)
[docs]
def create_name(self, name: str) -> Name:
"""Reserve *name* (or a variant) and return a :class:`Name` expression for it."""
reserved = self.reserve_name(name)
return Name(reserved, self)
[docs]
def name(self, name: str) -> Name:
"""Return a :class:`Name` expression for an already-reserved name."""
return Name(name, self)
[docs]
@cached_property
def enames(self) -> Enames:
"""
Returns an Enames object, which provides easy access to names as E-objects.
"""
return Enames(self)
_IDENTIFIER_SANITIZER_RE = re.compile("[^a-zA-Z0-9_]")
_IDENTIFIER_START_RE = re.compile("^[a-zA-Z_]")
[docs]
def cleanup_name(name: str) -> str:
"""
Convert name to a allowable identifier
"""
# See https://docs.python.org/2/reference/lexical_analysis.html#grammar-token-identifier
name = _IDENTIFIER_SANITIZER_RE.sub("", name)
if not _IDENTIFIER_START_RE.match(name):
name = "n" + name
return name
[docs]
class Statement(CodeGenAst):
"""Base class for code-generation nodes that represent Python statements."""
pass
[docs]
class Annotation(Statement):
"""A bare type annotation without a value, e.g. ``x: int``."""
def __init__(self, name: str, annotation: Expression):
self.name = name
self.annotation = annotation
[docs]
def as_ast(self, *, include_comments: bool = False):
if not allowable_name(self.name):
raise AssertionError(f"Expected {self.name} to be a valid Python identifier")
return py_ast.AnnAssign(
target=py_ast.Name(id=self.name, ctx=py_ast.Store(), **DEFAULT_AST_ARGS),
annotation=self.annotation.as_ast(),
simple=1,
value=None,
**DEFAULT_AST_ARGS,
)
type AugOp = Literal[
"+=",
"-=",
"*=",
"/=",
"//=",
"%=",
"**=",
"@=",
"<<=",
">>=",
"|=",
"&=",
"^=",
]
"""Operator strings accepted by :class:`AugAssignment` and :meth:`Block.aug_assign`."""
_AUG_OP_MAP: dict[str, type[py_ast.operator]] = {
"+=": py_ast.Add,
"-=": py_ast.Sub,
"*=": py_ast.Mult,
"/=": py_ast.Div,
"//=": py_ast.FloorDiv,
"%=": py_ast.Mod,
"**=": py_ast.Pow,
"@=": py_ast.MatMult,
"<<=": py_ast.LShift,
">>=": py_ast.RShift,
"|=": py_ast.BitOr,
"^=": py_ast.BitXor,
"&=": py_ast.BitAnd,
}
type AugAssignTarget = Name | Attr | Subscript
"""Valid targets for augmented assignment (no tuples)."""
def _is_aug_assign_target(value: object) -> TypeIs[AugAssignTarget]:
return isinstance(value, (Name, Attr, Subscript))
class AugAssignment(Statement):
"""An augmented assignment statement, e.g. ``x += 1``."""
def __init__(self, target: AugAssignTarget, op: AugOp, value: Expression, /):
if not _is_aug_assign_target(target):
raise AssertionError(
f"Invalid augmented assignment target: {type(target).__name__}. "
"Only Name, Attr, and Subscript are allowed (not tuples)."
)
if op not in _AUG_OP_MAP:
raise AssertionError(
f"Invalid augmented assignment operator: {op!r}. Must be one of: {', '.join(sorted(_AUG_OP_MAP))}"
)
self.target = target
self.op: AugOp = op
self.value = value
def as_ast(self, *, include_comments: bool = False):
target_ast = _aug_target_as_store_ast(self.target)
return py_ast.AugAssign(
target=target_ast,
op=_AUG_OP_MAP[self.op](**DEFAULT_AST_ARGS_ADD),
value=self.value.as_ast(),
**DEFAULT_AST_ARGS,
)
[docs]
class Assignment(Statement):
"""A variable assignment statement, optionally with a type annotation."""
def __init__(self, target: Target, value: Expression, /, *, type_hint: Expression | None = None):
if not is_target(target):
raise AssertionError("Invalid assignment target")
self.names: list[str] = _target_names(target)
self.target: Target = target
if type_hint is not None and not isinstance(target, Name):
raise AssertionError("Type hints are only supported for simple name assignment targets")
self.value = value
self.type_hint = type_hint
[docs]
def as_ast(self, *, include_comments: bool = False):
if self.type_hint is None:
return py_ast.Assign(
targets=[_target_as_store_ast(self.target)],
value=self.value.as_ast(),
**DEFAULT_AST_ARGS,
)
else:
# type_hint is only allowed for Name targets (enforced in __init__)
assert isinstance(self.target, Name)
target_ast = _target_as_store_ast(self.target)
assert isinstance(target_ast, py_ast.Name)
return py_ast.AnnAssign(
target=target_ast,
annotation=self.type_hint.as_ast(),
simple=1,
value=self.value.as_ast(),
**DEFAULT_AST_ARGS,
)
[docs]
class Block(CodeGenAstList):
"""An ordered sequence of statements sharing a common :class:`Scope`."""
def __init__(self, scope: Scope, parent_block: Block | None = None):
self.scope = scope
# We allow `Expression` here for things like MethodCall which
# are bare expressions that are still useful for side effects.
# `Comment` (str) entries are rendered as ``# text`` comments.
self.statements: list[Block | Statement | Expression | Comment] = []
self.parent_block = parent_block
[docs]
def as_ast_list(self, allow_empty: bool = True, *, include_comments: bool = False) -> list[py_ast.stmt]:
retval: list[py_ast.stmt] = []
for s in self.statements:
if isinstance(s, str):
# Comment
if include_comments:
retval.append(CommentNode(s, **DEFAULT_AST_ARGS)) # type: ignore[reportArgumentType]
continue
if isinstance(s, CodeGenAstList):
retval.extend(s.as_ast_list(allow_empty=True, include_comments=include_comments))
elif isinstance(s, Statement):
ast_obj = s.as_ast(include_comments=include_comments)
assert isinstance(ast_obj, py_ast.stmt), (
"Statement object return {ast_obj} which is not a subclass of py_ast.stmt"
)
retval.append(ast_obj)
else:
# Things like bare function/method calls need to be wrapped
# in `Expr` to match the way Python parses.
retval.append(py_ast.Expr(s.as_ast(include_comments=include_comments), **DEFAULT_AST_ARGS))
if len(retval) == 0 and not allow_empty:
return [py_ast.Pass(**DEFAULT_AST_ARGS)]
return retval
[docs]
def add_statement(self, statement: Statement | Block | ExpressionLike) -> None:
"""Append a statement, block, or bare expression to this block."""
if isinstance(statement, E):
statement = E_to_Expression(statement)
self.statements.append(statement)
if isinstance(statement, Block):
if statement.parent_block is None:
statement.parent_block = self
else:
if statement.parent_block != self:
raise AssertionError(
f"Block {statement} is already child of {statement.parent_block}, can't reassign to {self}"
)
[docs]
def add_statements(self, statements: Sequence[Statement | Block | ExpressionLike]) -> None:
"""Append multiple statements to this block.
This is a convenience wrapper around :meth:`add_statement`::
block.add_statements([stmt_a, stmt_b, stmt_c])
# equivalent to:
block.add_statement(stmt_a)
block.add_statement(stmt_b)
block.add_statement(stmt_c)
"""
for statement in statements:
self.add_statement(statement)
[docs]
def create_import(self, module: str, as_: str | None = None) -> tuple[Import, Name]:
"""Create an ``import`` statement, reserve the resulting name, and add it to this block."""
return_name_object: Name
if as_ is not None:
# "import foo as bar" results in `bar` name being assigned.
if not allowable_name(as_):
raise AssertionError(f"{as_!r} is not an allowable 'as' name")
if self.scope.is_name_in_use(as_):
raise AssertionError(f"{as_!r} is already assigned in the scope")
as_name_object = self.scope.create_name(as_)
return_name_object = as_name_object
else:
as_name_object = None
# "import foo" results in `foo` name being assigned
# "import foo.bar" also results in `foo` being reserved.
dotted_parts = module.split(".")
for part in dotted_parts:
if not allowable_name(part):
raise AssertionError(f"{module!r} not an allowable 'import' name")
name_to_assign = dotted_parts[0]
# We can't rename, so don't use `reserve_name` or `create_name`.
# We also need to allow for multiple imports, like `import foo.bar` then `import foo.baz`
if not self.scope.is_name_in_use(name_to_assign):
self.scope.reserve_name(name_to_assign)
return_name_object = self.scope.name(name_to_assign)
import_statement = Import(module=module, as_=as_name_object)
self.add_statement(import_statement)
return import_statement, return_name_object
[docs]
def create_import_from(self, *, from_: str, import_: str, as_: str | None = None) -> tuple[ImportFrom, Name]:
"""Create a ``from ... import`` statement, reserve the resulting name, and add it to this block."""
return_name_object: Name
if as_ is not None:
# "from foo import bar as baz" results in `baz` name being assigned.
if not allowable_name(as_):
raise AssertionError(f"{as_!r} is not an allowable 'as' name")
if self.scope.is_name_in_use(as_):
raise AssertionError(f"{as_!r} is already assigned in the scope")
as_name_object = self.scope.create_name(as_)
return_name_object = as_name_object
else:
as_name_object = None
# Check the dotted bit.
dotted_parts = from_.split(".")
for part in dotted_parts:
if not allowable_name(part):
raise AssertionError(f"{from_!r} not an allowable 'import' name")
# Check the `import_` for clashes.
name_to_assign = import_
if self.scope.is_name_in_use(name_to_assign):
raise AssertionError(f"{name_to_assign!r} is already assigned in the scope")
return_name_object = self.scope.create_name(name_to_assign)
import_statement = ImportFrom(from_module=from_, import_=import_, as_=as_name_object)
self.add_statement(import_statement)
return import_statement, return_name_object
# Safe alternatives to Block.statements being manipulated directly:
[docs]
def create_assignment(
self,
target: str | Target,
value: ExpressionLike,
*,
type_hint: ExpressionLike | None = None,
allow_multiple: bool = False,
):
"""
Adds an assigment of the form:
x = value
or more complex like:
x[0] = value
x, y = value
"""
if isinstance(target, str):
if not self.scope.is_name_in_use(target):
raise AssertionError(f"Cannot assign to unreserved name '{target}'")
target = self.scope.name(target)
names = _target_names(target)
for name in names:
if self.scope.has_assignment(name):
if not allow_multiple:
raise AssertionError(f"Have already assigned to '{name}' in this scope")
self.scope.register_assignment(name)
self.add_statement(
Assignment(
target, E_to_Expression(value), type_hint=E_to_Expression(type_hint) if type_hint is not None else None
)
)
@overload
def assign(self, target: str, value: ExpressionLike, *, type_hint: ExpressionLike | None = ...) -> Name: ...
@overload
def assign(self, target: tuple[str, ...], value: ExpressionLike) -> tuple[Name, ...]: ...
[docs]
def assign(
self,
target: str | tuple[str, ...],
value: ExpressionLike,
*,
type_hint: ExpressionLike | None = None,
) -> Name | tuple[Name, ...]:
"""
Shortcut that reserves names and creates an assignment in one step.
When *target* is a single ``str``, reserves the name and assigns to it,
returning the new :class:`Name`::
result = block.assign("x", some_expr)
# equivalent to:
# x = scope.create_name("x")
# block.create_assignment(x, some_expr)
When *target* is a ``tuple`` of ``str``, reserves each name and creates
a tuple-unpacking assignment, returning a tuple of :class:`Name`
objects::
a, b = block.assign(("a", "b"), some_pair_expr)
# equivalent to:
# a = scope.create_name("a")
# b = scope.create_name("b")
# block.create_assignment((a, b), some_pair_expr)
"""
if isinstance(target, str):
name_obj = self.scope.create_name(target)
self.create_assignment(name_obj, value, type_hint=type_hint)
return name_obj
else:
if type_hint is not None:
raise AssertionError("Can't use type hint with tuple unpacking assignment")
name_objs = tuple(self.scope.create_name(t) for t in target)
self.create_assignment(name_objs, value)
return name_objs
[docs]
def aug_assign(self, target: AugAssignTarget, op: AugOp, value: ExpressionLike, /) -> None:
"""Add an augmented assignment statement to this block.
Usage::
x = block.assign("x", auto(0))
block.aug_assign(x, "+=", auto(1)) # x += 1
*target* must be a :class:`Name`, :class:`Attr`, or :class:`Subscript`
(tuples are not valid Python augmented-assignment targets).
*op* is one of the Python augmented-assignment operator strings:
- ``"+="``
- ``"-="``
- ``"*="``
- ``"/="``
- ``"//="``
- ``"%="``
- ``"**="``
- ``"@="``
- ``"<<="``
- ``">>="``
- ``"|="``
- ``"&="``
- ``"^="``
"""
self.add_statement(AugAssignment(target, op, E_to_Expression(value)))
[docs]
def create_annotation(self, name: str, annotation: ExpressionLike) -> Name:
"""
Adds a bare type annotation of the form::
x: int
Reserves the name and adds the annotation statement to the block.
"""
name_obj = self.scope.create_name(name)
self.scope.register_assignment(name_obj.name)
self.add_statement(Annotation(name_obj.name, E_to_Expression(annotation)))
return name_obj
[docs]
def create_field(self, name: str, annotation: ExpressionLike, *, default: ExpressionLike | None = None) -> Name:
"""
Create a typed field, typically used in dataclass bodies.
If *default* is provided, creates an annotated assignment::
x: int = 0
Otherwise, creates a bare annotation::
x: int
"""
if default is not None:
name_obj = self.scope.create_name(name)
self.scope.register_assignment(name_obj.name)
self.add_statement(Assignment(name_obj, E_to_Expression(default), type_hint=E_to_Expression(annotation)))
return name_obj
else:
return self.create_annotation(name, annotation)
[docs]
def create_function(
self,
name: str,
args: Sequence[str | FunctionArg],
decorators: Sequence[ExpressionLike] | None = None,
return_type: ExpressionLike | None = None,
) -> tuple[Function, Name]:
"""
Reserve a name for a function, create the Function and add the function statement
to the block.
"""
name_obj = self.scope.create_name(name)
func = Function(
name_obj.name,
args=args,
parent_scope=self.scope,
decorators=[E_to_Expression(d) for d in decorators] if decorators is not None else None,
return_type=E_to_Expression(return_type) if return_type is not None else None,
)
self.add_statement(func)
return func, name_obj
[docs]
def create_class(
self,
name: str,
bases: Sequence[ExpressionLike] | None = None,
decorators: Sequence[ExpressionLike] | None = None,
) -> tuple[Class, Name]:
"""
Reserve a name for a class, create the Class and add the class statement
to the block.
"""
name_obj = self.scope.create_name(name)
cls = Class(
name_obj.name,
parent_scope=self.scope,
bases=[E_to_Expression(b) for b in bases] if bases is not None else None,
decorators=[E_to_Expression(d) for d in decorators] if decorators is not None else None,
)
self.add_statement(cls)
return cls, name_obj
[docs]
def create_return(self, value: ExpressionLike) -> None:
"""Add a ``return`` statement to this block."""
self.add_statement(Return(E_to_Expression(value)))
[docs]
def create_break(self) -> None:
"""Add a ``break`` statement to this block."""
self.add_statement(Break())
[docs]
def create_continue(self) -> None:
"""Add a ``continue`` statement to this block."""
self.add_statement(Continue())
[docs]
def create_assert(self, test: ExpressionLike, msg: ExpressionLike | None = None) -> None:
"""Add an ``assert`` statement to this block."""
self.add_statement(Assert(E_to_Expression(test), E_to_Expression(msg) if msg is not None else None))
[docs]
def create_raise(self, exc: ExpressionLike | None = None, cause: ExpressionLike | None = None) -> None:
"""Add a ``raise`` statement to this block."""
self.add_statement(
Raise(
E_to_Expression(exc) if exc is not None else None,
E_to_Expression(cause) if cause is not None else None,
)
)
[docs]
def create_if(self) -> If:
"""
Create an If statement, add it to this block, and return it.
Usage::
if_stmt = block.create_if()
if_block = if_stmt.add_if(condition)
if_block.create_return(value)
"""
if_statement = If(self.scope, parent_block=self)
self.add_statement(if_statement)
return if_statement
@overload
def create_with(self, context_expr: ExpressionLike, target: Name | str) -> tuple[With, Name]: ...
@overload
def create_with(self, context_expr: ExpressionLike) -> With: ...
[docs]
def create_with(self, context_expr: ExpressionLike, target: Name | str | None = None) -> With | tuple[With, Name]:
"""
Create a With statement, add it to this block, and return it
Usage::
with_stmt, target = block.create_with(expr, "f")
with_stmt.body.create_return(value)
If target is a str, the name will be reserved.
If target is `None`, only the with_statement will be returned.
"""
if isinstance(target, str):
name_obj = self.scope.create_name(target)
target = name_obj
with_statement = With(E_to_Expression(context_expr), target=target, parent_scope=self.scope, parent_block=self)
self.add_statement(with_statement)
if target is None:
return with_statement
return with_statement, target
@overload
def create_for(self, target: str, iterable: ExpressionLike) -> tuple[For, Name]: ...
@overload
def create_for(self, target: tuple[str, ...], iterable: ExpressionLike) -> tuple[For, tuple[Name, ...]]: ...
@overload
def create_for(self, target: tuple[Name, ...], iterable: ExpressionLike) -> tuple[For, tuple[Name, ...]]: ...
@overload
def create_for(self, target: Target, iterable: ExpressionLike) -> tuple[For, Target]: ...
[docs]
def create_for(self, target: str | tuple[str, ...] | Target, iterable: ExpressionLike) -> tuple[For, Target]:
"""
Create a ``for`` loop, add it to this block, and return it.
The first parameter is the loop variable. If this is a str
or tuple[str] then these names will reserved and Name objects
created, similar to `assign`.
The second parameter is an expression that will be iterated over.
Usage::
for_stmt, index = func.body.create_for("i", items)
for_stmt.body.add_statement(some_expr)
"""
target = _normalize_targets(self.scope, target)
for_statement = For(target, E_to_Expression(iterable), parent_scope=self.scope, parent_block=self)
self.add_statement(for_statement)
return for_statement, target
[docs]
def create_try(self) -> Try:
"""
Create a Try statement, add it to this block, and return it.
Add ``except`` clauses via :meth:`Try.create_except`.
Usage::
try_stmt = block.create_try()
try_stmt.try_block.add_statement(some_expr)
except_block, e_name = try_stmt.create_except([my_error], "e")
except_block.create_return(value)
"""
try_statement = Try(
self.scope,
parent_block=self,
)
self.add_statement(try_statement)
return try_statement
def _normalize_targets(scope: Scope, target: str | tuple[str, ...] | Target) -> Target:
if isinstance(target, str):
name_obj = scope.create_name(target)
target = name_obj
elif isinstance(target, tuple):
target = tuple([_normalize_targets(scope, t) for t in target])
else:
assert is_target(target), f"Expected str, tuple or Target, got {type(target)}"
return target
[docs]
class Module(Block, CodeGenAst):
"""Top-level module block with its own :class:`Scope` pre-loaded with builtins."""
def __init__(self, reserve_builtins: bool = True):
scope = Scope(parent_scope=None)
if reserve_builtins:
for name in dir(builtins):
scope.reserve_name(name, is_builtin=True)
Block.__init__(self, scope)
[docs]
def as_ast(self, *, include_comments: bool = False) -> py_ast.Module:
return py_ast.Module(
body=self.as_ast_list(include_comments=include_comments), type_ignores=[], **DEFAULT_AST_ARGS_MODULE
)
[docs]
def as_python_source(self) -> str:
mod = self.as_ast(include_comments=True)
py_ast.fix_missing_locations(mod)
return unparse_with_comments(mod)
[docs]
@cached_property
def enames(self) -> Enames:
"""
Provides access to scope names as E-objects
e.g. Module().enames.str for the `str` Name that is registered as a builtin.
"""
return self.scope.enames
[docs]
class ArgKind(enum.Enum):
"""The kind of a function argument."""
POSITIONAL_ONLY = "positional_only"
POSITIONAL_OR_KEYWORD = "positional_or_keyword"
KEYWORD_ONLY = "keyword_only"
[docs]
@dataclass(frozen=True)
class FunctionArg:
"""A function argument with a name, kind, and optional default value."""
name: str
kind: ArgKind = ArgKind.POSITIONAL_OR_KEYWORD
default: Expression | None = None
annotation: Expression | None = None
[docs]
@classmethod
def positional(
cls, name: str, *, default: ExpressionLike | None = None, annotation: ExpressionLike | None = None
) -> FunctionArg:
"""Create a positional-only argument."""
return cls(
name=name,
kind=ArgKind.POSITIONAL_ONLY,
default=E_to_Expression(default) if default is not None else None,
annotation=E_to_Expression(annotation) if annotation is not None else None,
)
[docs]
@classmethod
def keyword(
cls, name: str, *, default: ExpressionLike | None = None, annotation: ExpressionLike | None = None
) -> FunctionArg:
"""Create a keyword-only argument."""
return cls(
name=name,
kind=ArgKind.KEYWORD_ONLY,
default=E_to_Expression(default) if default is not None else None,
annotation=E_to_Expression(annotation) if annotation is not None else None,
)
[docs]
@classmethod
def standard(
cls, name: str, *, default: ExpressionLike | None = None, annotation: ExpressionLike | None = None
) -> FunctionArg:
"""Create a positional-or-keyword argument (the Python default)."""
return cls(
name=name,
kind=ArgKind.POSITIONAL_OR_KEYWORD,
default=E_to_Expression(default) if default is not None else None,
annotation=E_to_Expression(annotation) if annotation is not None else None,
)
def _normalize_args(args: Sequence[str | FunctionArg]) -> list[FunctionArg]:
"""Normalize a mixed list of str and FunctionArg into a list of FunctionArg."""
return [FunctionArg(name=a) if isinstance(a, str) else a for a in args]
def _validate_arg_order(args: list[FunctionArg]) -> None:
"""Validate that args are in the correct order:
positional-only, then positional-or-keyword, then keyword-only.
Within each group, defaults must come after non-defaults.
"""
# Check kind ordering
KIND_ORDER = {
ArgKind.POSITIONAL_ONLY: 0,
ArgKind.POSITIONAL_OR_KEYWORD: 1,
ArgKind.KEYWORD_ONLY: 2,
}
prev_order = -1
for arg in args:
order = KIND_ORDER[arg.kind]
if order < prev_order:
raise ValueError(
f"Argument '{arg.name}' of kind {arg.kind.value} "
f"is out of order: positional-only args must come first, "
f"then positional-or-keyword, then keyword-only"
)
prev_order = order
# Check default ordering within positional groups
# (positional-only and positional-or-keyword share defaults list,
# so non-default can't follow default across these groups)
seen_default_in_positional = False
for arg in args:
if arg.kind in (ArgKind.POSITIONAL_ONLY, ArgKind.POSITIONAL_OR_KEYWORD):
if arg.default is not None:
seen_default_in_positional = True
elif seen_default_in_positional:
raise ValueError(f"Non-default argument '{arg.name}' follows default argument in positional arguments")
# keyword-only args can have defaults in any order (Python allows it)
[docs]
class Function(Scope, Statement):
"""A function definition statement that also acts as a :class:`Scope` for its body."""
def __init__(
self,
name: str,
args: Sequence[str | FunctionArg] | None = None,
parent_scope: Scope | None = None,
decorators: Sequence[Expression] | None = None,
return_type: Expression | None = None,
):
super().__init__(parent_scope=parent_scope)
self.body = Block(self)
self.func_name = name
self.decorators: list[Expression] = list(decorators) if decorators else []
self.return_type: Expression | None = return_type
self._args: list[FunctionArg] = []
if args is not None:
self.add_args(args)
@property
def args(self) -> Sequence[FunctionArg]:
"""Return the function's arguments as a read-only sequence."""
return tuple(self._args)
[docs]
def add_args(self, args: Sequence[str | FunctionArg]) -> None:
"""Add arguments to the function, with the same validation as in __init__."""
normalized = _normalize_args(args)
combined = self._args + normalized
_validate_arg_order(combined)
for arg in normalized:
if self.is_name_in_use(arg.name):
raise AssertionError(f"Can't use '{arg.name}' as function argument name because it shadows other names")
self.reserve_name(arg.name, function_arg=True)
self._args = combined
[docs]
def as_ast(self, *, include_comments: bool = False) -> py_ast.stmt:
if not allowable_name(self.func_name):
raise AssertionError(f"Expected '{self.func_name}' to be a valid Python identifier")
arguments = _function_arg_list_to_ast_arguments(self._args)
return py_ast.FunctionDef(
name=self.func_name,
args=arguments,
body=self.body.as_ast_list(allow_empty=False, include_comments=include_comments),
decorator_list=[d.as_ast() for d in self.decorators],
type_params=[],
returns=self.return_type.as_ast() if self.return_type is not None else None,
**DEFAULT_AST_ARGS,
)
[docs]
def create_return(self, value: ExpressionLike):
"""Add a ``return`` statement to the function body."""
self.body.create_return(E_to_Expression(value))
def _function_arg_list_to_ast_arguments(args: Sequence[FunctionArg]):
for arg in args:
if not allowable_name(arg.name):
raise AssertionError(f"Expected '{arg.name}' to be a valid Python identifier")
def _make_arg(a: FunctionArg) -> py_ast.arg:
return py_ast.arg(
arg=a.name,
annotation=a.annotation.as_ast() if a.annotation is not None else None,
**DEFAULT_AST_ARGS,
)
posonlyargs = [_make_arg(a) for a in args if a.kind == ArgKind.POSITIONAL_ONLY]
regular_args = [_make_arg(a) for a in args if a.kind == ArgKind.POSITIONAL_OR_KEYWORD]
kwonlyargs = [_make_arg(a) for a in args if a.kind == ArgKind.KEYWORD_ONLY]
# defaults: right-aligned to posonlyargs + regular_args
positional_all = [a for a in args if a.kind in (ArgKind.POSITIONAL_ONLY, ArgKind.POSITIONAL_OR_KEYWORD)]
defaults = [a.default.as_ast() for a in positional_all if a.default is not None]
# kw_defaults: one entry per kwonlyarg, None if no default
kw_defaults: list[py_ast.expr | None] = [
a.default.as_ast() if a.default is not None else None for a in args if a.kind == ArgKind.KEYWORD_ONLY
]
return py_ast.arguments(
posonlyargs=posonlyargs,
args=regular_args,
vararg=None,
kwonlyargs=kwonlyargs,
kw_defaults=kw_defaults,
kwarg=None,
defaults=defaults,
**DEFAULT_AST_ARGS_ARGUMENTS,
)
[docs]
class Class(Scope, Statement):
"""A class definition statement that also acts as a :class:`Scope` for its body."""
def __init__(
self,
name: str,
parent_scope: Scope | None = None,
bases: Sequence[Expression] | None = None,
decorators: Sequence[Expression] | None = None,
):
super().__init__(parent_scope=parent_scope)
self.body = Block(self)
self.class_name = name
self.bases: list[Expression] = list(bases) if bases else []
self.decorators: list[Expression] = list(decorators) if decorators else []
[docs]
def as_ast(self, *, include_comments: bool = False) -> py_ast.stmt:
if not allowable_name(self.class_name):
raise AssertionError(f"Expected '{self.class_name}' to be a valid Python identifier")
return py_ast.ClassDef(
name=self.class_name,
bases=[b.as_ast() for b in self.bases],
keywords=[],
body=self.body.as_ast_list(allow_empty=False, include_comments=include_comments),
decorator_list=[d.as_ast() for d in self.decorators],
type_params=[],
**DEFAULT_AST_ARGS,
)
[docs]
class Return(Statement):
"""A ``return`` statement."""
def __init__(self, value: Expression):
self.value = value
[docs]
def as_ast(self, *, include_comments: bool = False):
return py_ast.Return(self.value.as_ast(), **DEFAULT_AST_ARGS)
def __repr__(self):
return f"Return({repr(self.value)}"
[docs]
class Break(Statement):
"""A ``break`` statement."""
[docs]
def as_ast(self, *, include_comments: bool = False):
return py_ast.Break(**DEFAULT_AST_ARGS)
[docs]
class Continue(Statement):
"""A ``continue`` statement."""
[docs]
def as_ast(self, *, include_comments: bool = False):
return py_ast.Continue(**DEFAULT_AST_ARGS)
[docs]
class Assert(Statement):
"""An ``assert`` statement with an optional message."""
def __init__(self, test: Expression, msg: Expression | None = None):
self.test = test
self.msg = msg
[docs]
def as_ast(self, *, include_comments: bool = False):
msg_ast = self.msg.as_ast() if self.msg is not None else None
return py_ast.Assert(
test=self.test.as_ast(),
msg=msg_ast,
**DEFAULT_AST_ARGS,
)
def __repr__(self):
return f"Assert({repr(self.test)}, {repr(self.msg)})"
[docs]
class Raise(Statement):
"""A ``raise`` statement.
Supports:
* ``raise exc``
* ``raise exc from cause``
* bare ``raise`` (re-raise the current exception)
"""
def __init__(self, exc: Expression | None = None, cause: Expression | None = None):
if cause is not None and exc is None:
raise AssertionError("Cannot use 'cause' without 'exc'")
self.exc = exc
self.cause = cause
[docs]
def as_ast(self, *, include_comments: bool = False):
return py_ast.Raise(
exc=self.exc.as_ast() if self.exc is not None else None,
cause=self.cause.as_ast() if self.cause is not None else None,
**DEFAULT_AST_ARGS,
)
def __repr__(self):
return f"Raise({repr(self.exc)}, {repr(self.cause)})"
[docs]
class If(Statement):
"""A compound ``if``/``elif``/``else`` statement."""
def __init__(self, parent_scope: Scope, parent_block: Block | None = None):
# We model a "compound if statement" as a list of if blocks
# (if/elif/elif etc), each with their own condition, with a final else
# block. Note this is quite different from Python's AST for the same
# thing, so conversion to AST is more complex because of this.
self.if_blocks: list[Block] = []
self.conditions: list[Expression] = []
self._parent_block = parent_block
self.else_block = Block(parent_scope, parent_block=self._parent_block)
self._parent_scope = parent_scope
[docs]
def create_if_branch(self, condition: ExpressionLike) -> Block:
"""
Create new if branch with a condition.
"""
new_if = Block(self._parent_scope, parent_block=self._parent_block)
self.if_blocks.append(new_if)
self.conditions.append(E_to_Expression(condition))
return new_if
[docs]
def finalize(self) -> Block | Statement:
"""Return a simplified node: the else block if there are no conditions, otherwise *self*."""
if not self.if_blocks:
# Unusual case of no conditions, only default case, but it
# simplifies other code to be able to handle this uniformly. We can
# replace this if statement with a single unconditional block.
return self.else_block
return self
[docs]
def as_ast(self, *, include_comments: bool = False) -> py_ast.If:
if len(self.if_blocks) == 0:
raise AssertionError("Should have called `finalize` on If")
if_ast = empty_If()
current_if = if_ast
previous_if = None
for condition, if_block in zip(self.conditions, self.if_blocks):
current_if.test = condition.as_ast()
current_if.body = if_block.as_ast_list(include_comments=include_comments)
if previous_if is not None:
previous_if.orelse.append(current_if)
previous_if = current_if
current_if = empty_If()
if self.else_block.statements:
assert previous_if is not None
previous_if.orelse = self.else_block.as_ast_list(include_comments=include_comments)
return if_ast
[docs]
class With(Statement):
"""A ``with`` statement."""
def __init__(
self,
context_expr: Expression,
target: Name | None = None,
*,
parent_scope: Scope,
parent_block: Block | None = None,
):
self.context_expr = context_expr
self.target = target
self._parent_scope = parent_scope
self._parent_block = parent_block
self.body = Block(parent_scope, parent_block=parent_block)
[docs]
def as_ast(self, *, include_comments: bool = False) -> py_ast.With:
optional_vars = None
if self.target is not None:
optional_vars = py_ast.Name(id=self.target.name, ctx=py_ast.Store(), **DEFAULT_AST_ARGS)
return py_ast.With(
items=[
py_ast.withitem(
context_expr=self.context_expr.as_ast(),
optional_vars=optional_vars,
)
],
body=self.body.as_ast_list(allow_empty=False, include_comments=include_comments),
**DEFAULT_AST_ARGS,
)
[docs]
class For(Statement):
"""A ``for`` loop, with optional ``else`` clause."""
def __init__(
self,
target: Target,
iterable: Expression,
*,
parent_scope: Scope,
parent_block: Block | None = None,
):
if not is_target(target):
raise AssertionError("Invalid for-loop target")
self.target: Target = target
self.iterable = iterable
self._parent_scope = parent_scope
self._parent_block = parent_block
self.body = Block(parent_scope, parent_block=parent_block)
self.else_block = Block(parent_scope, parent_block=parent_block)
[docs]
def as_ast(self, *, include_comments: bool = False) -> py_ast.For:
return py_ast.For(
target=_target_as_store_ast(self.target),
iter=self.iterable.as_ast(),
body=self.body.as_ast_list(allow_empty=False, include_comments=include_comments),
orelse=self.else_block.as_ast_list(allow_empty=True, include_comments=include_comments),
**DEFAULT_AST_ARGS,
)
[docs]
class Try(Statement):
"""A ``try``/``except``/``else``/``finally`` statement.
Except clauses are added incrementally via :meth:`create_except`,
similar to how :meth:`If.create_if_branch` works.
"""
def __init__(
self,
parent_scope: Scope,
*,
parent_block: Block | None = None,
):
self._parent_scope = parent_scope
self._parent_block = parent_block
self.try_block = Block(parent_scope, parent_block=parent_block)
self.except_blocks: list[Block] = []
self.except_types: list[list[Expression]] = []
self.except_names: list[str | None] = []
self.else_block = Block(parent_scope, parent_block=parent_block)
self.finally_block = Block(parent_scope, parent_block=parent_block)
@overload
def create_except(
self,
catch_exceptions: Sequence[ExpressionLike],
*,
name: str | Name,
) -> tuple[Block, Name]: ...
@overload
def create_except(self, catch_exceptions: Sequence[ExpressionLike]) -> Block: ...
[docs]
def create_except(
self,
catch_exceptions: Sequence[ExpressionLike],
*,
name: str | Name | None = None,
) -> Block | tuple[Block, Name]:
"""
Add an ``except`` clause and return its body block.
*catch_exceptions* is the list of exception types to catch
(a single-element list produces ``except Foo:``, multiple
elements produce ``except (Foo, Bar):``).
*name*, if given, becomes the ``as`` target
(``except Foo as name:``). If it is passed as a `str`
it is reserved in the scope. It is returned as a Name object.
"""
block = Block(self._parent_scope, parent_block=self._parent_block)
self.except_blocks.append(block)
self.except_types.append([E_to_Expression(e) for e in catch_exceptions])
# Normalise `name` to `str | None` for appending to `except_names`
# and `name_obj` to a `Name | None` for returning
if isinstance(name, str):
name_obj = self._parent_scope.create_name(name)
elif isinstance(name, Name):
name_obj = name
name = name_obj.name
else:
name_obj = None
self.except_names.append(name)
if name_obj is not None:
return block, name_obj
return block
def _handler_type_ast(self, exceptions: list[Expression]) -> py_ast.expr:
if len(exceptions) == 1:
return exceptions[0].as_ast()
return py_ast.Tuple(
elts=[e.as_ast() for e in exceptions],
ctx=py_ast.Load(),
**DEFAULT_AST_ARGS,
)
[docs]
def as_ast(self, *, include_comments: bool = False) -> py_ast.Try:
return py_ast.Try(
body=self.try_block.as_ast_list(allow_empty=False, include_comments=include_comments),
handlers=[
py_ast.ExceptHandler(
type=self._handler_type_ast(exc_types),
name=exc_name,
body=exc_block.as_ast_list(allow_empty=False, include_comments=include_comments),
**DEFAULT_AST_ARGS,
)
for exc_types, exc_name, exc_block in zip(self.except_types, self.except_names, self.except_blocks)
],
orelse=self.else_block.as_ast_list(allow_empty=True, include_comments=include_comments),
finalbody=self.finally_block.as_ast_list(allow_empty=True, include_comments=include_comments),
**DEFAULT_AST_ARGS,
)
[docs]
class Import(Statement):
"""
Simple import statements, supporting:
- import foo
- import foo as bar
- import foo.bar
- import foo.bar as baz
Use via `Block.create_import`
We deliberately don't support multiple imports - these should
be cleaned up later using a linter on the generated code, if desired.
"""
def __init__(self, module: str, as_: Name | None) -> None:
self.module = module
self.as_ = as_
[docs]
def as_ast(self, *, include_comments: bool = False) -> py_ast.AST:
if self.as_ is None:
# No alias needed:
return py_ast.Import(names=[py_ast.alias(name=self.module)], **DEFAULT_AST_ARGS)
else:
return py_ast.Import(names=[py_ast.alias(name=self.module, asname=self.as_.name)], **DEFAULT_AST_ARGS)
[docs]
class ImportFrom(Statement):
"""
``from ... import`` statement, supporting:
- ``from foo import bar``
- ``from foo import bar as baz``
Use via `Block.create_import_from`
We deliberately don't support multiple imports - these should
be cleaned up later using a linter on the generated code, if desired.
"""
def __init__(self, from_module: str, import_: str, as_: Name | None) -> None:
self.from_module = from_module
self.import_ = import_
self.as_ = as_
[docs]
def as_ast(self, *, include_comments: bool = False) -> py_ast.AST:
if self.as_ is None:
# No alias needed:
return py_ast.ImportFrom(
module=self.from_module,
names=[py_ast.alias(name=self.import_)],
level=0,
**DEFAULT_AST_ARGS,
)
else:
return py_ast.ImportFrom(
module=self.from_module,
names=[py_ast.alias(name=self.import_, asname=self.as_.name)],
level=0,
**DEFAULT_AST_ARGS,
)
[docs]
class Expression(CodeGenAst):
"""Base class for code-generation nodes that represent Python expressions."""
[docs]
@abstractmethod
def as_ast(self, *, include_comments: bool = False) -> py_ast.expr: ...
# Conversion:
[docs]
@classmethod
def from_e(cls, obj: Expression | E) -> Expression:
return E_to_Expression(obj)
# Some utilities for easy chaining:
[docs]
def attr(self, attribute: str, /) -> Attr:
"""Return an :class:`Attr` expression for accessing *attribute* on this expression."""
return Attr(self, attribute)
[docs]
def call(
self,
args: Sequence[ExpressionLike] | None = None,
kwargs: Mapping[str, ExpressionLike] | None = None,
) -> Call:
"""Return a :class:`Call` expression invoking this expression."""
return Call(
self,
[E_to_Expression(arg) for arg in args] if args is not None else [],
{k: E_to_Expression(val) for k, val in kwargs.items()} if kwargs is not None else {},
)
[docs]
def method_call(
self,
attribute: str,
args: Sequence[ExpressionLike] | None = None,
kwargs: Mapping[str, ExpressionLike] | None = None,
) -> Call:
"""Return a :class:`Call` expression for a method call on this expression."""
return self.attr(attribute).call(args, kwargs)
[docs]
def subscript(self, slice: ExpressionLike, /) -> Subscript:
"""Return a :class:`Subscript` expression indexing this expression."""
return Subscript(self, E_to_Expression(slice))
# Arithmetic operators
[docs]
def add(self, other: ExpressionLike, /) -> Add:
"""Return an :class:`Add` (``+``) expression."""
return Add(self, E_to_Expression(other))
[docs]
def sub(self, other: ExpressionLike, /) -> Sub:
"""Return a :class:`Sub` (``-``) expression."""
return Sub(self, E_to_Expression(other))
[docs]
def mul(self, other: ExpressionLike, /) -> Mul:
"""Return a :class:`Mul` (``*``) expression."""
return Mul(self, E_to_Expression(other))
[docs]
def div(self, other: ExpressionLike, /) -> Div:
"""Return a :class:`Div` (``/``) expression."""
return Div(self, E_to_Expression(other))
[docs]
def floordiv(self, other: ExpressionLike, /) -> FloorDiv:
"""Return a :class:`FloorDiv` (``//``) expression."""
return FloorDiv(self, E_to_Expression(other))
[docs]
def mod(self, other: ExpressionLike, /) -> Mod:
"""Return a :class:`Mod` (``%``) expression."""
return Mod(self, E_to_Expression(other))
[docs]
def pow(self, other: ExpressionLike, /) -> Pow:
"""Return a :class:`Pow` (``**``) expression."""
return Pow(self, E_to_Expression(other))
[docs]
def matmul(self, other: ExpressionLike, /) -> MatMul:
"""Return a :class:`MatMul` (``@``) expression."""
return MatMul(self, E_to_Expression(other))
# Bitwise operators
[docs]
def bitand(self, other: ExpressionLike, /) -> BitAnd:
"""Return a :class:`BitAnd` (``&``) expression."""
return BitAnd(self, E_to_Expression(other))
[docs]
def bitor(self, other: ExpressionLike, /) -> BitOr:
"""Return a :class:`BitOr` (``|``) expression."""
return BitOr(self, E_to_Expression(other))
[docs]
def xor(self, other: ExpressionLike, /) -> BitXor:
"""Return a :class:`BitXor` (``^``) expression."""
return BitXor(self, E_to_Expression(other))
[docs]
def lshift(self, other: ExpressionLike, /) -> LShift:
"""Return a :class:`LShift` (``<<``) expression."""
return LShift(self, E_to_Expression(other))
[docs]
def rshift(self, other: ExpressionLike, /) -> RShift:
"""Return a :class:`RShift` (``>>``) expression."""
return RShift(self, E_to_Expression(other))
[docs]
def invert(self) -> Invert:
"""Return an :class:`Invert` (``~self``) expression."""
return Invert(self)
# Comparison operators
[docs]
def eq(self, other: ExpressionLike, /) -> Equals:
"""Return an :class:`Equals` (``==``) expression."""
return Equals(self, E_to_Expression(other))
[docs]
def ne(self, other: ExpressionLike, /) -> NotEquals:
"""Return a :class:`NotEquals` (``!=``) expression."""
return NotEquals(self, E_to_Expression(other))
[docs]
def lt(self, other: ExpressionLike, /) -> Lt:
"""Return a :class:`Lt` (``<``) expression."""
return Lt(self, E_to_Expression(other))
[docs]
def gt(self, other: ExpressionLike, /) -> Gt:
"""Return a :class:`Gt` (``>``) expression."""
return Gt(self, E_to_Expression(other))
[docs]
def le(self, other: ExpressionLike, /) -> LtE:
"""Return a :class:`LtE` (``<=``) expression."""
return LtE(self, E_to_Expression(other))
[docs]
def ge(self, other: ExpressionLike, /) -> GtE:
"""Return a :class:`GtE` (``>=``) expression."""
return GtE(self, E_to_Expression(other))
# Boolean operators
[docs]
def and_(self, other: ExpressionLike, /) -> And:
"""Return an :class:`And` (``and``) expression."""
return And(self, E_to_Expression(other))
[docs]
def or_(self, other: ExpressionLike, /) -> Or:
"""Return an :class:`Or` (``or``) expression."""
return Or(self, E_to_Expression(other))
# Membership operators
[docs]
def in_(self, other: ExpressionLike, /) -> In:
"""Return an :class:`In` (``in``) expression."""
return In(self, E_to_Expression(other))
[docs]
def not_in(self, other: ExpressionLike, /) -> NotIn:
"""Return a :class:`NotIn` (``not in``) expression."""
return NotIn(self, E_to_Expression(other))
# Identity operators
[docs]
def is_(self, other: ExpressionLike, /) -> Is:
"""Return an :class:`Is` (``is``) expression."""
return Is(self, E_to_Expression(other))
[docs]
def is_not(self, other: ExpressionLike, /) -> IsNot:
"""Return an :class:`IsNot` (``is not``) expression."""
return IsNot(self, E_to_Expression(other))
# Unary operators
[docs]
def not_(self) -> Not:
"""Return a :class:`Not` (``not self``) expression."""
return Not(self)
[docs]
def pos(self) -> UAdd:
"""Return a :class:`UAdd` (``+self``) expression."""
return UAdd(self)
[docs]
def neg(self) -> USub:
"""Return a :class:`USub` (``-self``) expression."""
return USub(self)
# Unpacking
[docs]
def starred(self) -> Starred:
"""Return a :class:`Starred` (``*self``) unpacking expression."""
return Starred(self)
# Walrus / NamedExpr
[docs]
def named(self, name: Name) -> NamedExpr:
"""
Return the expression as a named expression (walrus operator)
"""
return NamedExpr(name=name, value=self)
# E-object:
[docs]
@cached_property
def e(self) -> E:
"""
Returns this Expression as an E-object for easier expression generation.
"""
return E(self)
[docs]
class String(Expression):
"""A string literal expression."""
def __init__(self, string_value: str):
self.string_value = string_value
[docs]
def as_ast(self, *, include_comments: bool = False) -> py_ast.expr:
return py_ast.Constant(
self.string_value,
kind=None, # 3.8, indicates no prefix, needed only for tests
**DEFAULT_AST_ARGS,
)
def __repr__(self):
return f"String({repr(self.string_value)})"
def __eq__(self, other: object):
return isinstance(other, String) and other.string_value == self.string_value
[docs]
class Bool(Expression):
"""A boolean literal expression (``True`` or ``False``)."""
def __init__(self, value: bool):
self.value = value
[docs]
def as_ast(self, *, include_comments: bool = False) -> py_ast.expr:
return py_ast.Constant(self.value, **DEFAULT_AST_ARGS)
def __repr__(self):
return f"Bool({self.value!r})"
[docs]
class Bytes(Expression):
"""A bytes literal expression."""
def __init__(self, value: bytes):
self.value = value
[docs]
def as_ast(self, *, include_comments: bool = False) -> py_ast.expr:
return py_ast.Constant(self.value, **DEFAULT_AST_ARGS)
def __repr__(self):
return f"Bytes({self.value!r})"
[docs]
class Number(Expression):
"""A numeric literal expression (``int`` or ``float``)."""
def __init__(self, number: int | float):
self.number = number
[docs]
def as_ast(self, *, include_comments: bool = False) -> py_ast.expr:
return py_ast.Constant(self.number, **DEFAULT_AST_ARGS)
def __repr__(self):
return f"Number({repr(self.number)})"
[docs]
class List(Expression):
"""A list literal expression."""
def __init__(self, items: Sequence[Expression]):
self.items = items
[docs]
def as_ast(self, *, include_comments: bool = False) -> py_ast.expr:
return py_ast.List(elts=[i.as_ast() for i in self.items], ctx=py_ast.Load(), **DEFAULT_AST_ARGS)
[docs]
class Tuple(Expression):
"""A tuple literal expression."""
def __init__(self, items: Sequence[Expression]):
self.items = items
[docs]
def as_ast(self, *, include_comments: bool = False) -> py_ast.expr:
return py_ast.Tuple(elts=[i.as_ast() for i in self.items], ctx=py_ast.Load(), **DEFAULT_AST_ARGS)
[docs]
class Set(Expression):
"""A set literal expression, using ``set([])`` for the empty case."""
def __init__(self, items: Sequence[Expression]):
self.items = items
[docs]
def as_ast(self, *, include_comments: bool = False) -> py_ast.expr:
if len(self.items) == 0:
# {} is a dict literal in Python, so empty sets must use set([])
return py_ast.Call(
func=py_ast.Name(id="set", ctx=py_ast.Load(), **DEFAULT_AST_ARGS),
args=[py_ast.List(elts=[], ctx=py_ast.Load(), **DEFAULT_AST_ARGS)],
keywords=[],
**DEFAULT_AST_ARGS,
)
return py_ast.Set(elts=[i.as_ast() for i in self.items], **DEFAULT_AST_ARGS)
[docs]
class Dict(Expression):
"""A dict literal expression."""
def __init__(self, pairs: Sequence[tuple[Expression, Expression]]):
self.pairs = pairs
[docs]
def as_ast(self, *, include_comments: bool = False) -> py_ast.expr:
return py_ast.Dict(
keys=[k.as_ast() for k, _ in self.pairs],
values=[v.as_ast() for _, v in self.pairs],
**DEFAULT_AST_ARGS,
)
[docs]
class StringJoinBase(Expression):
"""Base class for string concatenation expressions."""
def __init__(self, parts: Sequence[Expression]):
self.parts = parts
def __repr__(self):
return f"{self.__class__.__name__}([{', '.join(repr(p) for p in self.parts)}])"
[docs]
@classmethod
def build(cls: type[StringJoinBase], parts: Sequence[Expression]) -> StringJoinBase | Expression:
"""
Build a string join operation, but return a simpler expression if possible.
"""
# Merge adjacent String objects.
new_parts: list[Expression] = []
for part in parts:
if len(new_parts) > 0 and isinstance(new_parts[-1], String) and isinstance(part, String):
new_parts[-1] = String(new_parts[-1].string_value + part.string_value)
else:
new_parts.append(part)
parts = new_parts
# See if we can eliminate the StringJoin altogether
if len(parts) == 0:
return String("")
if len(parts) == 1:
return parts[0]
return cls(parts)
[docs]
class FStringJoin(StringJoinBase):
"""Join string parts using an f-string expression."""
[docs]
def as_ast(self, *, include_comments: bool = False) -> py_ast.expr:
# f-strings
values: list[py_ast.expr] = []
for part in self.parts:
if isinstance(part, String):
values.append(part.as_ast())
else:
values.append(
py_ast.FormattedValue(
value=part.as_ast(),
conversion=-1,
format_spec=None,
**DEFAULT_AST_ARGS,
)
)
return py_ast.JoinedStr(values=values, **DEFAULT_AST_ARGS)
[docs]
class ConcatJoin(StringJoinBase):
"""Join string parts using ``+`` concatenation."""
[docs]
def as_ast(self, *, include_comments: bool = False) -> py_ast.expr:
# Concatenate with +
left = self.parts[0].as_ast()
for part in self.parts[1:]:
right = part.as_ast()
left = py_ast.BinOp(
left=left,
op=py_ast.Add(**DEFAULT_AST_ARGS_ADD),
right=right,
**DEFAULT_AST_ARGS,
)
return left
# For CPython, f-strings give a measurable improvement over concatenation,
# so make that default
StringJoin = FStringJoin
[docs]
class Name(Expression):
"""A reference to a named variable that must already be reserved in its :class:`Scope`."""
def __init__(self, name: str, scope: Scope):
if not scope.is_name_in_use(name):
raise AssertionError(f"Cannot refer to undefined name '{name}'")
self.name = name
[docs]
def as_ast(self, *, include_comments: bool = False) -> py_ast.Name:
if not allowable_name(self.name, allow_builtin=True):
raise AssertionError(f"Expected {self.name} to be a valid Python identifier")
return py_ast.Name(id=self.name, ctx=py_ast.Load(), **DEFAULT_AST_ARGS)
def __eq__(self, other: object):
return type(other) is type(self) and other.name == self.name
def __repr__(self):
return f"Name({repr(self.name)})"
[docs]
class Attr(Expression):
"""An attribute access expression (e.g. ``obj.attr``)."""
def __init__(self, value: Expression, attribute: str) -> None:
self.value = value
if not allowable_name(attribute, allow_builtin=True):
raise AssertionError(f"Expected {attribute} to be a valid Python identifier")
self.attribute = attribute
[docs]
def as_ast(self, *, include_comments: bool = False) -> py_ast.expr:
return py_ast.Attribute(value=self.value.as_ast(), attr=self.attribute, **DEFAULT_AST_ARGS)
[docs]
class Starred(Expression):
"""A starred (unpacking) expression (e.g. ``*args``)."""
def __init__(self, value: Expression):
self.value = value
[docs]
def as_ast(self, *, include_comments: bool = False) -> py_ast.expr:
return py_ast.Starred(value=self.value.as_ast(), ctx=py_ast.Load(), **DEFAULT_AST_ARGS)
def __repr__(self):
return f"Starred({self.value!r})"
class NamedExpr(Expression):
"""
A named expression (walrus operator) e.g. `x := 1`
"""
def __init__(self, name: Name, value: Expression) -> None:
self.name = name
self.value = value
def as_ast(self, *, include_comments: bool = False) -> py_ast.expr:
return py_ast.NamedExpr(target=self.name.as_ast(), value=self.value.as_ast(), **DEFAULT_AST_ARGS)
[docs]
def function_call(
function_name: str,
args: Sequence[Expression],
kwargs: Mapping[str, Expression],
scope: Scope,
) -> Expression:
"""Create a function call expression, validating that the name exists in *scope*."""
if not scope.is_name_in_use(function_name):
raise AssertionError(f"Cannot call unknown function '{function_name}'")
if function_name in SENSITIVE_FUNCTIONS:
raise AssertionError(f"Disallowing call to '{function_name}'")
return Name(name=function_name, scope=scope).call(args, kwargs)
[docs]
class Call(Expression):
"""A function or method call expression."""
def __init__(
self,
value: Expression,
args: Sequence[Expression],
kwargs: Mapping[str, Expression],
):
self.value = value
self.args = list(args)
self.kwargs = dict(kwargs)
[docs]
def as_ast(self, *, include_comments: bool = False) -> py_ast.expr:
for name in self.kwargs.keys():
if not allowable_keyword_arg_name(name):
raise AssertionError(f"Expected {name} to be a valid Fluent NamedArgument name")
if any(not allowable_name(name) for name in self.kwargs.keys()):
# This branch covers function arg names like 'foo-bar', which are
# allowable in languages like Fluent, but not normally in Python. We work around
# this using `my_function(**{'foo-bar': baz})` syntax.
# (If we only wanted to exec the resulting AST, this branch is technically not
# necessary, since it is the Python parser that disallows `foo-bar` as an identifier,
# and we are by-passing that by creating AST directly. However, to produce something
# that can be decompiled to valid Python, we solve the general case).
kwarg_pairs = list(sorted(self.kwargs.items()))
kwarg_names, kwarg_values = [k for k, _ in kwarg_pairs], [v for _, v in kwarg_pairs]
return py_ast.Call(
func=self.value.as_ast(),
args=[arg.as_ast() for arg in self.args],
keywords=[
py_ast.keyword(
arg=None,
value=py_ast.Dict(
keys=[py_ast.Constant(k, kind=None, **DEFAULT_AST_ARGS) for k in kwarg_names],
values=[v.as_ast() for v in kwarg_values],
**DEFAULT_AST_ARGS,
),
**DEFAULT_AST_ARGS,
)
],
**DEFAULT_AST_ARGS,
)
# Normal `my_function(foo=bar)` syntax
return py_ast.Call(
func=self.value.as_ast(),
args=[arg.as_ast() for arg in self.args],
keywords=[
py_ast.keyword(arg=name, value=value.as_ast(), **DEFAULT_AST_ARGS)
for name, value in self.kwargs.items()
],
**DEFAULT_AST_ARGS,
)
def __repr__(self):
return f"Call({self.value!r}, {self.args}, {self.kwargs})"
[docs]
def method_call(
obj: Expression,
method_name: str,
args: Sequence[Expression],
kwargs: Mapping[str, Expression],
):
"""Create a method call expression on *obj*."""
return obj.attr(method_name).call(args=args, kwargs=kwargs)
[docs]
class Subscript(Expression):
"""A subscript (indexing) expression (e.g. ``obj[key]``)."""
def __init__(self, value: Expression, slice: Expression):
self.value = value
self.slice = slice
[docs]
def as_ast(self, *, include_comments: bool = False) -> py_ast.expr:
return py_ast.Subscript(
value=self.value.as_ast(),
slice=py_ast.subscript_slice_object(self.slice.as_ast()),
ctx=py_ast.Load(),
**DEFAULT_AST_ARGS,
)
[docs]
class Slice(Expression):
"""A slice expression (e.g. ``0:10``, ``::2``, ``1:-1``).
Used as the *slice* argument to :class:`Subscript`.
All three components — *start*, *stop*, and *step* — are optional.
"""
def __init__(
self,
start: Expression | None = None,
stop: Expression | None = None,
step: Expression | None = None,
):
self.start = start
self.stop = stop
self.step = step
[docs]
def as_ast(self, *, include_comments: bool = False) -> py_ast.expr:
return py_ast.Slice(
lower=self.start.as_ast() if self.start is not None else None,
upper=self.stop.as_ast() if self.stop is not None else None,
step=self.step.as_ast() if self.step is not None else None,
**DEFAULT_AST_ARGS,
)
def __repr__(self):
parts: list[str] = []
if self.start is not None:
parts.append(f"start={self.start!r}")
if self.stop is not None:
parts.append(f"stop={self.stop!r}")
if self.step is not None:
parts.append(f"step={self.step!r}")
return f"Slice({', '.join(parts)})"
[docs]
class Lambda(Expression):
"""
A lambda expression e.g. `lambda x: x + 1`
"""
def __init__(self, args: Sequence[str | FunctionArg], body: Expression) -> None:
self.args = _normalize_args(args)
self.body = body
# A Lambda is a small Scope, with fixed arguments.
scope = Scope()
for arg in self.args:
scope.reserve_name(arg.name)
self.enames = Enames(scope=scope)
[docs]
def as_ast(self, *, include_comments: bool = False) -> py_ast.expr:
return py_ast.Lambda(
args=_function_arg_list_to_ast_arguments(self.args),
body=self.body.as_ast(),
**DEFAULT_AST_ARGS,
)
[docs]
class Comprehension:
def __init__(self, target: Target, iter: Expression):
self.target = target
self.iter = iter
class _EltComprehensionBase(Expression):
"""Base class for comprehensions with a single element expression (list, set, generator)."""
def __init__(self, element: Expression, generators: Sequence[Comprehension], ifs: Sequence[Expression]) -> None:
self.element = element
self.generators = generators
self.ifs = ifs
def _make_generators(self) -> list[py_ast.comprehension]:
return [
py_ast.comprehension(
target=_target_as_store_ast(comp.target),
iter=comp.iter.as_ast(),
ifs=[if_.as_ast() for if_ in self.ifs],
is_async=False,
)
for comp in self.generators
]
[docs]
class ListComp(_EltComprehensionBase):
"""A list comprehension expression, e.g. ``[x + 1 for x in items]``."""
[docs]
def as_ast(self, *, include_comments: bool = False) -> py_ast.expr:
return py_ast.ListComp(
elt=self.element.as_ast(),
generators=self._make_generators(),
)
[docs]
class SetComp(_EltComprehensionBase):
"""A set comprehension expression, e.g. ``{x + 1 for x in items}``."""
[docs]
def as_ast(self, *, include_comments: bool = False) -> py_ast.expr:
return py_ast.SetComp(
elt=self.element.as_ast(),
generators=self._make_generators(),
)
[docs]
class GeneratorExpr(_EltComprehensionBase):
"""A generator expression, e.g. ``(x + 1 for x in items)``."""
[docs]
def as_ast(self, *, include_comments: bool = False) -> py_ast.expr:
return py_ast.GeneratorExp(
elt=self.element.as_ast(),
generators=self._make_generators(),
)
[docs]
class DictComp(Expression):
"""A dict comprehension expression, e.g. ``{k: v for k, v in items}``."""
def __init__(
self, key: Expression, value: Expression, generators: Sequence[Comprehension], ifs: Sequence[Expression]
) -> None:
self.key = key
self.value = value
self.generators = generators
self.ifs = ifs
[docs]
def as_ast(self, *, include_comments: bool = False) -> py_ast.expr:
return py_ast.DictComp(
key=self.key.as_ast(),
value=self.value.as_ast(),
generators=[
py_ast.comprehension(
target=_target_as_store_ast(comp.target),
iter=comp.iter.as_ast(),
ifs=[if_.as_ast() for if_ in self.ifs],
is_async=False,
)
for comp in self.generators
],
)
[docs]
def list_comprehension(
*, iterable: ExpressionLike, target: Target, element: ExpressionLike, condition: ExpressionLike | None = None
) -> ListComp:
"""Create a :class:`ListComp` (list comprehension) expression.
e.g.::
data = auto([1, 2, 3])
list_comprehension(
iterable=data,
target=(loop_var := mod.scope.create_name("item")),
element=loop_var.e + 1,
condition=loop_var.e > 0
)
This produces: ``[item + 1 for item in [1, 2, 3] if item > 0]``
"""
comprehension = Comprehension(target, E_to_Expression(iterable))
return ListComp(
E_to_Expression(element), [comprehension], ifs=[E_to_Expression(condition)] if condition is not None else []
)
[docs]
def set_comprehension(
*, iterable: ExpressionLike, target: Target, element: ExpressionLike, condition: ExpressionLike | None = None
) -> SetComp:
"""Create a :class:`SetComp` (set comprehension) expression.
e.g.::
data = auto([1, 2, 3])
set_comprehension(
iterable=data,
target=(loop_var := mod.scope.create_name("item")),
element=loop_var.e + 1,
condition=loop_var.e > 0
)
This produces: ``{item + 1 for item in [1, 2, 3] if item > 0}``
"""
comprehension = Comprehension(target, E_to_Expression(iterable))
return SetComp(
E_to_Expression(element), [comprehension], ifs=[E_to_Expression(condition)] if condition is not None else []
)
[docs]
def generator_expression(
*, iterable: ExpressionLike, target: Target, element: ExpressionLike, condition: ExpressionLike | None = None
) -> GeneratorExpr:
"""Create a :class:`SetComp` (set comprehension) expression.
e.g.::
data = auto([1, 2, 3])
my_func = mod.scope.create_name("my_func")
my_func.e(generator_expression(
iterable=data,
target=(loop_var := mod.scope.create_name("item")),
element=loop_var.e + 1,
condition=loop_var.e > 0
)
This produces: ``my_func((item + 1 for item in [1, 2, 3] if item > 0))``
"""
comprehension = Comprehension(target, E_to_Expression(iterable))
return GeneratorExpr(
E_to_Expression(element), [comprehension], ifs=[E_to_Expression(condition)] if condition is not None else []
)
[docs]
def dict_comprehension(
*,
iterable: ExpressionLike,
target: Target,
key: ExpressionLike,
value: ExpressionLike,
condition: ExpressionLike | None = None,
) -> DictComp:
"""Create a :class:`DictComp` (dict comprehension) expression.
e.g.::
dict_comprehension(
iterable=items,
target=(
(key_var := mod.scope.create_name("k")),
(value_var := mod.scope.create_name("v")),
),
key=key_var.e + "_x",
value=value_var.e + 1,
)
This produces: ``{k + '_x': v + 1 for k, v in items}``
"""
comprehension = Comprehension(target, E_to_Expression(iterable))
return DictComp(
E_to_Expression(key),
E_to_Expression(value),
[comprehension],
ifs=[E_to_Expression(condition)] if condition is not None else [],
)
[docs]
def create_lambda(
args: Sequence[str | FunctionArg], body: ExpressionLike | Callable[[Lambda], ExpressionLike]
) -> Lambda:
"""
Create a lambda expression.
The body can be supplied by either an expression, or a callable that
will be called with a `Lambda` object as its only argument. This makes it
possible to access the `enames` object on the `Lambda`::
create_lambda('x', lambda self: self.enames.x + 1)
Produces: ``lambda x: x + 1``
"""
if callable(body):
temp_lambda = Lambda(args, body=constants.None_)
body_expr = E_to_Expression(body(temp_lambda))
else:
body_expr = E_to_Expression(body)
return Lambda(
args=args,
body=body_expr,
)
[docs]
def named(name: Name, value: ExpressionLike) -> NamedExpr:
"""
Create a NamedExpr from an Expression or E-object::
x = mod.scope.create_name("x")
value = codegen.auto(1).e + 1
named_val = codegen.named(x, value)
Produces::
(x := 1 + 1)
"""
return NamedExpr(name=name, value=E_to_Expression(value))
#: Type alias for valid assignment target expressions.
#: A :class:`Name`, :class:`Attr`, or :class:`Subscript` expression,
#: or a tuple of targets (for unpacking assignments).
type Target = Name | Attr | Subscript | tuple[Target, ...]
def is_target(value: object) -> TypeIs[Target]:
"""Return whether *value* is a valid assignment :data:`Target`."""
if isinstance(value, (Name, Attr, Subscript)):
return True
if isinstance(value, tuple):
return all(is_target(element) for element in value) # pyright: ignore[reportUnknownArgumentType,reportUnknownVariableType]
return False
def _target_as_store_ast(target: Target) -> py_ast.expr:
"""Convert a Target expression to an AST node with Store context."""
if isinstance(target, tuple):
return py_ast.Tuple(
elts=[_target_as_store_ast(t) for t in target],
ctx=py_ast.Store(),
**DEFAULT_AST_ARGS,
)
node = target.as_ast()
if isinstance(node, (py_ast.Name, py_ast.Attribute, py_ast.Subscript)):
node.ctx = py_ast.Store()
return node
raise AssertionError(f"Unexpected AST node type for target: {type(node)}") # pragma: no cover
def _aug_target_as_store_ast(target: AugAssignTarget) -> py_ast.Name | py_ast.Attribute | py_ast.Subscript:
node = target.as_ast()
if isinstance(node, (py_ast.Name, py_ast.Attribute, py_ast.Subscript)):
node.ctx = py_ast.Store()
return node
raise AssertionError(f"Unexpected AST node type for target: {type(node)}") # pragma: no cover
def _target_names(target: Target) -> list[str]:
"""Collect all plain :class:`Name` identifiers from a *target*."""
if isinstance(target, Name):
return [target.name]
if isinstance(target, tuple):
names: list[str] = []
for t in target:
names.extend(_target_names(t))
return names
return []
create_class_instance = function_call
[docs]
class NoneExpr(Expression):
"""A ``None`` literal expression."""
[docs]
def as_ast(self, *, include_comments: bool = False) -> py_ast.expr:
return py_ast.Constant(value=None, **DEFAULT_AST_ARGS)
[docs]
class BinaryOperator(Expression):
"""Base class for binary operator expressions with a left and right operand."""
def __init__(self, left: Expression, right: Expression):
self.left = left
self.right = right
[docs]
class ArithOp(BinaryOperator, ABC):
"""Arithmetic binary operator (ast.BinOp)."""
op: ClassVar[type[py_ast.operator]]
[docs]
def as_ast(self, *, include_comments: bool = False) -> py_ast.expr:
return py_ast.BinOp(
left=self.left.as_ast(),
op=self.op(**DEFAULT_AST_ARGS_ADD),
right=self.right.as_ast(),
**DEFAULT_AST_ARGS,
)
[docs]
class Add(ArithOp):
"""Addition (``+``) operator."""
op = py_ast.Add
[docs]
class Sub(ArithOp):
"""Subtraction (``-``) operator."""
op = py_ast.Sub
[docs]
class Mul(ArithOp):
"""Multiplication (``*``) operator."""
op = py_ast.Mult
[docs]
class Div(ArithOp):
"""True division (``/``) operator."""
op = py_ast.Div
[docs]
class FloorDiv(ArithOp):
"""Floor division (``//``) operator."""
op = py_ast.FloorDiv
[docs]
class Mod(ArithOp):
"""Modulo (``%``) operator."""
op = py_ast.Mod
[docs]
class Pow(ArithOp):
"""Exponentiation (``**``) operator."""
op = py_ast.Pow
[docs]
class MatMul(ArithOp):
"""Matrix multiplication (``@``) operator."""
op = py_ast.MatMult
class BitAnd(ArithOp):
"""Bitwise AND (``&``) operator."""
op = py_ast.BitAnd
class BitOr(ArithOp):
"""Bitwise OR (``|``) operator."""
op = py_ast.BitOr
class BitXor(ArithOp):
"""Bitwise XOR (``^``) operator."""
op = py_ast.BitXor
class LShift(ArithOp):
"""Left shift (``<<``) operator."""
op = py_ast.LShift
class RShift(ArithOp):
"""Right shift (``>>``) operator."""
op = py_ast.RShift
[docs]
class CompareOp(BinaryOperator, ABC):
"""Comparison operator (ast.Compare)."""
op: ClassVar[type[py_ast.cmpop]]
[docs]
def as_ast(self, *, include_comments: bool = False) -> py_ast.expr:
return py_ast.Compare(
left=self.left.as_ast(),
comparators=[self.right.as_ast()],
ops=[self.op()],
**DEFAULT_AST_ARGS,
)
[docs]
class Equals(CompareOp):
"""Equality (``==``) comparison."""
op = py_ast.Eq
[docs]
class NotEquals(CompareOp):
"""Inequality (``!=``) comparison."""
op = py_ast.NotEq
[docs]
class Lt(CompareOp):
"""Less-than (``<``) comparison."""
op = py_ast.Lt
[docs]
class Gt(CompareOp):
"""Greater-than (``>``) comparison."""
op = py_ast.Gt
[docs]
class LtE(CompareOp):
"""Less-than-or-equal (``<=``) comparison."""
op = py_ast.LtE
[docs]
class GtE(CompareOp):
"""Greater-than-or-equal (``>=``) comparison."""
op = py_ast.GtE
[docs]
class In(CompareOp):
"""Membership (``in``) comparison."""
op = py_ast.In
[docs]
class NotIn(CompareOp):
"""Non-membership (``not in``) comparison."""
op = py_ast.NotIn
class Is(CompareOp):
"""Identity (``is``) comparison."""
op = py_ast.Is
class IsNot(CompareOp):
"""Non-identity (``is not``) comparison."""
op = py_ast.IsNot
[docs]
class BoolOp(BinaryOperator, ABC):
"""Boolean operator (``and`` / ``or``)."""
op: ClassVar[type[py_ast.boolop]]
[docs]
def as_ast(self, *, include_comments: bool = False) -> py_ast.expr:
return py_ast.BoolOp(
op=self.op(),
values=[self.left.as_ast(), self.right.as_ast()],
**DEFAULT_AST_ARGS,
)
[docs]
class And(BoolOp):
"""Logical ``and`` operator."""
op = py_ast.And
[docs]
class Or(BoolOp):
"""Logical ``or`` operator."""
op = py_ast.Or
class UnaryOperator(Expression, ABC):
"""Base class for unary operator expressions."""
op: ClassVar[type[py_ast.unaryop]]
def __init__(self, operand: Expression):
self.operand = operand
def as_ast(self, *, include_comments: bool = False) -> py_ast.expr:
return py_ast.UnaryOp(
op=self.op(**DEFAULT_AST_ARGS_ADD),
operand=self.operand.as_ast(),
**DEFAULT_AST_ARGS,
)
class Not(UnaryOperator):
"""Logical ``not`` operator."""
op = py_ast.Not
class UAdd(UnaryOperator):
"""Unary positive (``+``) operator."""
op = py_ast.UAdd
class USub(UnaryOperator):
"""Unary negation (``-``) operator."""
op = py_ast.USub
class Invert(UnaryOperator):
"""Bitwise inversion (``~``) operator."""
op = py_ast.Invert
[docs]
def simplify(codegen_ast: CodeGenAstType, simplifier: Callable[[CodeGenAstType], CodeGenAst | None]):
"""
Repeatedly apply *simplifier* to *codegen_ast* until no more changes are made.
The simplifier function should return None if no changes are to be made,
or the new CodeGenAst object otherwise.
"""
changes = [True]
def rewriter(node: CodeGenAstType) -> CodeGenAstType:
simplified = simplifier(node)
if simplified is not None:
changes.append(True)
return simplified
return node
while any(changes):
changes[:] = []
rewriting_traverse(codegen_ast, rewriter)
return codegen_ast
[docs]
def rewriting_traverse(
node: CodeGenAstType | Sequence[CodeGenAstType],
func: Callable[[CodeGenAstType], CodeGenAstType],
_visited: set[int] | None = None,
):
"""
Apply 'func' to node and all sub CodeGenAst nodes.
Discovers child nodes by introspecting instance attributes rather than
relying on a manually-maintained list. A *visited* set (keyed by
``id()``) prevents infinite recursion through circular references
(e.g. Block.scope → Function → body → Block).
"""
if _visited is None:
_visited = set()
node_id = id(node)
if node_id in _visited:
return
if isinstance(node, (CodeGenAst, CodeGenAstList)):
_visited.add(node_id)
new_node = func(node)
if new_node is not node:
morph_into(node, new_node)
for value in node.__dict__.values():
rewriting_traverse(value, func, _visited)
elif isinstance(node, (list, tuple)):
for i in node:
rewriting_traverse(i, func, _visited)
elif isinstance(node, dict):
for v in node.values(): # type: ignore[reportUnknownVariableType]
rewriting_traverse(v, func, _visited) # type: ignore[reportUnknownVariableType]
[docs]
def morph_into(item: object, new_item: object) -> None:
"""Mutate *item* in-place to behave identically to *new_item* while preserving identity."""
item.__class__ = new_item.__class__
item.__dict__ = new_item.__dict__
def empty_If() -> py_ast.If:
"""
Create an empty If ast node. The `test` attribute
must be added later.
"""
return py_ast.If(test=None, orelse=[], **DEFAULT_AST_ARGS) # type: ignore[reportArgumentType]
type SimplePythonObj = bool | str | bytes | int | float | None
type Autoable = (
SimplePythonObj
| Expression
| E
| slice[Autoable]
| list[Autoable]
| tuple[Autoable, ...]
| set[Autoable]
| dict[Autoable, Autoable]
)
@overload
def auto(value: bool) -> Bool: ... # type: ignore[overload-overlap] # bool before int/float is intentional
@overload
def auto(value: str) -> String: ...
@overload
def auto(value: bytes) -> Bytes: ...
@overload
def auto(value: int) -> Number: ...
@overload
def auto(value: float) -> Number: ...
@overload
def auto(value: None) -> NoneExpr: ...
@overload
def auto(value: slice[Autoable]) -> Slice: ...
@overload
def auto(value: Expression) -> Expression: ...
@overload
def auto(value: E) -> Expression: ...
@overload
def auto(value: list[Autoable]) -> List: ...
@overload
def auto(value: tuple[Autoable, ...]) -> Tuple: ...
@overload
def auto(value: set[Autoable]) -> Set: ...
@overload
def auto(value: dict[Autoable, Autoable]) -> Dict: ...
[docs]
def auto(value: Autoable) -> Expression:
"""
Create a codegen Expression from a plain Python object.
Supports bool, str, bytes, int, float, None, slice, and recursively
list, tuple, set, and dict.
It also supports a mixture - containers than have both plain Python objects
and values that are already Expression or E-objects.
"""
if isinstance(value, bool):
return Bool(value)
elif isinstance(value, str):
return String(value)
elif isinstance(value, bytes):
return Bytes(value)
elif isinstance(value, (int, float)):
return Number(value)
elif value is None:
return constants.None_
elif isinstance(value, slice):
return Slice(
start=auto(value.start) if value.start is not None else None,
stop=auto(value.stop) if value.stop is not None else None,
step=auto(value.step) if value.step is not None else None,
)
elif isinstance(value, Expression):
return value
elif isinstance(value, E):
return E_to_Expression(value)
elif isinstance(value, list):
return List([auto(item) for item in value])
elif isinstance(value, tuple):
return Tuple([auto(item) for item in value])
elif isinstance(value, set):
return Set([auto(item) for item in sorted(value, key=repr)])
elif isinstance(value, dict): # type: ignore[reportUnnecessaryIsInstance]
return Dict([(auto(k), auto(v)) for k, v in value.items()])
assert_never(value)
type ELike = E | Autoable
"""
E-objects or things used in similar contexts (Python literals)
"""
type ExpressionLike = Expression | E
"""
Expression objects, or things used in similar contexts (E-objects)
"""
class E:
def __init__(self, expr: Expression) -> None:
self._the_expr = expr
def __getattr__(self, name: str) -> E:
return E(self._the_expr.attr(name))
def __call__(self, *args: ELike, **kwargs: ELike) -> E:
exp_args = [ELike_to_Expression(arg) for arg in args]
exp_kwargs = {key: ELike_to_Expression(val) for key, val in kwargs.items()}
return E(self._the_expr.call(exp_args, exp_kwargs))
# Subscript
def __getitem__(self, key: ELike) -> E:
return E(self._the_expr.subscript(ELike_to_Expression(key)))
# Arithmetic operators
def __add__(self, other: ELike) -> E:
return E(self._the_expr.add(ELike_to_Expression(other)))
def __radd__(self, other: ELike) -> E:
return E(ELike_to_Expression(other).add(self._the_expr))
def __sub__(self, other: ELike) -> E:
return E(self._the_expr.sub(ELike_to_Expression(other)))
def __rsub__(self, other: ELike) -> E:
return E(ELike_to_Expression(other).sub(self._the_expr))
def __mul__(self, other: ELike) -> E:
return E(self._the_expr.mul(ELike_to_Expression(other)))
def __rmul__(self, other: ELike) -> E:
return E(ELike_to_Expression(other).mul(self._the_expr))
def __truediv__(self, other: ELike) -> E:
return E(self._the_expr.div(ELike_to_Expression(other)))
def __rtruediv__(self, other: ELike) -> E:
return E(ELike_to_Expression(other).div(self._the_expr))
def __floordiv__(self, other: ELike) -> E:
return E(self._the_expr.floordiv(ELike_to_Expression(other)))
def __rfloordiv__(self, other: ELike) -> E:
return E(ELike_to_Expression(other).floordiv(self._the_expr))
def __mod__(self, other: ELike) -> E:
return E(self._the_expr.mod(ELike_to_Expression(other)))
def __rmod__(self, other: ELike) -> E:
return E(ELike_to_Expression(other).mod(self._the_expr))
def __pow__(self, other: ELike) -> E:
return E(self._the_expr.pow(ELike_to_Expression(other)))
def __rpow__(self, other: ELike) -> E:
return E(ELike_to_Expression(other).pow(self._the_expr))
def __matmul__(self, other: ELike) -> E:
return E(self._the_expr.matmul(ELike_to_Expression(other)))
def __rmatmul__(self, other: ELike) -> E:
return E(ELike_to_Expression(other).matmul(self._the_expr))
# Bitwise:
def __and__(self, other: ELike) -> E:
return E(self._the_expr.bitand(ELike_to_Expression(other)))
def __or__(self, other: ELike) -> E:
return E(self._the_expr.bitor(ELike_to_Expression(other)))
def __xor__(self, other: ELike) -> E:
return E(self._the_expr.xor(ELike_to_Expression(other)))
def __lshift__(self, other: ELike) -> E:
return E(self._the_expr.lshift(ELike_to_Expression(other)))
def __rshift__(self, other: ELike) -> E:
return E(self._the_expr.rshift(ELike_to_Expression(other)))
def __invert__(self) -> E:
return E(self._the_expr.invert())
# Comparison operators
def __eq__(self, other: ELike) -> E: # type: ignore
return E(self._the_expr.eq(ELike_to_Expression(other)))
def __ne__(self, other: ELike) -> E: # type: ignore
return E(self._the_expr.ne(ELike_to_Expression(other)))
def __lt__(self, other: ELike) -> E: # type: ignore
return E(self._the_expr.lt(ELike_to_Expression(other)))
def __gt__(self, other: ELike) -> E: # type: ignore
return E(self._the_expr.gt(ELike_to_Expression(other)))
def __le__(self, other: ELike) -> E: # type: ignore
return E(self._the_expr.le(ELike_to_Expression(other)))
def __ge__(self, other: ELike) -> E: # type: ignore
return E(self._the_expr.ge(ELike_to_Expression(other)))
# Boolean operators - these cannot exists, and/or are short-circuiting
# operators that cannot be overridden
# Membership
# It looks like we can't override __contains__ effectively
# to do what we need to support `in` and `not in`
# - Python coerces to bool afterwards,
# - `__bool__` has to return True/False
# - we can't do `not in`
# Identity operators - again can't be overridden
# # Unary operators
def __pos__(self) -> E:
return E(self._the_expr.pos())
def __neg__(self) -> E:
return E(self._the_expr.neg())
def E_to_Expression(val: E | Expression) -> Expression:
if isinstance(val, E):
return val._the_expr # type: ignore[reportPrivateUsage]
return val
def ELike_to_Expression(val: ELike) -> Expression:
if isinstance(val, E):
return val._the_expr # type: ignore[reportPrivateUsage]
return auto(val)
class Enames:
def __init__(self, scope: Scope) -> None:
self.__scope = scope
def __getattr__(self, name: str) -> E:
return E(self.__scope.name(name))
[docs]
class constants:
"""
Useful pre-made Expression constants
"""
None_: NoneExpr = NoneExpr()
True_: Bool = auto(True)
False_: Bool = auto(False)