You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

external.py 3.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. # -*- coding: utf-8 -*-
  2. # pylint: disable=redefined-builtin
  3. import numpy as np
  4. from ..functional.external import (
  5. atlas_runtime_opr,
  6. cambricon_runtime_opr,
  7. extern_opr_subgraph,
  8. magicmind_runtime_opr,
  9. tensorrt_runtime_opr,
  10. )
  11. from .module import Module
  12. class ExternOprSubgraph(Module):
  13. r"""Load a serialized ExternOpr subgraph.
  14. See :func:`~.extern_opr` for more details.
  15. """
  16. def __init__(
  17. self, output_shapes, dump_name, dump_data, output_dtypes=None, **kwargs
  18. ):
  19. super(ExternOprSubgraph, self).__init__(**kwargs)
  20. self._output_shapes = output_shapes
  21. self._dump_name = dump_name
  22. self._dump_data = dump_data
  23. self._output_dtypes = output_dtypes
  24. if self._output_dtypes is None:
  25. self._output_dtypes = [np.float32] * len(output_shapes)
  26. @property
  27. def data(self):
  28. return self._dump_data
  29. @data.setter
  30. def data(self, val):
  31. self._dump_data = np.frombuffer(val, dtype=np.uint8)
  32. @property
  33. def name(self):
  34. return self._dump_name
  35. @name.setter
  36. def name(self, val):
  37. self._dump_name = val
  38. def forward(self, *inputs):
  39. return extern_opr_subgraph(
  40. inputs,
  41. output_shapes=self._output_shapes,
  42. dump_name=self._dump_name,
  43. dump_data=self._dump_data,
  44. output_dtypes=self._output_dtypes,
  45. )
  46. class TensorrtRuntimeSubgraph(Module):
  47. r"""Load a serialized TensorrtRuntime subgraph.
  48. See :func:`~.tensorrt_runtime_opr` for more details.
  49. """
  50. def __init__(self, data, **kwargs):
  51. super(TensorrtRuntimeSubgraph, self).__init__(**kwargs)
  52. self._data = data
  53. @property
  54. def data(self):
  55. return self._data
  56. @data.setter
  57. def data(self, val):
  58. self._data = np.frombuffer(val, dtype=np.uint8)
  59. def forward(self, *inputs):
  60. return tensorrt_runtime_opr(inputs, data=self._data)
  61. class CambriconRuntimeSubgraph(Module):
  62. r"""Load a serialized CambriconRuntime subgraph.
  63. See :func:`~.cambricon_runtime_opr` for more details.
  64. """
  65. def __init__(self, data, symbol, tensor_dim_mutable, **kwargs):
  66. super(CambriconRuntimeSubgraph, self).__init__(**kwargs)
  67. self._data = data
  68. self.symbol = symbol
  69. self.tensor_dim_mutable = tensor_dim_mutable
  70. @property
  71. def data(self):
  72. return self._data
  73. @data.setter
  74. def data(self, val):
  75. self._data = np.frombuffer(val, dtype=np.uint8)
  76. def forward(self, *inputs):
  77. outputs = cambricon_runtime_opr(
  78. inputs, self._data, self.symbol, self.tensor_dim_mutable
  79. )
  80. return outputs
  81. class AtlasRuntimeSubgraph(Module):
  82. r"""Load a serialized AtlasRuntime subgraph.
  83. See :func:`~.atlas_runtime_opr` for more details.
  84. """
  85. def __init__(self, data, **kwargs):
  86. super(AtlasRuntimeSubgraph, self).__init__(**kwargs)
  87. self._data = data
  88. @property
  89. def data(self):
  90. return self._data
  91. @data.setter
  92. def data(self, val):
  93. self._data = np.frombuffer(val, dtype=np.uint8)
  94. def forward(self, *inputs):
  95. return atlas_runtime_opr(inputs, data=self._data)
  96. class MagicMindRuntimeSubgraph(Module):
  97. r"""Load a serialized MagicMindRuntime subgraph.
  98. See :func:`~.magicmind_runtime_opr` for more details.
  99. """
  100. def __init__(self, data, **kwargs):
  101. super(MagicMindRuntimeSubgraph, self).__init__(**kwargs)
  102. self._data = data
  103. @property
  104. def data(self):
  105. return self._data
  106. @data.setter
  107. def data(self, val):
  108. self._data = np.frombuffer(val, dtype=np.uint8)
  109. def forward(self, *inputs):
  110. return magicmind_runtime_opr(inputs, data=self._data)