@@ -454,9 +454,8 @@ void NetworkImplDft::set_io(const NetworkIO& network_io) { | |||||
} | } | ||||
void NetworkImplDft::try_infer_tensor_layout(std::shared_ptr<Tensor> tensor, Var var) { | void NetworkImplDft::try_infer_tensor_layout(std::shared_ptr<Tensor> tensor, Var var) { | ||||
auto&& static_infer_mgr = m_load_config.comp_graph->static_infer_manager(); | |||||
auto infer_trait = var.node()->get_static_infer_trait(); | |||||
if (std::get<0>(infer_trait)) { | |||||
if (var.node()->capable_shape_infer()) { | |||||
auto&& static_infer_mgr = m_load_config.comp_graph->static_infer_manager(); | |||||
auto shape = static_infer_mgr.infer_shape_fallible(var.node()); | auto shape = static_infer_mgr.infer_shape_fallible(var.node()); | ||||
if (!shape) { | if (!shape) { | ||||
LITE_WARN( | LITE_WARN( | ||||
@@ -101,6 +101,21 @@ TEST(TestNetWork, GetAllName) { | |||||
ASSERT_TRUE(output_names[0] == "TRUE_DIV(EXP[12065],reduce0[12067])[12077]"); | ASSERT_TRUE(output_names[0] == "TRUE_DIV(EXP[12065],reduce0[12067])[12077]"); | ||||
} | } | ||||
TEST(TestNetWork, LoadFBSModel) { | |||||
Config config; | |||||
std::string model_path = "./ax.mge"; | |||||
std::shared_ptr<Network> network = std::make_shared<Network>(config); | |||||
network->load_model(model_path); | |||||
auto output_tensor = network->get_output_tensor(0); | |||||
auto out_layout = output_tensor->get_layout(); | |||||
ASSERT_EQ(out_layout.ndim, 4); | |||||
ASSERT_EQ(out_layout.shapes[0], 1); | |||||
ASSERT_EQ(out_layout.shapes[1], 1); | |||||
ASSERT_EQ(out_layout.shapes[2], 40); | |||||
ASSERT_EQ(out_layout.shapes[3], 180); | |||||
} | |||||
TEST(TestNetWork, BasicInplaceAndSingleThreadAffinity) { | TEST(TestNetWork, BasicInplaceAndSingleThreadAffinity) { | ||||
Config config; | Config config; | ||||
auto lite_tensor = get_input_data("./input_data.npy"); | auto lite_tensor = get_input_data("./input_data.npy"); | ||||
@@ -892,6 +892,16 @@ StaticInferManagerImpl::TagHandler* StaticInferManagerImpl::get_tag_handler_for_ | |||||
return c.value; | return c.value; | ||||
} | } | ||||
bool StaticInferManagerImpl::has_shape_infer(Tag tag) const { | |||||
auto&& c = get_tag_trait_container(tag); | |||||
return c.shape != nullptr; | |||||
} | |||||
bool StaticInferManagerImpl::has_value_infer(Tag tag) const { | |||||
auto&& c = get_tag_trait_container(tag); | |||||
return c.value != nullptr; | |||||
} | |||||
StaticInferManagerImpl::TagTraitBase* StaticInferManagerImpl::get_tag_trait_for_dep( | StaticInferManagerImpl::TagTraitBase* StaticInferManagerImpl::get_tag_trait_for_dep( | ||||
const DepElement& dep) { | const DepElement& dep) { | ||||
TagHandler* ret; | TagHandler* ret; | ||||
@@ -66,6 +66,16 @@ public: | |||||
MGE_WIN_DECLSPEC_FUC TagHandler* get_tag_handler_for_value(Tag tag); | MGE_WIN_DECLSPEC_FUC TagHandler* get_tag_handler_for_value(Tag tag); | ||||
/*! | /*! | ||||
* \brief check if there is a registered shape infer func in tag | |||||
*/ | |||||
bool has_shape_infer(Tag tag) const; | |||||
/*! | |||||
* \brief check if there is a registered value infer func in tag | |||||
*/ | |||||
bool has_value_infer(Tag tag) const; | |||||
/*! | |||||
* \brief clear registered handler for a tag; this is only used in error | * \brief clear registered handler for a tag; this is only used in error | ||||
* handling in opr creation | * handling in opr creation | ||||
*/ | */ | ||||
@@ -578,6 +578,18 @@ bool VarNode::is_graph_dest_varnode() { | |||||
return ComputingGraphImpl::downcast(owner_graph())->var_receiver(this).size() == 0; | return ComputingGraphImpl::downcast(owner_graph())->var_receiver(this).size() == 0; | ||||
} | } | ||||
bool VarNode::capable_shape_infer() { | |||||
auto&& mgr = | |||||
ComputingGraphImpl::downcast(owner_graph())->static_infer_manager_impl(); | |||||
return mgr.has_shape_infer(this); | |||||
} | |||||
bool VarNode::capable_value_infer() { | |||||
auto&& mgr = | |||||
ComputingGraphImpl::downcast(owner_graph())->static_infer_manager_impl(); | |||||
return mgr.has_value_infer(this); | |||||
} | |||||
VarNode& VarNode::add_flag(Flag flag) { | VarNode& VarNode::add_flag(Flag flag) { | ||||
modify_flag(flag, m_flag | flag); | modify_flag(flag, m_flag | flag); | ||||
return *this; | return *this; | ||||
@@ -495,11 +495,14 @@ public: | |||||
const DeviceTensorND* fixed_alloc = nullptr); | const DeviceTensorND* fixed_alloc = nullptr); | ||||
/*! | /*! | ||||
* \brief get the shape and value infer trait | |||||
* \brief check infer shape capablity by check m_static_infer_trait's shape infer | |||||
*/ | */ | ||||
const std::tuple<void*, void*>& get_static_infer_trait() { | |||||
return m_static_infer_trait; | |||||
} | |||||
MGE_WIN_DECLSPEC_FUC bool capable_shape_infer(); | |||||
/*! | |||||
* \brief check infer shape capablity by check m_static_infer_trait's value infer | |||||
*/ | |||||
MGE_WIN_DECLSPEC_FUC bool capable_value_infer(); | |||||
private: | private: | ||||
//! whether its memory should be allocated by mgb system during graph | //! whether its memory should be allocated by mgb system during graph | ||||