Browse Source

feat(imperative/src): python wrapper for cambricon and atlas runtime opr

GitOrigin-RevId: bd969d1339
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
c8697a7005
10 changed files with 213 additions and 9 deletions
  1. +5
    -1
      imperative/python/megengine/core/tensor/megbrain_graph.py
  2. +22
    -0
      imperative/python/megengine/device.py
  3. +27
    -0
      imperative/python/megengine/functional/external.py
  4. +11
    -5
      imperative/python/megengine/jit/tracing.py
  5. +54
    -1
      imperative/python/megengine/module/external.py
  6. +3
    -2
      imperative/python/megengine/utils/comp_graph_tools.py
  7. +2
    -0
      imperative/python/src/common.cpp
  8. +36
    -0
      imperative/src/impl/ops/atlas_runtime.cpp
  9. +37
    -0
      imperative/src/impl/ops/cambricon_runtime.cpp
  10. +16
    -0
      src/core/include/megbrain/ir/ops.td

+ 5
- 1
imperative/python/megengine/core/tensor/megbrain_graph.py View File

@@ -529,7 +529,11 @@ class InputNode(OpNode):


@property @property
def device(self): def device(self):
return self.outputs[0].device
var = self.outputs[0]
if isinstance(var, VarNode):
return var.device
else:
return var.comp_node


@property @property
def dtype(self): def dtype(self):


+ 22
- 0
imperative/python/megengine/device.py View File

@@ -36,6 +36,10 @@ def _str2device_type(type_str: str, allow_unspec: bool = True):
return DeviceType.CPU return DeviceType.CPU
elif type_str == "GPU" or type_str == "CUDA": elif type_str == "GPU" or type_str == "CUDA":
return DeviceType.CUDA return DeviceType.CUDA
elif type_str == "CAMBRICON":
return DeviceType.CAMBRICON
elif type_str == "ATLAS":
return DeviceType.ATLAS
else: else:
assert allow_unspec and str == "XPU", "device type can only be cpu, gpu or xpu" assert allow_unspec and str == "XPU", "device type can only be cpu, gpu or xpu"
return DeviceType.UNSPEC return DeviceType.UNSPEC
@@ -65,6 +69,24 @@ def is_cuda_available() -> bool:
return CompNode._get_device_count(t, False) > 0 return CompNode._get_device_count(t, False) > 0




def is_cambricon_available() -> bool:
"""
Returns whether cambricon device is available on this system.

"""
t = _str2device_type("cambricon")
return CompNode._get_device_count(t, False) > 0


def is_atlas_available() -> bool:
"""
Returns whether atlas device is available on this system.

"""
t = _str2device_type("atlas")
return CompNode._get_device_count(t, False) > 0


def set_default_device(device: str = "xpux"): def set_default_device(device: str = "xpux"):
r""" r"""
Sets default computing node. Sets default computing node.


+ 27
- 0
imperative/python/megengine/functional/external.py View File

@@ -20,3 +20,30 @@ def tensorrt_runtime_opr(inputs, *, data: bytes = None):
op = builtin.TensorRTRuntime(data, len(data)) op = builtin.TensorRTRuntime(data, len(data))
# return sequence of outputs # return sequence of outputs
return apply(op, *inputs) return apply(op, *inputs)


def cambricon_runtime_opr(inputs, data, symbol, tensor_dim_mutable):
r"""
Load a serialized Cambricon model as a runtime operator in MegEngine.

:param inputs: list of input tensors.
:param data: the serialized Cambricon model.
:param symbol: name of the function in Cambricon model.
:param tensor_dim_mutable: whether the input tensors' shapes are mutable
in ``cnrtModel_t``.
"""

op = builtin.CambriconRuntime(data, len(data), symbol, tensor_dim_mutable)
return apply(op, *inputs)


def atlas_runtime_opr(inputs, data):
r"""
Load a serialized Atlas model as a runtime operator in MegEngine.

:param inputs: list of input tensors.
:param data: the serialized Atlas model.
"""

op = builtin.AtlasRuntime(data, len(data))
return apply(op, *inputs)

+ 11
- 5
imperative/python/megengine/jit/tracing.py View File

@@ -786,7 +786,11 @@ class trace:
) )
output_names = output_names or self._output_names output_names = output_names or self._output_names


