Browse Source

refactor(megbrain): refactor try infer tensor layout in lite avoiding using megbrain interface

GitOrigin-RevId: 9799e67102
master
Megvii Engine Team 2 years ago
parent
commit
551cad4955
3 changed files with 4 additions and 24 deletions
  1. +4
    -2
      lite/src/mge/network_impl.cpp
  2. +0
    -12
      src/core/impl/graph/var_node.cpp
  3. +0
    -10
      src/core/include/megbrain/graph/var_node.h

+ 4
- 2
lite/src/mge/network_impl.cpp View File

@@ -660,8 +660,10 @@ void NetworkImplDft::set_io(const NetworkIO& network_io) {
} }


void NetworkImplDft::try_infer_tensor_layout(std::shared_ptr<Tensor> tensor, Var var) { void NetworkImplDft::try_infer_tensor_layout(std::shared_ptr<Tensor> tensor, Var var) {
if (var.node()->capable_shape_infer()) {
auto&& static_infer_mgr = m_load_config.comp_graph->static_infer_manager();
using InferType = mgb::cg::static_infer::InferType;
auto&& static_infer_mgr = m_load_config.comp_graph->static_infer_manager();
if (static_infer_mgr.get_infer_type(var.node()).shape &
(InferType::CONST | InferType::RT_STATIC)) {
auto shape = static_infer_mgr.infer_shape_fallible(var.node()); auto shape = static_infer_mgr.infer_shape_fallible(var.node());
if (!shape) { if (!shape) {
LITE_WARN( LITE_WARN(


+ 0
- 12
src/core/impl/graph/var_node.cpp View File

@@ -596,18 +596,6 @@ bool VarNode::is_graph_dest_varnode() {
return ComputingGraphImpl::downcast(owner_graph())->var_receiver(this).size() == 0; return ComputingGraphImpl::downcast(owner_graph())->var_receiver(this).size() == 0;
} }


bool VarNode::capable_shape_infer() {
auto&& mgr =
ComputingGraphImpl::downcast(owner_graph())->static_infer_manager_impl();
return mgr.has_shape_infer(this);
}

bool VarNode::capable_value_infer() {
auto&& mgr =
ComputingGraphImpl::downcast(owner_graph())->static_infer_manager_impl();
return mgr.has_value_infer(this);
}

VarNode& VarNode::add_flag(Flag flag) { VarNode& VarNode::add_flag(Flag flag) {
modify_flag(flag, m_flag | flag); modify_flag(flag, m_flag | flag);
return *this; return *this;


+ 0
- 10
src/core/include/megbrain/graph/var_node.h View File

@@ -488,16 +488,6 @@ public:
MGE_WIN_DECLSPEC_FUC MemAllocPlan& init_mem_plan( MGE_WIN_DECLSPEC_FUC MemAllocPlan& init_mem_plan(
const DeviceTensorND* fixed_alloc = nullptr); const DeviceTensorND* fixed_alloc = nullptr);


/*!
* \brief check infer shape capablity by check m_static_infer_trait's shape infer
*/
MGE_WIN_DECLSPEC_FUC bool capable_shape_infer();

/*!
* \brief check infer shape capablity by check m_static_infer_trait's value infer
*/
MGE_WIN_DECLSPEC_FUC bool capable_value_infer();

//! whether the var is graph output, if it is output, the Flag of //! whether the var is graph output, if it is output, the Flag of
//! NO_SYS_MEM_ALLOC can be modified. //! NO_SYS_MEM_ALLOC can be modified.
MGE_WIN_DECLSPEC_FUC bool is_graph_dest_varnode(); MGE_WIN_DECLSPEC_FUC bool is_graph_dest_varnode();


Loading…
Cancel
Save