Browse Source

fix(mgb/tensorrt): fix trt runtime, padding channel to a multiple of 4 when using kCHW4 IOFormat

GitOrigin-RevId: c5f1ed70da
release-1.5
Megvii Engine Team huangxinda 3 years ago
parent
commit
2cd9823210
5 changed files with 113 additions and 8 deletions
  1. +6
    -0
      imperative/python/megengine/core/tensor/megbrain_graph.py
  2. +1
    -0
      imperative/python/src/graph_rt.cpp
  3. +7
    -0
      sdk/load-and-run/dump_with_testcase_mge.py
  4. +1
    -1
      src/tensorrt/impl/tensorrt_runtime_opr.cpp
  5. +98
    -7
      src/tensorrt/test/tensorrt_runtime.cpp

+ 6
- 0
imperative/python/megengine/core/tensor/megbrain_graph.py View File

@@ -266,6 +266,8 @@ def optimize_for_inference(dest_vars, **kwargs):
input for inference on nvidia backend(this optimization pass will
result in mismatch of the precision of output of training and
inference)
* enable_fuse_preprocess: whether to fuse astype\pad channel\dimshuffle and
etc opr from h2d opr.
"""
inference_options = GraphOptimizeOptions()
inference_optimize_layout_transform_map = {
@@ -291,6 +293,8 @@ def optimize_for_inference(dest_vars, **kwargs):
inference_options.fuse_conv_bias_nonlinearity = True
if kwargs.pop("enable_fuse_conv_bias_with_z", False):
inference_options.fuse_conv_bias_with_z = True
if kwargs.pop("enable_fuse_preprocess", False):
inference_options.fuse_preprocess = True

if kwargs:
raise ValueError("unknown options: %s" % list(kwargs))
@@ -335,6 +339,8 @@ def deserialize_infer_option(x: int) -> Dict[str, bool]:
ret["enable_fuse_conv_bias_nonlinearity"] = True
if inference_options.fuse_conv_bias_with_z:
ret["enable_fuse_conv_bias_with_z"] = True
if inference_options.fuse_preprocess:
ret["enable_fuse_preprocess"] = True

return ret



+ 1
- 0
imperative/python/src/graph_rt.cpp View File

@@ -251,6 +251,7 @@ void init_graph_rt(py::module m) {
.def_readwrite("f16_io_comp", &_OptimizeForInferenceOptions::f16_io_comp)
.def_readwrite("fuse_conv_bias_nonlinearity", &_OptimizeForInferenceOptions::fuse_conv_bias_nonlinearity)
.def_readwrite("fuse_conv_bias_with_z", &_OptimizeForInferenceOptions::fuse_conv_bias_with_z)
.def_readwrite("fuse_preprocess", &_OptimizeForInferenceOptions::fuse_preprocess)
.def_readwrite("layout_transform", &_OptimizeForInferenceOptions::layout_transform)
;



+ 7
- 0
sdk/load-and-run/dump_with_testcase_mge.py View File

@@ -309,6 +309,7 @@ def optimize_for_inference(args, outputs):
"enable_chwn4",
"enable_fuse_conv_bias_nonlinearity",
"enable_fuse_conv_bias_with_z",
"eaable_fuse_preprocess",
]
kwargs = {}
for k in args_list:
@@ -465,6 +466,12 @@ def main():
"nvidia GPU (this optimization pass will result in mismatch "
"of the precision of output of training and inference)",
)
parser.add_argument(
"--enable-fuse-preprocess",
action="store_true",
help="fuse astype\pad_channel\dimshuffle and etc opr "
"from h2d opr",
)
args = parser.parse_args()

feeds = make_feeds(args)


+ 1
- 1
src/tensorrt/impl/tensorrt_runtime_opr.cpp View File

@@ -117,7 +117,7 @@ void TensorRTRuntimeOpr::get_output_var_shape(
chan_pos = 1;
}
dims.nbDims = dims.nbDims + 1;
dims.d[chan_pos] = dims.d[chan_pos] / 4;
dims.d[chan_pos] = (dims.d[chan_pos] + 3) / 4;
dims.d[dims.nbDims - 1] = 4;
}
#endif


+ 98
- 7
src/tensorrt/test/tensorrt_runtime.cpp View File

@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "megbrain/comp_node_env.h"
@@ -14,12 +15,13 @@
#include "megbrain/test/helper.h"
#include "megbrain/test/megdnn_helper.h"
#include "megbrain/utils/debug.h"
#include "megbrain/opr/basic_arith.h"

#if MGB_ENABLE_TENSOR_RT

#include "make_trt_net.h"
#include "megbrain/tensorrt/tensorrt_opr.h"
#include "megbrain/tensorrt/tensorrt_runtime_opr.h"
#include "make_trt_net.h"

#include <random>

@@ -29,8 +31,6 @@ using namespace nvinfer1;
template <typename T>
using TensorRTUniquePtr = intl::TensorRTUniquePtr<T>;



TEST(TestOprTensorRT, RuntimeBasic) {
REQUIRE_GPU(1);
intl::SimpleTensorRTNetwork net;
@@ -61,7 +61,6 @@ TEST(TestOprTensorRT, RuntimeBasic) {
MGB_ASSERT_TENSOR_NEAR(host_z1, host_z2, 5e-4);
}


TEST(TestOprTensorRT, RuntimeBasicBatched) {
REQUIRE_GPU(1);
intl::BatchedTensorRTNetwork net;
@@ -80,7 +79,9 @@ TEST(TestOprTensorRT, RuntimeBasicBatched) {
builder->buildCudaEngine(*trt_net)};
#endif
TensorRTUniquePtr<IHostMemory> mem{cuda_engine->serialize(), {}};
auto nx = opr::Broadcast::make(net.x, {1, net.x.shape()[0], net.x.shape()[1], net.x.shape()[2]});
auto nx = opr::Broadcast::make(
net.x,
{1, net.x.shape()[0], net.x.shape()[1], net.x.shape()[2]});
return TensorRTRuntimeOpr::make(mem->data(), mem->size(), {nx})[0];
};
auto y2 = make_trt();
@@ -93,7 +94,6 @@ TEST(TestOprTensorRT, RuntimeBasicBatched) {
MGB_ASSERT_TENSOR_NEAR(host_z1, host_z2, 5e-4);
}


TEST(TestOprTensorRT, ConcatRuntimeBasic) {
REQUIRE_GPU(1);
intl::ConcatConvTensorRTNetwork net;
@@ -168,6 +168,97 @@ TEST(TestOprTensorRT, RuntimeChangeBatchSize) {
MGB_ASSERT_TENSOR_NEAR(host_z1, host_z2, 5e-4);
}

#if NV_TENSOR_RT_VERSION >= 6001
TEST(TestOprTensorRT, IOFormatFree) {
size_t N = 1, C = 3, H = 7, W = 7;
REQUIRE_GPU(1);
TensorRTUniquePtr<IBuilder> builder{
createInferBuilder(TensorRTOpr::Logger::instance()), {}};
nvinfer1::NetworkDefinitionCreationFlags flags;
::memset(&flags, 0, sizeof(nvinfer1::NetworkDefinitionCreationFlags));
flags = 1 << static_cast<int>(
nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
TensorRTUniquePtr<INetworkDefinition> network{
builder->createNetworkV2(flags), {}};
auto cast = [](size_t i) { return static_cast<int>(i); };
ITensor* data = network->addInput(
"data", DataType::kINT8, Dims4{cast(N), cast(C), cast(H), cast(W)});
TensorFormats formats = 1
<< static_cast<int>(nvinfer1::TensorFormat::kCHW4);
data->setAllowedFormats(formats);
data->setDynamicRange(-127.f * 1.2f, 127.f * 1.2f);
HostTensorGenerator<> fgen;
auto mean = fgen({N, C, H, W});
Weights mean_weights{DataType::kFLOAT, nullptr, 0};
mean_weights.values = mean->raw_ptr();
mean_weights.count = N * C * H * W;
auto constant = network->addConstant(
Dims4{cast(N), cast(C), cast(H), cast(W)}, mean_weights);
auto out = network->addElementWise(*network->getInput(0),
*constant->getOutput(0),
ElementWiseOperation::kSUB);
out->getOutput(0)->setDynamicRange(-127.f * 2.3f, 127.f * 2.3f);
network->markOutput(*out->getOutput(0));
network->getInput(0)->setType(DataType::kINT8);
network->getOutput(0)->setType(DataType::kFLOAT);
network->getOutput(0)->setAllowedFormats(
1 << static_cast<int>(nvinfer1::TensorFormat::kLINEAR));
TensorRTUniquePtr<IBuilderConfig> build_config{
builder->createBuilderConfig()};
build_config->setFlag(BuilderFlag::kINT8);
build_config->setFlag(BuilderFlag::kSTRICT_TYPES);
TensorRTUniquePtr<ICudaEngine> cuda_engine{
builder->buildEngineWithConfig(*network, *build_config)};
TensorRTUniquePtr<IHostMemory> mem{cuda_engine->serialize(), {}};

HostTensorGenerator<dtype::Int8> gen;
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
auto mkvar = [&](const char* name, const TensorShape& shp,
const DType& dtype) {
return opr::TypeCvt::make(
opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name),
dtype);
};
auto x = mkvar("x", {N, C, H, W}, dtype::QuantizedS8(1.2f));
auto fx = opr::TypeCvt::make(x, dtype::Float32());
auto wval = opr::SharedDeviceTensor::make(*graph, *mean).rename("mean");
auto z = fx - wval;
HostTensorND y1;
auto func1 = graph->compile({make_callback_copy(z, y1)});
func1->execute();

TensorShape shp{N, 1, H, W};
auto host = std::make_shared<HostTensorND>(x.node()->comp_node(), x.node()->dtype());
host->resize(shp);
auto ptr = host->raw_ptr();
size_t size_bytes = TensorLayout{shp, x.node()->dtype()}.span().dist_byte();
std::memset(ptr, 0, size_bytes);
auto padding = opr::ImmutableTensor::make(*graph, *host);
x = opr::Concat::make({x, padding}, 1);

auto nchw2nchw4 = [](SymbolVar x) {
auto xshp = opr::GetVarShape::make(x);

auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp = opr::Concat::make(
{sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0);
auto y0 = opr::Reshape::make(x, tshp);
auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2});
return y1;
};
x = nchw2nchw4(x);
auto trt = TensorRTRuntimeOpr::make(mem->data(), mem->size(), {x})[0];
HostTensorND y2;
auto func2 = graph->compile({make_callback_copy(trt, y2)});
func2->execute();
MGB_ASSERT_TENSOR_EQ(y1, y2);
}
#endif

#endif // MGB_ENABLE_TENSOR_RT

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

Loading…
Cancel
Save