Browse Source

feat(mge): tensorrt runtime opr

GitOrigin-RevId: 2fdd00adcb
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
abe3c165ba
6 changed files with 107 additions and 1 deletions
  1. +1
    -1
      .gitattributes
  2. +22
    -0
      imperative/python/megengine/functional/external.py
  3. +37
    -0
      imperative/python/megengine/module/external.py
  4. +8
    -0
      imperative/python/test/unit/module/test_external.py
  5. +32
    -0
      imperative/src/impl/ops/tensorrt_runtime.cpp
  6. +7
    -0
      src/core/include/megbrain/ir/ops.td

+ 1
- 1
.gitattributes View File

@@ -5,4 +5,4 @@ dnn/src/cuda/conv_bias/int8_imma/kimpl/* binary
dnn/src/cuda/batch_conv_bias/int8/kimpl/* binary dnn/src/cuda/batch_conv_bias/int8/kimpl/* binary
dnn/src/cuda/sass/prebuilt/map_defs.cpp binary dnn/src/cuda/sass/prebuilt/map_defs.cpp binary
tools/mlir/mlir-tblgen filter=lfs diff=lfs merge=lfs -text tools/mlir/mlir-tblgen filter=lfs diff=lfs merge=lfs -text
sdk/c-opr-loaders/mc40/example/sinopec_nv12_extra.neu filter=lfs diff=lfs merge=lfs -text
*.caffemodel filter=lfs diff=lfs merge=lfs -text

+ 22
- 0
imperative/python/megengine/functional/external.py View File

@@ -0,0 +1,22 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# pylint: disable=redefined-builtin
from typing import Sequence

from ..core._imperative_rt.core2 import apply
from ..core.ops import builtin


def tensorrt_runtime_opr(inputs, *, data: bytes = None):
# empty model will give None result
if data is None:
return None
op = builtin.TensorRTRuntime(data, len(data))
# return sequence of outputs
return apply(op, *inputs)

+ 37
- 0
imperative/python/megengine/module/external.py View File

@@ -0,0 +1,37 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# pylint: disable=redefined-builtin
import numpy as np

from ..functional.external import tensorrt_runtime_opr
from .module import Module


class TensorrtRuntimeSubgraph(Module):
r"""Load a serialized TensorrtRuntime subgraph.

See :func:`~.tensorrt_runtime_opr` for more details.
"""

def __init__(
self, data,
):
super(TensorrtRuntimeSubgraph, self).__init__()
self._data = data

@property
def data(self):
return self._data

@data.setter
def data(self, val):
self._data = np.frombuffer(val, dtype=np.uint8)

def forward(self, *inputs):
return tensorrt_runtime_opr(inputs, data=self._data)

+ 8
- 0
imperative/python/test/unit/module/test_external.py View File

@@ -6,14 +6,20 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import io
import os import os
import platform


import numpy as np import numpy as np
import pytest import pytest


import megengine as mge import megengine as mge
import megengine.utils.comp_graph_tools as cgtools
from megengine import Tensor from megengine import Tensor
from megengine.distributed.helper import get_device_count_by_fork
from megengine.jit import trace
from megengine.module import Module from megengine.module import Module
from megengine.module.external import TensorrtRuntimeSubgraph




class MyModule(Module): class MyModule(Module):
@@ -44,3 +50,5 @@ def test_cambricon_module():
return pred return pred


pred = inference([inp]) pred = inference([inp])



+ 32
- 0
imperative/src/impl/ops/tensorrt_runtime.cpp View File

@@ -0,0 +1,32 @@
/**
* \file imperative/src/impl/ops/tensorrt_runtime.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "../op_trait.h"
#include "megbrain/imperative/ops/autogen.h"

#if MGB_ENABLE_TENSOR_RT
#include "megbrain/tensorrt/tensorrt_runtime_opr.h"
namespace mgb::imperative {

namespace { namespace tensorrt_runtime {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const TensorRTRuntime&>(def);
SymbolVarArray sinputs(inputs.begin(), inputs.end());
return opr::TensorRTRuntimeOpr::make(op.buf.c_str(), op.buf_size, sinputs);
}
OP_TRAIT_REG(TensorRTRuntime, TensorRTRuntime)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // tensorrt_runtime

} // namespace mgb::imperative
#endif // MGB_ENABLE_TENSOR_RT

+ 7
- 0
src/core/include/megbrain/ir/ops.td View File

@@ -241,4 +241,11 @@ def ElemwiseMultiType: MgbHashableOp<"ElemwiseMultiType", [ElemwiseMultiTypePara


def InplaceAdd: MgbHashableOp<"InplaceAdd", [EmptyParam]>; def InplaceAdd: MgbHashableOp<"InplaceAdd", [EmptyParam]>;


def TensorRTRuntime: MgbHashableOp<"TensorRTRuntime"> {
let extraArguments = (ins
MgbStringAttr:$buf,
MgbSizeTAddr:$buf_size
);
}

#endif // MGB_OPS #endif // MGB_OPS

Loading…
Cancel
Save