Browse Source

fix(mgb/tensorrt): fix trt runtime, padding channel to a multiple of 4 when using kCHW4 IOFormat

GitOrigin-RevId: c5f1ed70da
release-1.5
Megvii Engine Team huangxinda 3 years ago
parent
commit
2cd9823210
5 changed files with 113 additions and 8 deletions
  1. +6
    -0
      imperative/python/megengine/core/tensor/megbrain_graph.py
  2. +1
    -0
      imperative/python/src/graph_rt.cpp
  3. +7
    -0
      sdk/load-and-run/dump_with_testcase_mge.py
  4. +1
    -1
      src/tensorrt/impl/tensorrt_runtime_opr.cpp
  5. +98
    -7
      src/tensorrt/test/tensorrt_runtime.cpp

+ 6
- 0
imperative/python/megengine/core/tensor/megbrain_graph.py View File

@@ -266,6 +266,8 @@ def optimize_for_inference(dest_vars, **kwargs):
input for inference on nvidia backend(this optimization pass will input for inference on nvidia backend(this optimization pass will
result in mismatch of the precision of output of training and result in mismatch of the precision of output of training and
inference) inference)
* enable_fuse_preprocess: whether to fuse astype\pad channel\dimshuffle and
etc opr from h2d opr.
""" """
inference_options = GraphOptimizeOptions() inference_options = GraphOptimizeOptions()
inference_optimize_layout_transform_map = { inference_optimize_layout_transform_map = {
@@ -291,6 +293,8 @@ def optimize_for_inference(dest_vars, **kwargs):
inference_options.fuse_conv_bias_nonlinearity = True inference_options.fuse_conv_bias_nonlinearity = True
if kwargs.pop("enable_fuse_conv_bias_with_z", False): if kwargs.pop("enable_fuse_conv_bias_with_z", False):
inference_options.fuse_conv_bias_with_z = True inference_options.fuse_conv_bias_with_z = True
if kwargs.pop("enable_fuse_preprocess", False):
inference_options.fuse_preprocess = True


if kwargs: if kwargs:
raise ValueError("unknown options: %s" % list(kwargs)) raise ValueError("unknown options: %s" % list(kwargs))
@@ -335,6 +339,8 @@ def deserialize_infer_option(x: int) -> Dict[str, bool]:
ret["enable_fuse_conv_bias_nonlinearity"] = True ret["enable_fuse_conv_bias_nonlinearity"] = True
if inference_options.fuse_conv_bias_with_z: if inference_options.fuse_conv_bias_with_z:
ret["enable_fuse_conv_bias_with_z"] = True ret["enable_fuse_conv_bias_with_z"] = True
if inference_options.fuse_preprocess:
ret["enable_fuse_preprocess"] = True


return ret return ret




+ 1
- 0
imperative/python/src/graph_rt.cpp View File

@@ -251,6 +251,7 @@ void init_graph_rt(py::module m) {
.def_readwrite("f16_io_comp", &_OptimizeForInferenceOptions::f16_io_comp) .def_readwrite("f16_io_comp", &_OptimizeForInferenceOptions::f16_io_comp)
.def_readwrite("fuse_conv_bias_nonlinearity", &_OptimizeForInferenceOptions::fuse_conv_bias_nonlinearity) .def_readwrite("fuse_conv_bias_nonlinearity", &_OptimizeForInferenceOptions::fuse_conv_bias_nonlinearity)
.def_readwrite("fuse_conv_bias_with_z", &_OptimizeForInferenceOptions::fuse_conv_bias_with_z) .def_readwrite("fuse_conv_bias_with_z", &_OptimizeForInferenceOptions::fuse_conv_bias_with_z)
.def_readwrite("fuse_preprocess", &_OptimizeForInferenceOptions::fuse_preprocess)
.def_readwrite("layout_transform", &_OptimizeForInferenceOptions::layout_transform) .def_readwrite("layout_transform", &_OptimizeForInferenceOptions::layout_transform)
; ;




+ 7
- 0
sdk/load-and-run/dump_with_testcase_mge.py View File

@@ -309,6 +309,7 @@ def optimize_for_inference(args, outputs):
"enable_chwn4", "enable_chwn4",
"enable_fuse_conv_bias_nonlinearity", "enable_fuse_conv_bias_nonlinearity",
"enable_fuse_conv_bias_with_z", "enable_fuse_conv_bias_with_z",
"eaable_fuse_preprocess",
] ]
kwargs = {} kwargs = {}
for k in args_list: for k in args_list:
@@ -465,6 +466,12 @@ def main():
"nvidia GPU (this optimization pass will result in mismatch " "nvidia GPU (this optimization pass will result in mismatch "
"of the precision of output of training and inference)", "of the precision of output of training and inference)",
) )
parser.add_argument(
"--enable-fuse-preprocess",
action="store_true",
help="fuse astype\pad_channel\dimshuffle and etc opr "
"from h2d opr",
)
args = parser.parse_args() args = parser.parse_args()