dumped_device = as_device("xpux")
def dumped_device(info):
device_name = info.device.logical_name
if device_name[:3] in ("cpu", "gpu", "xpu"):
return as_device("xpux")
return info.device


h2v = {} h2v = {}
graph = G.Graph() graph = G.Graph()
@@ -794,19 +798,21 @@ class trace:
# apply graph_opt_level in dump # apply graph_opt_level in dump
if self._graph_opt_level is not None: if self._graph_opt_level is not None:
graph.options.graph_opt_level = self._graph_opt_level graph.options.graph_opt_level = self._graph_opt_level

for i, h in enumerate(self._arg_bindings): for i, h in enumerate(self._arg_bindings):
info = self._tinfo[h] info = self._tinfo[h]
h2v[h] = graph.make_h2d( h2v[h] = graph.make_h2d(
dtype=info.dtype, dtype=info.dtype,
device=dumped_device,
device=dumped_device(info),
shape=info.shape or (1,), shape=info.shape or (1,),
name=arg_names[i] if arg_names else None, name=arg_names[i] if arg_names else None,
) )
for k, h in self._kwarg_bindings.items(): for k, h in self._kwarg_bindings.items():
info = self._tinfo[h] info = self._tinfo[h]
h2v[h] = graph.make_h2d( h2v[h] = graph.make_h2d(
dtype=info.dtype, device=dumped_device, shape=info.shape or (1,), name=k
dtype=info.dtype,
device=dumped_device(info),
shape=info.shape or (1,),
name=k,
) )


for op, ihandles, ohandles in self._seq: for op, ihandles, ohandles in self._seq:
@@ -833,7 +839,7 @@ class trace:
h2v[h] = graph.make_const( h2v[h] = graph.make_const(
info.bound_data.numpy(), info.bound_data.numpy(),
dtype=info.dtype, dtype=info.dtype,
device=dumped_device,
device=dumped_device(info),
name=info.name, name=info.name,
) )
ivars.append(h2v[h]) ivars.append(h2v[h])


+ 54
- 1
imperative/python/megengine/module/external.py View File

@@ -9,7 +9,11 @@
# pylint: disable=redefined-builtin # pylint: disable=redefined-builtin
import numpy as np import numpy as np


from ..functional.external import tensorrt_runtime_opr
from ..functional.external import (
atlas_runtime_opr,
cambricon_runtime_opr,
tensorrt_runtime_opr,
)
from .module import Module from .module import Module




@@ -33,3 +37,52 @@ class TensorrtRuntimeSubgraph(Module):


def forward(self, *inputs): def forward(self, *inputs):
return tensorrt_runtime_opr(inputs, data=self._data) return tensorrt_runtime_opr(inputs, data=self._data)


class CambriconRuntimeSubgraph(Module):
r"""Load a serialized CambriconRuntime subgraph.

See :func:`~.cambricon_runtime_opr` for more details.
"""

def __init__(self, data, symbol, tensor_dim_mutable, **kwargs):
super(CambriconRuntimeSubgraph, self).__init__(**kwargs)
self._data = data
self.symbol = symbol
self.tensor_dim_mutable = tensor_dim_mutable

@property
def data(self):
return self._data

@data.setter
def data(self, val):
self._data = np.frombuffer(val, dtype=np.uint8)

def forward(self, *inputs):
outputs = cambricon_runtime_opr(
inputs, self._data, self.symbol, self.tensor_dim_mutable
)
return outputs


class AtlasRuntimeSubgraph(Module):
r"""Load a serialized AtlasRuntime subgraph.

See :func:`~.atlas_runtime_opr` for more details.
"""

def __init__(self, data, **kwargs):
super(AtlasRuntimeSubgraph, self).__init__(**kwargs)
self._data = data

@property
def data(self):
return self._data

@data.setter
def data(self, val):
self._data = np.frombuffer(val, dtype=np.uint8)

def forward(self, *inputs):
return atlas_runtime_opr(inputs, data=self._data)

+ 3
- 2
imperative/python/megengine/utils/comp_graph_tools.py View File

@@ -427,8 +427,9 @@ class GraphInference:
list(self._inp_dict.keys()), list(inputs.keys()) list(self._inp_dict.keys()), list(inputs.keys())
) )
for key in self._inp_dict: for key in self._inp_dict:
self._inp_dict[key].set_value(Tensor(inputs[key])._dev_tensor())

