You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

external.py 4.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. # pylint: disable=redefined-builtin
  10. import numpy as np
  11. from ..functional.external import (
  12. atlas_runtime_opr,
  13. cambricon_runtime_opr,
  14. extern_opr_subgraph,
  15. magicmind_runtime_opr,
  16. tensorrt_runtime_opr,
  17. )
  18. from .module import Module
  19. class ExternOprSubgraph(Module):
  20. r"""Load a serialized ExternOpr subgraph.
  21. See :func:`~.extern_opr` for more details.
  22. """
  23. def __init__(
  24. self, output_shapes, dump_name, dump_data, output_dtypes=None, **kwargs
  25. ):
  26. super(ExternOprSubgraph, self).__init__(**kwargs)
  27. self._output_shapes = output_shapes
  28. self._dump_name = dump_name
  29. self._dump_data = dump_data
  30. self._output_dtypes = output_dtypes
  31. if self._output_dtypes is None:
  32. self._output_dtypes = [np.float32] * len(output_shapes)
  33. @property
  34. def data(self):
  35. return self._dump_data
  36. @data.setter
  37. def data(self, val):
  38. self._dump_data = np.frombuffer(val, dtype=np.uint8)
  39. @property
  40. def name(self):
  41. return self._dump_name
  42. @name.setter
  43. def name(self, val):
  44. self._dump_name = val
  45. def forward(self, *inputs):
  46. return extern_opr_subgraph(
  47. inputs,
  48. output_shapes=self._output_shapes,
  49. dump_name=self._dump_name,
  50. dump_data=self._dump_data,
  51. output_dtypes=self._output_dtypes,
  52. )
  53. class TensorrtRuntimeSubgraph(Module):
  54. r"""Load a serialized TensorrtRuntime subgraph.
  55. See :func:`~.tensorrt_runtime_opr` for more details.
  56. """
  57. def __init__(self, data, **kwargs):
  58. super(TensorrtRuntimeSubgraph, self).__init__(**kwargs)
  59. self._data = data
  60. @property
  61. def data(self):
  62. return self._data
  63. @data.setter
  64. def data(self, val):
  65. self._data = np.frombuffer(val, dtype=np.uint8)
  66. def forward(self, *inputs):
  67. return tensorrt_runtime_opr(inputs, data=self._data)
  68. class CambriconRuntimeSubgraph(Module):
  69. r"""Load a serialized CambriconRuntime subgraph.
  70. See :func:`~.cambricon_runtime_opr` for more details.
  71. """
  72. def __init__(self, data, symbol, tensor_dim_mutable, **kwargs):
  73. super(CambriconRuntimeSubgraph, self).__init__(**kwargs)
  74. self._data = data
  75. self.symbol = symbol
  76. self.tensor_dim_mutable = tensor_dim_mutable
  77. @property
  78. def data(self):
  79. return self._data
  80. @data.setter
  81. def data(self, val):
  82. self._data = np.frombuffer(val, dtype=np.uint8)
  83. def forward(self, *inputs):
  84. outputs = cambricon_runtime_opr(
  85. inputs, self._data, self.symbol, self.tensor_dim_mutable
  86. )
  87. return outputs
  88. class AtlasRuntimeSubgraph(Module):
  89. r"""Load a serialized AtlasRuntime subgraph.
  90. See :func:`~.atlas_runtime_opr` for more details.
  91. """
  92. def __init__(self, data, **kwargs):
  93. super(AtlasRuntimeSubgraph, self).__init__(**kwargs)
  94. self._data = data
  95. @property
  96. def data(self):
  97. return self._data
  98. @data.setter
  99. def data(self, val):
  100. self._data = np.frombuffer(val, dtype=np.uint8)
  101. def forward(self, *inputs):
  102. return atlas_runtime_opr(inputs, data=self._data)
  103. class MagicMindRuntimeSubgraph(Module):
  104. r"""Load a serialized MagicMindRuntime subgraph.
  105. See :func:`~.magicmind_runtime_opr` for more details.
  106. """
  107. def __init__(self, data, **kwargs):
  108. super(MagicMindRuntimeSubgraph, self).__init__(**kwargs)
  109. self._data = data
  110. @property
  111. def data(self):
  112. return self._data
  113. @data.setter
  114. def data(self, val):
  115. self._data = np.frombuffer(val, dtype=np.uint8)
  116. def forward(self, *inputs):
  117. return magicmind_runtime_opr(inputs, data=self._data)