You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

external.py 4.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. # pylint: disable=redefined-builtin
  10. import numpy as np
  11. from ..functional.external import (
  12. atlas_runtime_opr,
  13. cambricon_runtime_opr,
  14. extern_opr_subgraph,
  15. tensorrt_runtime_opr,
  16. )
  17. from .module import Module
  18. class ExternOprSubgraph(Module):
  19. r"""Load a serialized ExternOpr subgraph.
  20. See :func:`~.extern_opr` for more details.
  21. """
  22. def __init__(
  23. self, output_shapes, dump_name, dump_data, output_dtypes=None, **kwargs
  24. ):
  25. super(ExternOprSubgraph, self).__init__(**kwargs)
  26. self._output_shapes = output_shapes
  27. self._dump_name = dump_name
  28. self._dump_data = dump_data
  29. self._output_dtypes = output_dtypes
  30. if self._output_dtypes is None:
  31. self._output_dtypes = [np.float32] * len(output_shapes)
  32. @property
  33. def data(self):
  34. return self._dump_data
  35. @data.setter
  36. def data(self, val):
  37. self._dump_data = np.frombuffer(val, dtype=np.uint8)
  38. @property
  39. def name(self):
  40. return self._dump_name
  41. @name.setter
  42. def name(self, val):
  43. self._dump_name = val
  44. def forward(self, *inputs):
  45. return extern_opr_subgraph(
  46. inputs,
  47. output_shapes=self._output_shapes,
  48. dump_name=self._dump_name,
  49. dump_data=self._dump_data,
  50. output_dtypes=self._output_dtypes,
  51. )
  52. class TensorrtRuntimeSubgraph(Module):
  53. r"""Load a serialized TensorrtRuntime subgraph.
  54. See :func:`~.tensorrt_runtime_opr` for more details.
  55. """
  56. def __init__(self, data, **kwargs):
  57. super(TensorrtRuntimeSubgraph, self).__init__(**kwargs)
  58. self._data = data
  59. @property
  60. def data(self):
  61. return self._data
  62. @data.setter
  63. def data(self, val):
  64. self._data = np.frombuffer(val, dtype=np.uint8)
  65. def forward(self, *inputs):
  66. return tensorrt_runtime_opr(inputs, data=self._data)
  67. class CambriconRuntimeSubgraph(Module):
  68. r"""Load a serialized CambriconRuntime subgraph.
  69. See :func:`~.cambricon_runtime_opr` for more details.
  70. """
  71. def __init__(self, data, symbol, tensor_dim_mutable, **kwargs):
  72. super(CambriconRuntimeSubgraph, self).__init__(**kwargs)
  73. self._data = data
  74. self.symbol = symbol
  75. self.tensor_dim_mutable = tensor_dim_mutable
  76. @property
  77. def data(self):
  78. return self._data
  79. @data.setter
  80. def data(self, val):
  81. self._data = np.frombuffer(val, dtype=np.uint8)
  82. def forward(self, *inputs):
  83. outputs = cambricon_runtime_opr(
  84. inputs, self._data, self.symbol, self.tensor_dim_mutable
  85. )
  86. return outputs
  87. class AtlasRuntimeSubgraph(Module):
  88. r"""Load a serialized AtlasRuntime subgraph.
  89. See :func:`~.atlas_runtime_opr` for more details.
  90. """
  91. def __init__(self, data, **kwargs):
  92. super(AtlasRuntimeSubgraph, self).__init__(**kwargs)
  93. self._data = data
  94. @property
  95. def data(self):
  96. return self._data
  97. @data.setter
  98. def data(self, val):
  99. self._data = np.frombuffer(val, dtype=np.uint8)
  100. def forward(self, *inputs):
  101. return atlas_runtime_opr(inputs, data=self._data)
  102. class MagicMindRuntimeSubgraph(Module):
  103. r"""Load a serialized MagicMindRuntime subgraph.
  104. See :func:`~.magicmind_runtime_opr` for more details.
  105. """
  106. def __init__(self, data, **kwargs):
  107. super(MagicMindRuntimeSubgraph, self).__init__(**kwargs)
  108. self._data = data
  109. @property
  110. def data(self):
  111. return self._data
  112. @data.setter
  113. def data(self, val):
  114. self._data = np.frombuffer(val, dtype=np.uint8)
  115. def forward(self, *inputs):
  116. return magicmind_runtime_opr(inputs, data=self._data)