You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

external.py 2.2 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. # -*- coding: utf-8 -*-
  2. # pylint: disable=redefined-builtin
  3. from typing import Iterable, List, Sequence
  4. from ..core._imperative_rt.core2 import apply
  5. from ..core.ops import builtin
  6. def extern_opr_subgraph(
  7. inputs, output_shapes: List[tuple], dump_name: str, dump_data: bytes, output_dtypes
  8. ):
  9. r"""Load a serialized extern opr subgraph and fake execute the operator.
  10. Args:
  11. inputs: list of input tensors.
  12. output_shapes: The output shapes.
  13. dump_name: The serialized subgraph name.
  14. dump_data: The serialized subgraph.
  15. """
  16. if not isinstance(inputs, Iterable):
  17. inputs = (inputs,)
  18. op = builtin.ExternOpr(
  19. output_shapes, dump_name, dump_data, len(dump_data), output_dtypes
  20. )
  21. return apply(op, *inputs)
  22. def tensorrt_runtime_opr(inputs, *, data: bytes = None):
  23. # empty model will give None result
  24. if data is None:
  25. return None
  26. op = builtin.TensorRTRuntime(data, len(data))
  27. # return sequence of outputs
  28. return apply(op, *inputs)
  29. def cambricon_runtime_opr(inputs, data, symbol, tensor_dim_mutable):
  30. r"""Load a serialized Cambricon model as a runtime operator in MegEngine.
  31. Args:
  32. inputs: list of input tensors.
  33. data: the serialized Cambricon model.
  34. symbol: name of the function in Cambricon model.
  35. tensor_dim_mutable: whether the input tensors' shapes are mutable
  36. in ``cnrtModel_t``.
  37. """
  38. op = builtin.CambriconRuntime(data, len(data), symbol, tensor_dim_mutable)
  39. return apply(op, *inputs)
  40. def atlas_runtime_opr(inputs, data):
  41. r"""Load a serialized Atlas model as a runtime operator in MegEngine.
  42. Args:
  43. inputs: list of input tensors.
  44. data: the serialized Atlas model.
  45. """
  46. op = builtin.AtlasRuntime(data, len(data))
  47. return apply(op, *inputs)
  48. def magicmind_runtime_opr(inputs, data):
  49. r"""Load a serialized MagicMind model as a runtime operator in MegEngine.
  50. Args:
  51. inputs: list of input tensors.
  52. data: the serialized MagicMind model.
  53. """
  54. op = builtin.MagicMindRuntime(data, len(data))
  55. return apply(op, *inputs)