From 565466c25f55e07dc7a9f512ba6122b26767a246 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 11 Oct 2021 16:16:08 +0800 Subject: [PATCH] feat(lite): auto deduce output tensor shape before model forward GitOrigin-RevId: 78e00dab5da3fcc91bb53d8588c06b5b25295e19 --- lite/pylite/megenginelite/struct.py | 1 + lite/pylite/megenginelite/tensor.py | 6 ++++-- lite/src/mge/common.cpp | 6 ++++++ lite/src/mge/network_impl.cpp | 28 ++++++++++++++++++++++++++++ lite/src/mge/network_impl.h | 4 ++++ lite/src/misc.cpp | 4 ++++ lite/test/test_network.cpp | 5 +++++ src/core/include/megbrain/graph/var_node.h | 7 +++++++ 8 files changed, 59 insertions(+), 2 deletions(-) diff --git a/lite/pylite/megenginelite/struct.py b/lite/pylite/megenginelite/struct.py index 6c1e55b0..57a7505e 100644 --- a/lite/pylite/megenginelite/struct.py +++ b/lite/pylite/megenginelite/struct.py @@ -31,6 +31,7 @@ class LiteDataType(IntEnum): LITE_INT16 = 3 LITE_INT8 = 4 LITE_UINT8 = 5 + LITE_UINT16 = 6 class LiteTensorPhase(IntEnum): diff --git a/lite/pylite/megenginelite/tensor.py b/lite/pylite/megenginelite/tensor.py index 256f2ee2..0af8f27f 100644 --- a/lite/pylite/megenginelite/tensor.py +++ b/lite/pylite/megenginelite/tensor.py @@ -22,6 +22,7 @@ _lite_type_to_nptypes = { LiteDataType.LITE_UINT8: np.uint8, LiteDataType.LITE_INT8: np.int8, LiteDataType.LITE_INT16: np.int16, + LiteDataType.LITE_UINT16: np.uint16, LiteDataType.LITE_HALF: np.float16, } @@ -33,6 +34,7 @@ _str_nptypes_to_lite_nptypes = { np.dtype("uint8"): LiteDataType.LITE_UINT8, np.dtype("int8"): LiteDataType.LITE_INT8, np.dtype("int16"): LiteDataType.LITE_INT16, + np.dtype("uint16"): LiteDataType.LITE_UINT16, np.dtype("float16"): LiteDataType.LITE_HALF, } @@ -43,7 +45,7 @@ ctype_to_lite_dtypes = { c_ubyte: LiteDataType.LITE_UINT8, c_byte: LiteDataType.LITE_INT8, c_short: LiteDataType.LITE_INT16, - c_ushort: LiteDataType.LITE_INT16, + c_ushort: LiteDataType.LITE_UINT16, } @@ -83,7 +85,7 @@ class LiteLayout(Structure): def __repr__(self): data = { - "shapes": list(self.shapes), + "shapes": list(self.shapes)[0 : self.ndim], "ndim": self.ndim, "data_type": _lite_type_to_nptypes[LiteDataType(self.data_type)], } diff --git a/lite/src/mge/common.cpp b/lite/src/mge/common.cpp index d1855fa4..6b0ceb4d 100644 --- a/lite/src/mge/common.cpp +++ b/lite/src/mge/common.cpp @@ -100,6 +100,9 @@ LTensorLayout lite::to_impl_layout(const Layout& layout) { case LiteDataType::LITE_INT16: mge_layout.dtype = mgb::dtype::Int16(); break; + case LiteDataType::LITE_UINT16: + mge_layout.dtype = mgb::dtype::Uint16(); + break; default: LITE_THROW(mgb::ssprintf( "unsupport dtype in lite enum id is %d.", @@ -133,6 +136,9 @@ Layout lite::to_lite_layout(const LTensorLayout& mge_layout) { case mgb::DTypeEnum::Int16: layout.data_type = LiteDataType::LITE_INT16; break; + case mgb::DTypeEnum::Uint16: + layout.data_type = LiteDataType::LITE_UINT16; + break; case mgb::DTypeEnum::Int8: layout.data_type = LiteDataType::LITE_INT8; break; diff --git a/lite/src/mge/network_impl.cpp b/lite/src/mge/network_impl.cpp index b1911a8d..a8a98d86 100644 --- a/lite/src/mge/network_impl.cpp +++ b/lite/src/mge/network_impl.cpp @@ -442,6 +442,24 @@ void NetworkImplDft::set_io(const NetworkIO& network_io) { } } +void NetworkImplDft::try_infer_tensor_layout( + std::shared_ptr tensor, mgb::cg::SymbolVar var) { + auto&& static_infer_mgr = m_load_config.comp_graph->static_infer_manager(); + auto infer_trait = var.node()->get_static_infer_trait(); + if (std::get<0>(infer_trait)) { + auto shape = static_infer_mgr.infer_shape_fallible(var.node()); + if (!shape) { + LITE_WARN( + "Lite infer output shape failed, maybe the model is " + "dynamic " + "shape.\n"); + return; + } + Layout layout = to_lite_layout(mgb::TensorLayout{*shape, var.dtype()}); + tensor->set_layout(layout); + } +} + void NetworkImplDft::update_io() { update_input(); update_output(); @@ -564,6 +582,14 @@ void NetworkImplDft::update_output() { out_it->lite_tensor = std::make_shared(device_id, stream_id, device_type); } + mgb::SymbolVar var; + for (auto&& out_var : m_load_result.output_var_list) { + if (out_var.node()->name() == out_it->name) { + var = out_var; + break; + } + } + try_infer_tensor_layout(out_it->lite_tensor, var); } //! user not set, use default output } else { @@ -579,12 +605,14 @@ void NetworkImplDft::update_output() { it->lite_tensor = std::make_shared(device_id, stream_id, device_type); } + try_infer_tensor_layout(it->lite_tensor, out); } else { IOInner output; output.name = out.node()->name(); output.lite_tensor = std::make_shared( device_id, stream_id, device_type, true); m_network_io->outputs.push_back({output}); + try_infer_tensor_layout(output.lite_tensor, out); } } } diff --git a/lite/src/mge/network_impl.h b/lite/src/mge/network_impl.h index bafad7f7..d370253c 100644 --- a/lite/src/mge/network_impl.h +++ b/lite/src/mge/network_impl.h @@ -201,6 +201,10 @@ private: //! compile the graph to get the execute function void compile_graph(); + //! try to infer output tensor layout + void try_infer_tensor_layout( + std::shared_ptr tensor, mgb::cg::SymbolVar var); + private: bool m_async = false; bool m_is_cpu_inplace_mode = false; diff --git a/lite/src/misc.cpp b/lite/src/misc.cpp index da7ed6e2..186f616c 100644 --- a/lite/src/misc.cpp +++ b/lite/src/misc.cpp @@ -102,6 +102,8 @@ LiteLogLevel lite::get_log_level() { } std::string lite::ssprintf(const char* format, ...) { + if (!format) + return ""; va_list ap; va_start(ap, format); auto ret = svsprintf(format, ap); @@ -110,6 +112,8 @@ std::string lite::ssprintf(const char* format, ...) { } void lite::print_log(LiteLogLevel level, const char* format, ...) { + if (!format) + return; if (static_cast(level) < static_cast(get_log_level())) { return; } diff --git a/lite/test/test_network.cpp b/lite/test/test_network.cpp index 6b1d08b1..9ddfa48b 100644 --- a/lite/test/test_network.cpp +++ b/lite/test/test_network.cpp @@ -90,6 +90,11 @@ TEST(TestNetWork, GetAllName) { auto input_names = network->get_all_input_name(); auto output_names = network->get_all_output_name(); + auto output_tensor = network->get_output_tensor(0); + auto out_layout = output_tensor->get_layout(); + ASSERT_EQ(out_layout.ndim, 2); + ASSERT_EQ(out_layout.shapes[0], 1); + ASSERT_EQ(out_layout.shapes[1], 1000); ASSERT_EQ(input_names.size(), 1); ASSERT_EQ(output_names.size(), 1); ASSERT_TRUE(input_names[0] == "data"); diff --git a/src/core/include/megbrain/graph/var_node.h b/src/core/include/megbrain/graph/var_node.h index f8784a84..63cd5f01 100644 --- a/src/core/include/megbrain/graph/var_node.h +++ b/src/core/include/megbrain/graph/var_node.h @@ -488,6 +488,13 @@ public: */ MemAllocPlan& init_mem_plan(const DeviceTensorND* fixed_alloc = nullptr); + /*! + * \brief get the shape and value infer trait + */ + const std::tuple& get_static_infer_trait() { + return m_static_infer_trait; + } + private: //! whether its memory should be allocated by mgb system during graph //! execution; initialized in VarNodeMemManager::reset_opr_seq()