Browse Source

feat(mge): add atlas_subgraph module

GitOrigin-RevId: 11530383c0
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
3dbac4f47f
4 changed files with 69 additions and 4 deletions
  1. +11
    -0
      python_module/megengine/functional/external.py
  2. +28
    -1
      python_module/megengine/module/external.py
  3. BIN
      python_module/test/unit/module/AtlasRuntimeOprTest.basic.om
  4. +30
    -3
      python_module/test/unit/module/test_external.py

+ 11
- 0
python_module/megengine/functional/external.py View File

@@ -35,6 +35,17 @@ def cambricon_subgraph(


@wrap_io_tensor
def atlas_subgraph(inputs: List[Tensor], data: bytes) -> List[Tensor]:
"""Load a serialized Atlas subgraph (i.e. om model) and
execute the operations defined in the subgraph.

:param inputs: List of input tensors of the subgraph.
:param data: The serialized subgraph.
"""
return mgb.opr.atlas_runtime(tuple(map(lambda x: x._symvar, inputs)), data)


@wrap_io_tensor
def extern_opr_subgraph(
inputs, output_shapes: List[tuple], dump_name: str, dump_data: bytes,
) -> List[Tensor]:


+ 28
- 1
python_module/megengine/module/external.py View File

@@ -8,7 +8,11 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np

from ..functional.external import cambricon_subgraph, extern_opr_subgraph
from ..functional.external import (
atlas_subgraph,
cambricon_subgraph,
extern_opr_subgraph,
)
from .module import Module


@@ -41,6 +45,29 @@ class CambriconSubgraph(Module):
return outputs


class AtlasSubgraph(Module):
r"""Load a serialized Atlas subgraph.

See :func:`~.atlas_subgraph` for more details.
"""

def __init__(self, data):
super(AtlasSubgraph, self).__init__()
self._data = data

@property
def data(self):
return self._data.tobytes()

@data.setter
def data(self, val):
self._data = np.frombuffer(val, dtype=np.uint8)

def forward(self, inputs):
outputs = atlas_subgraph(inputs, self._data)
return outputs


class ExternOprSubgraph(Module):
r"""Load a serialized extern opr subgraph.
"""


BIN
python_module/test/unit/module/AtlasRuntimeOprTest.basic.om View File


+ 30
- 3
python_module/test/unit/module/test_external.py View File

@@ -13,10 +13,10 @@ import numpy as np
import megengine as mge
from megengine import tensor
from megengine.module import Module
from megengine.module.external import CambriconSubgraph
from megengine.module.external import AtlasSubgraph, CambriconSubgraph


class MyModule(Module):
class CambriconModule(Module):
def __init__(self, data):
super().__init__()
self.cambricon = CambriconSubgraph(data, "subnet0", True)
@@ -31,7 +31,7 @@ def test_cambricon_module():
model = os.path.join(os.path.dirname(__file__), model)
with open(model, "rb") as f:
data = f.read()
m = MyModule(data)
m = CambriconModule(data)
inputs = []
inputs.append(tensor(dtype=np.float16, device="cambricon0"))
inputs[0].set_value(np.random.normal(size=(1, 64, 32, 32)).astype(np.float16))
@@ -41,3 +41,30 @@ def test_cambricon_module():
return pred

pred = inference(inputs)


class AtlasModule(Module):
def __init__(self, data):
super().__init__()
self.atlas = AtlasSubgraph(data)

def forward(self, inputs):
out = self.atlas(inputs)
return out


def test_atlas_module():
model = "AtlasRuntimeOprTest.basic.om"
model = os.path.join(os.path.dirname(__file__), model)
with open(model, "rb") as f:
data = f.read()
m = AtlasModule(data)
inputs = []
inputs.append(tensor(dtype=np.float32, device="atlas0"))
inputs[0].set_value(np.random.normal(size=(4, 3, 16, 16)).astype(np.float32))

def inference(inps):
pred = m(inps)
return pred

pred = inference(inputs)

Loading…
Cancel
Save