import copy from typing import Any, Dict, List from ..expr import Expr, is_constant, is_getattr from ..node import Node, TensorNode def register_obj(objs: List[Any], _dict: Dict): if not isinstance(objs, List): objs = [objs] def _register(any_obj: Any): for obj in objs: _dict[obj] = any_obj return any_obj return _register def get_const_value(expr: Expr, fall_back: Any = None): value = fall_back if isinstance(expr, Node): expr = expr.expr if is_getattr(expr) and isinstance(expr.outputs[0], TensorNode): module = expr.inputs[0].owner assert module is not None value = copy.deepcopy(expr.interpret(module)[0]) elif is_constant(expr): value = copy.deepcopy(expr.interpret()[0]) return value