@@ -35,6 +35,17 @@ def cambricon_subgraph( | |||||
@wrap_io_tensor | @wrap_io_tensor | ||||
def atlas_subgraph(inputs: List[Tensor], data: bytes) -> List[Tensor]: | |||||
"""Load a serialized Atlas subgraph (i.e. om model) and | |||||
execute the operations defined in the subgraph. | |||||
:param inputs: List of input tensors of the subgraph. | |||||
:param data: The serialized subgraph. | |||||
""" | |||||
return mgb.opr.atlas_runtime(tuple(map(lambda x: x._symvar, inputs)), data) | |||||
@wrap_io_tensor | |||||
def extern_opr_subgraph( | def extern_opr_subgraph( | ||||
inputs, output_shapes: List[tuple], dump_name: str, dump_data: bytes, | inputs, output_shapes: List[tuple], dump_name: str, dump_data: bytes, | ||||
) -> List[Tensor]: | ) -> List[Tensor]: | ||||
@@ -8,7 +8,11 @@ | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import numpy as np | import numpy as np | ||||
from ..functional.external import cambricon_subgraph, extern_opr_subgraph | |||||
from ..functional.external import ( | |||||
atlas_subgraph, | |||||
cambricon_subgraph, | |||||
extern_opr_subgraph, | |||||
) | |||||
from .module import Module | from .module import Module | ||||
@@ -41,6 +45,29 @@ class CambriconSubgraph(Module): | |||||
return outputs | return outputs | ||||
class AtlasSubgraph(Module): | |||||
r"""Load a serialized Atlas subgraph. | |||||
See :func:`~.atlas_subgraph` for more details. | |||||
""" | |||||
def __init__(self, data): | |||||
super(AtlasSubgraph, self).__init__() | |||||
self._data = data | |||||
@property | |||||
def data(self): | |||||
return self._data.tobytes() | |||||
@data.setter | |||||
def data(self, val): | |||||
self._data = np.frombuffer(val, dtype=np.uint8) | |||||
def forward(self, inputs): | |||||
outputs = atlas_subgraph(inputs, self._data) | |||||
return outputs | |||||
class ExternOprSubgraph(Module): | class ExternOprSubgraph(Module): | ||||
r"""Load a serialized extern opr subgraph. | r"""Load a serialized extern opr subgraph. | ||||
""" | """ | ||||
@@ -13,10 +13,10 @@ import numpy as np | |||||
import megengine as mge | import megengine as mge | ||||
from megengine import tensor | from megengine import tensor | ||||
from megengine.module import Module | from megengine.module import Module | ||||
from megengine.module.external import CambriconSubgraph | |||||
from megengine.module.external import AtlasSubgraph, CambriconSubgraph | |||||
class MyModule(Module): | |||||
class CambriconModule(Module): | |||||
def __init__(self, data): | def __init__(self, data): | ||||
super().__init__() | super().__init__() | ||||
self.cambricon = CambriconSubgraph(data, "subnet0", True) | self.cambricon = CambriconSubgraph(data, "subnet0", True) | ||||
@@ -31,7 +31,7 @@ def test_cambricon_module(): | |||||
model = os.path.join(os.path.dirname(__file__), model) | model = os.path.join(os.path.dirname(__file__), model) | ||||
with open(model, "rb") as f: | with open(model, "rb") as f: | ||||
data = f.read() | data = f.read() | ||||
m = MyModule(data) | |||||
m = CambriconModule(data) | |||||
inputs = [] | inputs = [] | ||||
inputs.append(tensor(dtype=np.float16, device="cambricon0")) | inputs.append(tensor(dtype=np.float16, device="cambricon0")) | ||||
inputs[0].set_value(np.random.normal(size=(1, 64, 32, 32)).astype(np.float16)) | inputs[0].set_value(np.random.normal(size=(1, 64, 32, 32)).astype(np.float16)) | ||||
@@ -41,3 +41,30 @@ def test_cambricon_module(): | |||||
return pred | return pred | ||||
pred = inference(inputs) | pred = inference(inputs) | ||||
class AtlasModule(Module): | |||||
def __init__(self, data): | |||||
super().__init__() | |||||
self.atlas = AtlasSubgraph(data) | |||||
def forward(self, inputs): | |||||
out = self.atlas(inputs) | |||||
return out | |||||
def test_atlas_module(): | |||||
model = "AtlasRuntimeOprTest.basic.om" | |||||
model = os.path.join(os.path.dirname(__file__), model) | |||||
with open(model, "rb") as f: | |||||
data = f.read() | |||||
m = AtlasModule(data) | |||||
inputs = [] | |||||
inputs.append(tensor(dtype=np.float32, device="atlas0")) | |||||
inputs[0].set_value(np.random.normal(size=(4, 3, 16, 16)).astype(np.float32)) | |||||
def inference(inps): | |||||
pred = m(inps) | |||||
return pred | |||||
pred = inference(inputs) |