diff --git a/lite/src/mge/network_impl.cpp b/lite/src/mge/network_impl.cpp index 1cfde0fa..ae106921 100644 --- a/lite/src/mge/network_impl.cpp +++ b/lite/src/mge/network_impl.cpp @@ -454,9 +454,8 @@ void NetworkImplDft::set_io(const NetworkIO& network_io) { } void NetworkImplDft::try_infer_tensor_layout(std::shared_ptr tensor, Var var) { - auto&& static_infer_mgr = m_load_config.comp_graph->static_infer_manager(); - auto infer_trait = var.node()->get_static_infer_trait(); - if (std::get<0>(infer_trait)) { + if (var.node()->capable_shape_infer()) { + auto&& static_infer_mgr = m_load_config.comp_graph->static_infer_manager(); auto shape = static_infer_mgr.infer_shape_fallible(var.node()); if (!shape) { LITE_WARN( diff --git a/lite/test/test_network.cpp b/lite/test/test_network.cpp index f786b92c..8418daf5 100644 --- a/lite/test/test_network.cpp +++ b/lite/test/test_network.cpp @@ -101,6 +101,21 @@ TEST(TestNetWork, GetAllName) { ASSERT_TRUE(output_names[0] == "TRUE_DIV(EXP[12065],reduce0[12067])[12077]"); } +TEST(TestNetWork, LoadFBSModel) { + Config config; + std::string model_path = "./ax.mge"; + std::shared_ptr network = std::make_shared(config); + network->load_model(model_path); + + auto output_tensor = network->get_output_tensor(0); + auto out_layout = output_tensor->get_layout(); + ASSERT_EQ(out_layout.ndim, 4); + ASSERT_EQ(out_layout.shapes[0], 1); + ASSERT_EQ(out_layout.shapes[1], 1); + ASSERT_EQ(out_layout.shapes[2], 40); + ASSERT_EQ(out_layout.shapes[3], 180); +} + TEST(TestNetWork, BasicInplaceAndSingleThreadAffinity) { Config config; auto lite_tensor = get_input_data("./input_data.npy"); diff --git a/src/core/impl/graph/static_infer_impl.cpp b/src/core/impl/graph/static_infer_impl.cpp index 04773cf9..5aa84a0e 100644 --- a/src/core/impl/graph/static_infer_impl.cpp +++ b/src/core/impl/graph/static_infer_impl.cpp @@ -892,6 +892,16 @@ StaticInferManagerImpl::TagHandler* StaticInferManagerImpl::get_tag_handler_for_ return c.value; } +bool StaticInferManagerImpl::has_shape_infer(Tag tag) const { + auto&& c = get_tag_trait_container(tag); + return c.shape != nullptr; +} + +bool StaticInferManagerImpl::has_value_infer(Tag tag) const { + auto&& c = get_tag_trait_container(tag); + return c.value != nullptr; +} + StaticInferManagerImpl::TagTraitBase* StaticInferManagerImpl::get_tag_trait_for_dep( const DepElement& dep) { TagHandler* ret; diff --git a/src/core/impl/graph/static_infer_impl.h b/src/core/impl/graph/static_infer_impl.h index 572017a0..e7854302 100644 --- a/src/core/impl/graph/static_infer_impl.h +++ b/src/core/impl/graph/static_infer_impl.h @@ -66,6 +66,16 @@ public: MGE_WIN_DECLSPEC_FUC TagHandler* get_tag_handler_for_value(Tag tag); /*! + * \brief check if there is a registered shape infer func in tag + */ + bool has_shape_infer(Tag tag) const; + + /*! + * \brief check if there is a registered value infer func in tag + */ + bool has_value_infer(Tag tag) const; + + /*! * \brief clear registered handler for a tag; this is only used in error * handling in opr creation */ diff --git a/src/core/impl/graph/var_node.cpp b/src/core/impl/graph/var_node.cpp index 833da401..a9ceb364 100644 --- a/src/core/impl/graph/var_node.cpp +++ b/src/core/impl/graph/var_node.cpp @@ -578,6 +578,18 @@ bool VarNode::is_graph_dest_varnode() { return ComputingGraphImpl::downcast(owner_graph())->var_receiver(this).size() == 0; } +bool VarNode::capable_shape_infer() { + auto&& mgr = + ComputingGraphImpl::downcast(owner_graph())->static_infer_manager_impl(); + return mgr.has_shape_infer(this); +} + +bool VarNode::capable_value_infer() { + auto&& mgr = + ComputingGraphImpl::downcast(owner_graph())->static_infer_manager_impl(); + return mgr.has_value_infer(this); +} + VarNode& VarNode::add_flag(Flag flag) { modify_flag(flag, m_flag | flag); return *this; diff --git a/src/core/include/megbrain/graph/var_node.h b/src/core/include/megbrain/graph/var_node.h index 74db600a..b09b4157 100644 --- a/src/core/include/megbrain/graph/var_node.h +++ b/src/core/include/megbrain/graph/var_node.h @@ -495,11 +495,14 @@ public: const DeviceTensorND* fixed_alloc = nullptr); /*! - * \brief get the shape and value infer trait + * \brief check infer shape capablity by check m_static_infer_trait's shape infer */ - const std::tuple& get_static_infer_trait() { - return m_static_infer_trait; - } + MGE_WIN_DECLSPEC_FUC bool capable_shape_infer(); + + /*! + * \brief check infer shape capablity by check m_static_infer_trait's value infer + */ + MGE_WIN_DECLSPEC_FUC bool capable_value_infer(); private: //! whether its memory should be allocated by mgb system during graph