You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 807 B

123456789101112131415161718192021222324252627282930
  1. import copy
  2. from typing import Any, Dict, List
  3. from ..expr import Expr, is_constant, is_getattr
  4. from ..node import Node, TensorNode
  5. def register_obj(objs: List[Any], _dict: Dict):
  6. if not isinstance(objs, List):
  7. objs = [objs]
  8. def _register(any_obj: Any):
  9. for obj in objs:
  10. _dict[obj] = any_obj
  11. return any_obj
  12. return _register
  13. def get_const_value(expr: Expr, fall_back: Any = None):
  14. value = fall_back
  15. if isinstance(expr, Node):
  16. expr = expr.expr
  17. if is_getattr(expr) and isinstance(expr.outputs[0], TensorNode):
  18. module = expr.inputs[0].owner
  19. assert module is not None
  20. value = copy.deepcopy(expr.interpret(module)[0])
  21. elif is_constant(expr):
  22. value = copy.deepcopy(expr.interpret()[0])
  23. return value