Browse Source

feat(mge): add atlas_subgraph module

GitOrigin-RevId: 11530383c0
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
3dbac4f47f
4 changed files with 69 additions and 4 deletions
  1. +11
    -0
      python_module/megengine/functional/external.py
  2. +28
    -1
      python_module/megengine/module/external.py
  3. BIN
      python_module/test/unit/module/AtlasRuntimeOprTest.basic.om
  4. +30
    -3
      python_module/test/unit/module/test_external.py

+ 11
- 0
python_module/megengine/functional/external.py View File

@@ -35,6 +35,17 @@ def cambricon_subgraph(




@wrap_io_tensor @wrap_io_tensor
def atlas_subgraph(inputs: List[Tensor], data: bytes) -> List[Tensor]:
"""Load a serialized Atlas subgraph (i.e. om model) and
execute the operations defined in the subgraph.

:param inputs: List of input tensors of the subgraph.
:param data: The serialized subgraph.
"""
return mgb.opr.atlas_runtime(tuple(map(lambda x: x._symvar, inputs)), data)


@wrap_io_tensor
def extern_opr_subgraph( def extern_opr_subgraph(
inputs, output_shapes: List[tuple], dump_name: str, dump_data: bytes, inputs, output_shapes: List[tuple], dump_name: str, dump_data: bytes,
) -> List[Tensor]: ) -> List[Tensor]:


+ 28
- 1
python_module/megengine/module/external.py View File

@@ -8,7 +8,11 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np import numpy as np


from ..functional.external import cambricon_subgraph, extern_opr_subgraph
from ..functional.external import (
atlas_subgraph,
cambricon_subgraph,
extern_opr_subgraph,
)
from .module import Module from .module import Module




@@ -41,6 +45,29 @@ class CambriconSubgraph(Module):
return outputs return outputs




class AtlasSubgraph(Module):
r"""Load a serialized Atlas subgraph.

See :func:`~.atlas_subgraph` for more details.
"""

def __init__(self, data):
super(AtlasSubgraph, self).__init__()
self._data = data

@property
def data(self):
return self._data.tobytes()

@data.setter
def data(self, val):
self._data = np.frombuffer(val, dtype=np.uint8)

def forward(self, inputs):
outputs = atlas_subgraph(inputs, self._data)
return outputs


class ExternOprSubgraph(Module): class ExternOprSubgraph(Module):
r"""Load a serialized extern opr subgraph. r"""Load a serialized extern opr subgraph.
""" """


BIN
python_module/test/unit/module/AtlasRuntimeOprTest.basic.om View File


+ 30
- 3
python_module/test/unit/module/test_external.py View File

@@ -13,10 +13,10 @@ import numpy as np
import megengine as mge import megengine as mge
from megengine import tensor from megengine import tensor
from megengine.module import Module from megengine.module import Module
from megengine.module.external import CambriconSubgraph
from megengine.module.external import AtlasSubgraph, CambriconSubgraph




class MyModule(Module):
class CambriconModule(Module):
def __init__(self, data): def __init__(self, data):
super().__init__() super().__init__()
self.cambricon = CambriconSubgraph(data, "subnet0", True) self.cambricon = CambriconSubgraph(data, "subnet0", True)
@@ -31,7 +31,7 @@ def test_cambricon_module():
model = os.path.join(os.path.dirname(__file__), model) model = os.path.join(os.path.dirname(__file__), model)
with open(model, "rb") as f: with open(model, "rb") as f:
data = f.read() data = f.read()
m = MyModule(data)
m = CambriconModule(data)
inputs = [] inputs = []
inputs.append(tensor(dtype=np.float16, device="cambricon0")) inputs.append(tensor(dtype=np.float16, device="cambricon0"))
inputs[0].set_value(np.random.normal(size=(1, 64, 32, 32)).astype(np.float16)) inputs[0].set_value(np.random.normal(size=(1, 64, 32, 32)).astype(np.float16))
@@ -41,3 +41,30 @@ def test_cambricon_module():
return pred return pred


pred = inference(inputs) pred = inference(inputs)


class AtlasModule(Module):
def __init__(self, data):
super().__init__()
self.atlas = AtlasSubgraph(data)

def forward(self, inputs):
out = self.atlas(inputs)
return out


def test_atlas_module():
model = "AtlasRuntimeOprTest.basic.om"
model = os.path.join(os.path.dirname(__file__), model)
with open(model, "rb") as f:
data = f.read()
m = AtlasModule(data)
inputs = []
inputs.append(tensor(dtype=np.float32, device="atlas0"))
inputs[0].set_value(np.random.normal(size=(4, 3, 16, 16)).astype(np.float32))

def inference(inps):
pred = m(inps)
return pred

pred = inference(inputs)

Loading…
Cancel
Save