Browse Source

feat(mge): tensorrt runtime opr

GitOrigin-RevId: 2fdd00adcb
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
abe3c165ba
6 changed files with 107 additions and 1 deletions
  1. +1
    -1
      .gitattributes
  2. +22
    -0
      imperative/python/megengine/functional/external.py
  3. +37
    -0
      imperative/python/megengine/module/external.py
  4. +8
    -0
      imperative/python/test/unit/module/test_external.py
  5. +32
    -0
      imperative/src/impl/ops/tensorrt_runtime.cpp
  6. +7
    -0
      src/core/include/megbrain/ir/ops.td

+ 1
- 1
.gitattributes View File

@@ -5,4 +5,4 @@ dnn/src/cuda/conv_bias/int8_imma/kimpl/* binary
dnn/src/cuda/batch_conv_bias/int8/kimpl/* binary
dnn/src/cuda/sass/prebuilt/map_defs.cpp binary
tools/mlir/mlir-tblgen filter=lfs diff=lfs merge=lfs -text
sdk/c-opr-loaders/mc40/example/sinopec_nv12_extra.neu filter=lfs diff=lfs merge=lfs -text
*.caffemodel filter=lfs diff=lfs merge=lfs -text

+ 22
- 0
imperative/python/megengine/functional/external.py View File

@@ -0,0 +1,22 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# pylint: disable=redefined-builtin
from typing import Sequence

from ..core._imperative_rt.core2 import apply
from ..core.ops import builtin


def tensorrt_runtime_opr(inputs, *, data: bytes = None):
# empty model will give None result
if data is None:
return None
op = builtin.TensorRTRuntime(data, len(data))
# return sequence of outputs
return apply(op, *inputs)

+ 37
- 0
imperative/python/megengine/module/external.py View File

@@ -0,0 +1,37 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# pylint: disable=redefined-builtin
import numpy as np

from ..functional.external import tensorrt_runtime_opr
from .module import Module


class TensorrtRuntimeSubgraph(Module):
r"""Load a serialized TensorrtRuntime subgraph.

See :func:`~.tensorrt_runtime_opr` for more details.
"""

def __init__(
self, data,
):
super(TensorrtRuntimeSubgraph, self).__init__()
self._data = data

@property
def data(self):
return self._data

@data.setter
def data(self, val):
self._data = np.frombuffer(val, dtype=np.uint8)

def forward(self, *inputs):
return tensorrt_runtime_opr(inputs, data=self._data)

+ 8
- 0
imperative/python/test/unit/module/test_external.py View File

@@ -6,14 +6,20 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import io
import os
import platform

import numpy as np
import pytest

import megengine as mge
import megengine.utils.comp_graph_tools as cgtools
from megengine import Tensor
from megengine.distributed.helper import get_device_count_by_fork
from megengine.jit import trace
from megengine.module import Module
from megengine.module.external import TensorrtRuntimeSubgraph


class MyModule(Module):
@@ -44,3 +50,5 @@ def test_cambricon_module():
return pred

pred = inference([inp])



+ 32
- 0
imperative/src/impl/ops/tensorrt_runtime.cpp View File

@@ -0,0 +1,32 @@
/**
* \file imperative/src/impl/ops/tensorrt_runtime.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "../op_trait.h"
#include "megbrain/imperative/ops/autogen.h"

#if MGB_ENABLE_TENSOR_RT
#include "megbrain/tensorrt/tensorrt_runtime_opr.h"
namespace mgb::imperative {

namespace { namespace tensorrt_runtime {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const TensorRTRuntime&>(def);
SymbolVarArray sinputs(inputs.begin(), inputs.end());
return opr::TensorRTRuntimeOpr::make(op.buf.c_str(), op.buf_size, sinputs);
}
OP_TRAIT_REG(TensorRTRuntime, TensorRTRuntime)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // tensorrt_runtime

} // namespace mgb::imperative
#endif // MGB_ENABLE_TENSOR_RT

+ 7
- 0
src/core/include/megbrain/ir/ops.td View File

@@ -241,4 +241,11 @@ def ElemwiseMultiType: MgbHashableOp<"ElemwiseMultiType", [ElemwiseMultiTypePara

def InplaceAdd: MgbHashableOp<"InplaceAdd", [EmptyParam]>;

def TensorRTRuntime: MgbHashableOp<"TensorRTRuntime"> {
let extraArguments = (ins
MgbStringAttr:$buf,
MgbSizeTAddr:$buf_size
);
}

#endif // MGB_OPS

Loading…
Cancel
Save