feeds = make_feeds(args) feeds = make_feeds(args)


+ 1
- 1
src/tensorrt/impl/tensorrt_runtime_opr.cpp View File

@@ -117,7 +117,7 @@ void TensorRTRuntimeOpr::get_output_var_shape(
chan_pos = 1; chan_pos = 1;
} }
dims.nbDims = dims.nbDims + 1; dims.nbDims = dims.nbDims + 1;
dims.d[chan_pos] = dims.d[chan_pos] / 4;
dims.d[chan_pos] = (dims.d[chan_pos] + 3) / 4;
dims.d[dims.nbDims - 1] = 4; dims.d[dims.nbDims - 1] = 4;
} }
#endif #endif


+ 98
- 7
src/tensorrt/test/tensorrt_runtime.cpp View File

@@ -6,7 +6,8 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */


#include "megbrain/comp_node_env.h" #include "megbrain/comp_node_env.h"
@@ -14,12 +15,13 @@
#include "megbrain/test/helper.h" #include "megbrain/test/helper.h"
#include "megbrain/test/megdnn_helper.h" #include "megbrain/test/megdnn_helper.h"
#include "megbrain/utils/debug.h" #include "megbrain/utils/debug.h"
#include "megbrain/opr/basic_arith.h"


#if MGB_ENABLE_TENSOR_RT #if MGB_ENABLE_TENSOR_RT


#include "make_trt_net.h"
#include "megbrain/tensorrt/tensorrt_opr.h" #include "megbrain/tensorrt/tensorrt_opr.h"
#include "megbrain/tensorrt/tensorrt_runtime_opr.h" #include "megbrain/tensorrt/tensorrt_runtime_opr.h"
#include "make_trt_net.h"


#include <random> #include <random>


@@ -29,8 +31,6 @@ using namespace nvinfer1;
template <typename T> template <typename T>
using TensorRTUniquePtr = intl::TensorRTUniquePtr<T>; using TensorRTUniquePtr = intl::TensorRTUniquePtr<T>;




TEST(TestOprTensorRT, RuntimeBasic) { TEST(TestOprTensorRT, RuntimeBasic) {
REQUIRE_GPU(1); REQUIRE_GPU(1);
intl::SimpleTensorRTNetwork net; intl::SimpleTensorRTNetwork net;
@@ -61,7 +61,6 @@ TEST(TestOprTensorRT, RuntimeBasic) {
MGB_ASSERT_TENSOR_NEAR(host_z1, host_z2, 5e-4); MGB_ASSERT_TENSOR_NEAR(host_z1, host_z2, 5e-4);
} }



TEST(TestOprTensorRT, RuntimeBasicBatched) { TEST(TestOprTensorRT, RuntimeBasicBatched) {
REQUIRE_GPU(1); REQUIRE_GPU(1);
intl::BatchedTensorRTNetwork net; intl::BatchedTensorRTNetwork net;
@@ -80,7 +79,9 @@ TEST(TestOprTensorRT, RuntimeBasicBatched) {
builder->buildCudaEngine(*trt_net)}; builder->buildCudaEngine(*trt_net)};
#endif #endif
TensorRTUniquePtr<IHostMemory> mem{cuda_engine->serialize(), {}}; TensorRTUniquePtr<IHostMemory> mem{cuda_engine->serialize(), {}};
auto nx = opr::Broadcast::make(net.x, {1, net.x.shape()[0], net.x.shape()[1], net.x.shape()[2]});
auto nx = opr::Broadcast::make(
net.x,
{1, net.x.shape()[0], net.x.shape()[1], net.x.shape()[2]});
return TensorRTRuntimeOpr::make(mem->data(), mem->size(), {nx})[0]; return TensorRTRuntimeOpr::make(mem->data(), mem->size(), {nx})[0];
}; };
auto y2 = make_trt(); auto y2 = make_trt();
@@ -93,7 +94,6 @@ TEST(TestOprTensorRT, RuntimeBasicBatched) {
MGB_ASSERT_TENSOR_NEAR(host_z1, host_z2, 5e-4); MGB_ASSERT_TENSOR_NEAR(host_z1, host_z2, 5e-4);
} }



