# Copyright (c) 2022-2026 CNES.
#
# All rights reserved. Use of this source code is governed by a
# BSD-style license that can be found in the LICENSE file.
"""Typed AST walker for partition filter expressions.
Replaces the v2 ``eval()``-based filter. Only a small, total subset of Python
syntax is permitted: comparisons, ``and``/``or``/``not``, ``in``, names that
match partition key components, and integer/string literals.
"""
from typing import Any
import ast
from collections.abc import Callable
import operator
from ..errors import ExpressionError
#: A partition key is a tuple of (component, value) pairs, e.g.
#: (("year", 2024), ("month", 3)).
PartitionKey = tuple[tuple[str, Any], ...]
#: A predicate is a function that takes a partition-key dict and returns a bool.
Predicate = Callable[[dict[str, Any]], bool]
_ALLOWED_NODES: tuple[type[ast.AST], ...] = (
ast.Expression,
ast.BoolOp,
ast.UnaryOp,
ast.Compare,
ast.Name,
ast.Load,
ast.Constant,
ast.Tuple,
ast.List,
ast.Set,
ast.And,
ast.Or,
ast.Not,
ast.Eq,
ast.NotEq,
ast.Lt,
ast.LtE,
ast.Gt,
ast.GtE,
ast.In,
ast.NotIn,
)
[docs]
def compile_filter(expr: str | None) -> Predicate:
"""Compile a filter expression to a predicate over partition-key dicts.
``expr=None`` or empty string returns a tautology.
Args:
expr: The filter expression to compile. This should be a string
containing a valid Python expression using only the allowed syntax.
Returns:
A predicate function that takes a partition-key dict and returns a bool
indicating whether the partition key satisfies the filter expression.
Raises:
ExpressionError: If the expression contains syntax errors or uses
disallowed syntax.
"""
if not expr:
return lambda _ctx: True
try:
tree = ast.parse(expr, mode="eval")
except SyntaxError as exc:
raise ExpressionError(f"invalid filter expression: {exc}") from exc
for node in ast.walk(tree):
if not isinstance(node, _ALLOWED_NODES):
raise ExpressionError(
f"disallowed syntax in filter: {type(node).__name__}"
)
return _make_predicate(tree.body)
def _make_predicate(node: ast.AST) -> Predicate:
fn = _compile(node)
return lambda ctx: bool(fn(ctx))
def _compile(node: ast.AST) -> Callable[[dict[str, Any]], Any]:
handler = _NODE_HANDLERS.get(type(node))
if handler is None:
raise ExpressionError(f"unsupported node: {type(node).__name__}")
return handler(node)
def _compile_constant(node: ast.Constant) -> Callable[[dict[str, Any]], Any]:
v = node.value
return lambda _c: v
def _compile_name(node: ast.Name) -> Callable[[dict[str, Any]], Any]:
n = node.id
def _name(ctx: dict[str, Any]) -> Any:
if n not in ctx:
raise ExpressionError(f"unknown partition key {n!r}")
return ctx[n]
return _name
def _compile_sequence(
node: ast.Tuple | ast.List | ast.Set,
) -> Callable[[dict[str, Any]], Any]:
children = [_compile(e) for e in node.elts]
return lambda c: tuple(f(c) for f in children)
def _compile_boolop(node: ast.BoolOp) -> Callable[[dict[str, Any]], Any]:
children = [_compile(v) for v in node.values]
if isinstance(node.op, ast.And):
return lambda c: all(f(c) for f in children)
return lambda c: any(f(c) for f in children)
def _compile_unaryop(node: ast.UnaryOp) -> Callable[[dict[str, Any]], Any]:
if not isinstance(node.op, ast.Not):
raise ExpressionError(f"unsupported unary op: {type(node.op).__name__}")
inner = _compile(node.operand)
return lambda c: not inner(c)
def _compile_compare(node: ast.Compare) -> Callable[[dict[str, Any]], Any]:
left = _compile(node.left)
ops = node.ops
rights = [_compile(c) for c in node.comparators]
def _cmp(ctx: dict[str, Any]) -> bool:
cur = left(ctx)
for op, r in zip(ops, rights, strict=True):
rv = r(ctx)
if not _apply_op(op, cur, rv):
return False
cur = rv
return True
return _cmp
_NODE_HANDLERS: dict[
type[ast.AST], Callable[[Any], Callable[[dict[str, Any]], Any]]
] = {
ast.Constant: _compile_constant,
ast.Name: _compile_name,
ast.Tuple: _compile_sequence,
ast.List: _compile_sequence,
ast.Set: _compile_sequence,
ast.BoolOp: _compile_boolop,
ast.UnaryOp: _compile_unaryop,
ast.Compare: _compile_compare,
}
_CMP_OPS: dict[type[ast.cmpop], Callable[[Any, Any], bool]] = {
ast.Eq: operator.eq,
ast.NotEq: operator.ne,
ast.Lt: operator.lt,
ast.LtE: operator.le,
ast.Gt: operator.gt,
ast.GtE: operator.ge,
ast.In: lambda a, b: a in b,
ast.NotIn: lambda a, b: a not in b,
}
def _apply_op(op: ast.cmpop, a: Any, b: Any) -> bool:
fn = _CMP_OPS.get(type(op))
if fn is None:
raise ExpressionError(
f"unsupported comparison operator: {type(op).__name__}"
)
return fn(a, b)
def key_to_dict(key: PartitionKey) -> dict[str, Any]:
"""Return ``key`` as a ``{component: value}`` mapping."""
return dict(key)