"""Helpers for AST (Abstract Syntax Tree). """ import sys from typing import Dict, List, Optional, Type, overload if sys.version_info > (3, 8): import ast else: try: # use typed_ast module if installed from typed_ast import ast3 as ast except ImportError: import ast # type: ignore OPERATORS: Dict[Type[ast.AST], str] = { ast.Add: "+", ast.And: "and", ast.BitAnd: "&", ast.BitOr: "|", ast.BitXor: "^", ast.Div: "/", ast.FloorDiv: "//", ast.Invert: "~", ast.LShift: "<<", ast.MatMult: "@", ast.Mult: "*", ast.Mod: "%", ast.Not: "not", ast.Pow: "**", ast.Or: "or", ast.RShift: ">>", ast.Sub: "-", ast.UAdd: "+", ast.USub: "-", } def parse(code: str, mode: str = 'exec') -> "ast.AST": """Parse the *code* using the built-in ast or typed_ast libraries. This enables "type_comments" feature if possible. """ try: # type_comments parameter is available on py38+ return ast.parse(code, mode=mode, type_comments=True) # type: ignore except SyntaxError: # Some syntax error found. To ignore invalid type comments, retry parsing without # type_comments parameter (refs: https://github.com/sphinx-doc/sphinx/issues/8652). return ast.parse(code, mode=mode) except TypeError: # fallback to ast module. # typed_ast is used to parse type_comments if installed. return ast.parse(code, mode=mode) @overload def unparse(node: None, code: str = '') -> None: ... @overload def unparse(node: ast.AST, code: str = '') -> str: ... def unparse(node: Optional[ast.AST], code: str = '') -> Optional[str]: """Unparse an AST to string.""" if node is None: return None elif isinstance(node, str): return node return _UnparseVisitor(code).visit(node) # a greatly cut-down version of `ast._Unparser` class _UnparseVisitor(ast.NodeVisitor): def __init__(self, code: str = '') -> None: self.code = code def _visit_op(self, node: ast.AST) -> str: return OPERATORS[node.__class__] for _op in OPERATORS: locals()['visit_{}'.format(_op.__name__)] = _visit_op def visit_arg(self, node: ast.arg) -> str: if node.annotation: return "%s: %s" % (node.arg, self.visit(node.annotation)) else: return node.arg def _visit_arg_with_default(self, arg: ast.arg, default: Optional[ast.AST]) -> str: """Unparse a single argument to a string.""" name = self.visit(arg) if default: if arg.annotation: name += " = %s" % self.visit(default) else: name += "=%s" % self.visit(default) return name def visit_arguments(self, node: ast.arguments) -> str: defaults: List[Optional[ast.expr]] = list(node.defaults) positionals = len(node.args) posonlyargs = 0 if hasattr(node, "posonlyargs"): # for py38+ posonlyargs += len(node.posonlyargs) # type:ignore positionals += posonlyargs for _ in range(len(defaults), positionals): defaults.insert(0, None) kw_defaults: List[Optional[ast.expr]] = list(node.kw_defaults) for _ in range(len(kw_defaults), len(node.kwonlyargs)): kw_defaults.insert(0, None) args: List[str] = [] if hasattr(node, "posonlyargs"): # for py38+ for i, arg in enumerate(node.posonlyargs): # type: ignore args.append(self._visit_arg_with_default(arg, defaults[i])) if node.posonlyargs: # type: ignore args.append('/') for i, arg in enumerate(node.args): args.append(self._visit_arg_with_default(arg, defaults[i + posonlyargs])) if node.vararg: args.append("*" + self.visit(node.vararg)) if node.kwonlyargs and not node.vararg: args.append('*') for i, arg in enumerate(node.kwonlyargs): args.append(self._visit_arg_with_default(arg, kw_defaults[i])) if node.kwarg: args.append("**" + self.visit(node.kwarg)) return ", ".join(args) def visit_Attribute(self, node: ast.Attribute) -> str: return "%s.%s" % (self.visit(node.value), node.attr) def visit_BinOp(self, node: ast.BinOp) -> str: return " ".join(self.visit(e) for e in [node.left, node.op, node.right]) def visit_BoolOp(self, node: ast.BoolOp) -> str: op = " %s " % self.visit(node.op) return op.join(self.visit(e) for e in node.values) def visit_Call(self, node: ast.Call) -> str: args = ([self.visit(e) for e in node.args] + ["%s=%s" % (k.arg, self.visit(k.value)) for k in node.keywords]) return "%s(%s)" % (self.visit(node.func), ", ".join(args)) def visit_Constant(self, node: ast.Constant) -> str: # type: ignore if node.value is Ellipsis: return "..." elif isinstance(node.value, (int, float, complex)): if self.code and sys.version_info > (3, 8): return ast.get_source_segment(self.code, node) # type: ignore else: return repr(node.value) else: return repr(node.value) def visit_Dict(self, node: ast.Dict) -> str: keys = (self.visit(k) for k in node.keys) values = (self.visit(v) for v in node.values) items = (k + ": " + v for k, v in zip(keys, values)) return "{" + ", ".join(items) + "}" def visit_Index(self, node: ast.Index) -> str: return self.visit(node.value) def visit_Lambda(self, node: ast.Lambda) -> str: return "lambda %s: ..." % self.visit(node.args) def visit_List(self, node: ast.List) -> str: return "[" + ", ".join(self.visit(e) for e in node.elts) + "]" def visit_Name(self, node: ast.Name) -> str: return node.id def visit_Set(self, node: ast.Set) -> str: return "{" + ", ".join(self.visit(e) for e in node.elts) + "}" def visit_Subscript(self, node: ast.Subscript) -> str: def is_simple_tuple(value: ast.AST) -> bool: return ( isinstance(value, ast.Tuple) and bool(value.elts) and not any(isinstance(elt, ast.Starred) for elt in value.elts) ) if is_simple_tuple(node.slice): elts = ", ".join(self.visit(e) for e in node.slice.elts) # type: ignore return "%s[%s]" % (self.visit(node.value), elts) elif isinstance(node.slice, ast.Index) and is_simple_tuple(node.slice.value): elts = ", ".join(self.visit(e) for e in node.slice.value.elts) # type: ignore return "%s[%s]" % (self.visit(node.value), elts) else: return "%s[%s]" % (self.visit(node.value), self.visit(node.slice)) def visit_UnaryOp(self, node: ast.UnaryOp) -> str: return "%s %s" % (self.visit(node.op), self.visit(node.operand)) def visit_Tuple(self, node: ast.Tuple) -> str: if len(node.elts) == 0: return "()" elif len(node.elts) == 1: return "(%s,)" % self.visit(node.elts[0]) else: return "(" + ", ".join(self.visit(e) for e in node.elts) + ")" if sys.version_info < (3, 8): # these ast nodes were deprecated in python 3.8 def visit_Bytes(self, node: ast.Bytes) -> str: return repr(node.s) def visit_Ellipsis(self, node: ast.Ellipsis) -> str: return "..." def visit_NameConstant(self, node: ast.NameConstant) -> str: return repr(node.value) def visit_Num(self, node: ast.Num) -> str: return repr(node.n) def visit_Str(self, node: ast.Str) -> str: return repr(node.s) def generic_visit(self, node): raise NotImplementedError('Unable to parse %s object' % type(node).__name__)