TEST(TestOprTensorRT, ConcatRuntimeBasic) { TEST(TestOprTensorRT, ConcatRuntimeBasic) {
REQUIRE_GPU(1); REQUIRE_GPU(1);
intl::ConcatConvTensorRTNetwork net; intl::ConcatConvTensorRTNetwork net;
@@ -168,6 +168,97 @@ TEST(TestOprTensorRT, RuntimeChangeBatchSize) {
MGB_ASSERT_TENSOR_NEAR(host_z1, host_z2, 5e-4); MGB_ASSERT_TENSOR_NEAR(host_z1, host_z2, 5e-4);
} }


#if NV_TENSOR_RT_VERSION >= 6001
TEST(TestOprTensorRT, IOFormatFree) {
size_t N = 1, C = 3, H = 7, W = 7;
REQUIRE_GPU(1);
TensorRTUniquePtr<IBuilder> builder{
createInferBuilder(TensorRTOpr::Logger::instance()), {}};
nvinfer1::NetworkDefinitionCreationFlags flags;
::memset(&flags, 0, sizeof(nvinfer1::NetworkDefinitionCreationFlags));
flags = 1 << static_cast<int>(
nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
TensorRTUniquePtr<INetworkDefinition> network{
builder->createNetworkV2(flags), {}};
auto cast = [](size_t i) { return static_cast<int>(i); };
ITensor* data = network->addInput(
"data", DataType::kINT8, Dims4{cast(N), cast(C), cast(H), cast(W)});
TensorFormats formats = 1
<< static_cast<int>(nvinfer1::TensorFormat::kCHW4);
data->setAllowedFormats(formats);
data->setDynamicRange(-127.f * 1.2f, 127.f * 1.2f);
HostTensorGenerator<> fgen;
auto mean = fgen({N, C, H, W});
Weights mean_weights{DataType::kFLOAT, nullptr, 0};
mean_weights.values = mean->raw_ptr();
mean_weights.count = N * C * H * W;
auto constant = network->addConstant(
Dims4{cast(N), cast(C), cast(H), cast(W)}, mean_weights);
auto out = network->addElementWise(*network->getInput(0),
*constant->getOutput(0),
ElementWiseOperation::kSUB);
out->getOutput(0)->setDynamicRange(-127.f * 2.3f, 127.f * 2.3f);
network->markOutput(*out->getOutput(0));
network->getInput(0)->setType(DataType::kINT8);
network->getOutput(0)->setType(DataType::kFLOAT);
network->getOutput(0)->setAllowedFormats(
1 << static_cast<int>(nvinfer1::TensorFormat::kLINEAR));
TensorRTUniquePtr<IBuilderConfig> build_config{
builder->createBuilderConfig()};
build_config->setFlag(BuilderFlag::kINT8);
build_config->setFlag(BuilderFlag::kSTRICT_TYPES);
TensorRTUniquePtr<ICudaEngine> cuda_engine{
builder->buildEngineWithConfig(*network, *build_config)};
TensorRTUniquePtr<IHostMemory> mem{cuda_engine->serialize(), {}};

HostTensorGenerator<dtype::Int8> gen;
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
auto mkvar = [&](const char* name, const TensorShape& shp,
const DType& dtype) {
return opr::TypeCvt::make(
opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name),
dtype);
};
auto x = mkvar("x", {N, C, H, W}, dtype::QuantizedS8(1.2f));
auto fx = opr::TypeCvt::make(x, dtype::Float32());
auto wval = opr::SharedDeviceTensor::make(*graph, *mean).rename("mean");
auto z = fx - wval;
HostTensorND y1;
auto func1 = graph->compile({make_callback_copy(z, y1)});
func1->execute();

TensorShape shp{N, 1, H, W};
auto host = std::make_shared<HostTensorND>(x.node()->comp_node(), x.node()->dtype());
host->resize(shp);
auto ptr = host->raw_ptr();
size_t size_bytes = TensorLayout{shp, x.node()->dtype()}.span().dist_byte();
std::memset(ptr, 0, size_bytes);
auto padding = opr::ImmutableTensor::make(*graph, *host);
x = opr::Concat::make({x, padding}, 1);

auto nchw2nchw4 = [](SymbolVar x) {
auto xshp = opr::GetVarShape::make(x);

auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp = opr::Concat::make(
{sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0);
auto y0 = opr::Reshape::make(x, tshp);
auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2});
return y1;
};
x = nchw2nchw4(x);
auto trt = TensorRTRuntimeOpr::make(mem->data(), mem->size(), {x})[0];
HostTensorND y2;
auto func2 = graph->compile({make_callback_copy(trt, y2)});
func2->execute();
MGB_ASSERT_TENSOR_EQ(y1, y2);
}
#endif

#endif // MGB_ENABLE_TENSOR_RT #endif // MGB_ENABLE_TENSOR_RT


// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

Loading…
Cancel
Save