self._inp_dict[key].set_value(
Tensor(inputs[key], device=self._inp_dict[key].device)._dev_tensor()
)
self._func.execute() self._func.execute()
self._func.wait() self._func.wait()




+ 2
- 0
imperative/python/src/common.cpp View File

@@ -171,6 +171,8 @@ void init_common(py::module m) {
.value("UNSPEC", CompNode::DeviceType::UNSPEC) .value("UNSPEC", CompNode::DeviceType::UNSPEC)
.value("CUDA", CompNode::DeviceType::CUDA) .value("CUDA", CompNode::DeviceType::CUDA)
.value("CPU", CompNode::DeviceType::CPU) .value("CPU", CompNode::DeviceType::CPU)
.value("CAMBRICON", CompNode::DeviceType::CAMBRICON)
.value("ATLAS", CompNode::DeviceType::ATLAS)
.value("MULTITHREAD", CompNode::DeviceType::MULTITHREAD) .value("MULTITHREAD", CompNode::DeviceType::MULTITHREAD)
.value("MAX_DEVICE_ID", CompNode::DeviceType::MAX_DEVICE_ID); .value("MAX_DEVICE_ID", CompNode::DeviceType::MAX_DEVICE_ID);




+ 36
- 0
imperative/src/impl/ops/atlas_runtime.cpp View File

@@ -0,0 +1,36 @@
/**
* \file imperative/src/impl/ops/tensorrt_runtime.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "../op_trait.h"
#include "megbrain/imperative/ops/autogen.h"

#if MGB_ATLAS
#include "megbrain/opr/atlas_runtime_op.h"
namespace mgb::imperative {

namespace {
namespace atlas_runtime {

auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const AtlasRuntime&>(def);
SymbolVarArray symbol_var_inputs(inputs.begin(), inputs.end());
OperatorNodeConfig config{op.make_name()};
return opr::AtlasRuntimeOpr::make(op.buf.c_str(), op.buf_size,
symbol_var_inputs, config);
}
OP_TRAIT_REG(AtlasRuntime, AtlasRuntime)
.apply_on_var_node(apply_on_var_node)
.fallback();
} // namespace atlas_runtime
} // namespace

} // namespace mgb::imperative
#endif

+ 37
- 0
imperative/src/impl/ops/cambricon_runtime.cpp View File

@@ -0,0 +1,37 @@
/**
* \file imperative/src/impl/ops/tensorrt_runtime.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "../op_trait.h"
#include "megbrain/imperative/ops/autogen.h"

#if MGB_CAMBRICON
#include "megbrain/cambricon/cambricon_runtime_opr.h"
namespace mgb::imperative {

namespace {
namespace cambricon_runtime {

auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const CambriconRuntime&>(def);
SymbolVarArray symbol_var_inputs(inputs.begin(), inputs.end());
OperatorNodeConfig config{op.make_name()};
return opr::CambriconRuntimeOpr::make(op.buf.c_str(), op.buf_size,
op.symbol, symbol_var_inputs,
op.tensor_dim_mutable, config);
}
OP_TRAIT_REG(CambriconRuntime, CambriconRuntime)
.apply_on_var_node(apply_on_var_node)
.fallback();
} // namespace cambricon_runtime
} // namespace

} // namespace mgb::imperative
#endif

+ 16
- 0
src/core/include/megbrain/ir/ops.td View File

@@ -266,6 +266,22 @@ def TensorRTRuntime: MgbHashableOp<"TensorRTRuntime"> {
); );
} }


def AtlasRuntime: MgbHashableOp<"AtlasRuntime"> {
let extraArguments = (ins
MgbStringAttr:$buf,
MgbSizeTAddr:$buf_size
);
}

def CambriconRuntime: MgbHashableOp<"CambriconRuntime"> {
let extraArguments = (ins
MgbStringAttr:$buf,
MgbSizeTAddr:$buf_size,
MgbStringAttr:$symbol,
MgbBoolAttr:$tensor_dim_mutable
);
}

def CvtColor: MgbHashableOp<"CvtColor", [CvtColorParam]>; def CvtColor: MgbHashableOp<"CvtColor", [CvtColorParam]>;


#endif // MGB_OPS #endif // MGB_OPS

Loading…
Cancel
Save