diff --git a/src/core/impl/graph/var_node.cpp b/src/core/impl/graph/var_node.cpp index e6a5a1ff..073fbcec 100644 --- a/src/core/impl/graph/var_node.cpp +++ b/src/core/impl/graph/var_node.cpp @@ -576,10 +576,10 @@ VarNode& VarNode::add_flag(Flag flag) { void VarNode::modify_flag(Flag delta, Flag new_flag) { if (contain_flag(Flag::FLAG_FREEZED)) { - mgb_assert((delta & ( - Flag::NO_MEM_RECLAIM | - Flag::NO_SYS_STATIC_MEM_ALLOC | - Flag::RT_FORCE_DYNAMIC_MEM_ALLOC)) == delta); + mgb_assert( + (delta & (Flag::NO_MEM_RECLAIM | Flag::NO_SYS_STATIC_MEM_ALLOC | + Flag::RT_FORCE_DYNAMIC_MEM_ALLOC)) == delta || + (new_flag & Flag::MEMORY_NO_NEED)); mgb_assert(!ComputingGraphImpl::downcast(owner_graph())-> var_node_mem_manager().optimize_started(), diff --git a/src/core/impl/graph/var_node_mem_mgr.cpp b/src/core/impl/graph/var_node_mem_mgr.cpp index 0127a00b..8b612c9d 100644 --- a/src/core/impl/graph/var_node_mem_mgr.cpp +++ b/src/core/impl/graph/var_node_mem_mgr.cpp @@ -24,6 +24,8 @@ #include "megbrain/utils/timer.h" #include "megbrain/utils/arith_helper.h" +#include "megbrain/opr/io.h" + #include using namespace mgb; @@ -36,7 +38,6 @@ void call_mem_status_changed(cg::OperatorNodeBase* opr) { if (cb.on_mem_status_changed.valid()) cb.on_mem_status_changed.val()(); } - } // namespace /* ==================== StaticDeviceMemoryManager ==================== */ @@ -393,11 +394,12 @@ bool VarNodeMemManager::alloc_var_node_mem_static() { bool VarNodeMemManager::update_static_alloc_plan() { // check whether unchanged + bool free_no_need_memory = free_combine_memory_no_need_var(); if (!m_owner_graph->static_infer_comp_seq_manager() .update_static_check_shape_change() && !m_first_static_plan_run && !m_impure_mem_plan_mgr.check_need_realloc()) { - return false; + return false || free_no_need_memory; } if (m_first_static_plan_run) @@ -494,6 +496,96 @@ bool VarNodeMemManager::make_static_var_tensor_from_alloc_plan() { return true; } +bool VarNodeMemManager::free_combine_memory_no_need_var() { + if (!m_owner_graph->options().graph_opt.weight_preprocess || + m_already_free_no_need_mem) { + return false; + } + bool reordered = false; + //! free no need storage + for (auto opr : *m_opr_seq) { + if (opr->try_cast_final() || + opr->try_cast_final()) { + auto opr_base = static_cast(opr); + auto var = opr_base->output(0); + if (var->contain_flag(VarNode::Flag::MEMORY_NO_NEED) && + var->dev_tensor_valid() && !var->dev_tensor().empty()) { + //! Only the tensor share count is 1, it can be free + if (opr_base->dev_data().use_count() == 1) { + auto layout = var->layout(); + var->m_dev_tensor.reset( + DeviceTensorStorage{var->comp_node()}, layout); + opr_base->free_dev_data(); + mgb_log_debug( + "preprocessed weight is freed, var name = %s, " + "var layout = %s", + var->name().c_str(), layout.to_string().c_str()); + } + m_already_free_no_need_mem = true; + } + } + bool memory_need_reorder = false; + if (opr->try_cast_final() || + opr->try_cast_final()) { + auto opr_base = + static_cast( + opr); + for (size_t index = 0; index < opr_base->output().size(); index++) { + auto var = opr_base->output(index); + if (var->contain_flag(VarNode::Flag::MEMORY_NO_NEED) && + var->dev_tensor_valid() && !var->dev_tensor().empty()) { + //! Only the tensor share count is 1, it can be free + if (opr_base->values()[index].use_count() == 1) { + auto layout = var->layout(); + var->m_dev_tensor.reset( + DeviceTensorStorage{var->comp_node()}, layout); + opr_base->mutable_values()[index]->reset( + DeviceTensorStorage{var->comp_node()}, layout); + memory_need_reorder = true; + mgb_log_debug( + "preprocessed weight is freed, var name " + "= %s, var layout = %s", + var->name().c_str(), + layout.to_string().c_str()); + } + m_already_free_no_need_mem = true; + } + } + } + //! recorder the other needed outputs, because they share the + //! same chunk of mem in device with no needed var, see + //! BatchedDeviceValueLoader + if (memory_need_reorder) { + auto opr_base = + static_cast( + opr); + auto comp_node = opr_base->output(0)->comp_node(); + bool is_device_opr = + comp_node.mem_node() != CompNode::default_cpu().mem_node(); + if (memory_need_reorder && is_device_opr) { + for (size_t index = 0; index < opr_base->output().size(); + index++) { + auto var = opr_base->output(index); + if (!var->contain_flag(VarNode::Flag::MEMORY_NO_NEED)) { + DeviceTensorStorage storage(var->comp_node()); + size_t size = var->layout().span().dist_byte(); + storage.ensure_size(size); + storage.copy_from(var->m_dev_tensor.storage(), size); + + var->m_dev_tensor.reset(storage, var->layout()); + opr_base->mutable_values()[index]->reset(storage, + var->layout()); + reordered = true; + } + } + //! sync to make sure memcopy is finished + comp_node.sync(); + } + } + } + return reordered; +} + void VarNodeMemManager::init_dynamic_alloc_opr_info() { mgb_assert(m_first_static_plan_run); m_need_post_exec_action_vars.clear(); diff --git a/src/core/impl/graph/var_node_mem_mgr.h b/src/core/impl/graph/var_node_mem_mgr.h index 3dbeacb6..271ab17b 100644 --- a/src/core/impl/graph/var_node_mem_mgr.h +++ b/src/core/impl/graph/var_node_mem_mgr.h @@ -174,6 +174,14 @@ class VarNodeMemManager { bool alloc_var_node_mem_static(); /*! + * \brief free the memory of var with MEMORY_NO_NEED flag + * + * \return whether memory of MEMORY_NO_NEED var or related other var + * memory changed + */ + bool free_combine_memory_no_need_var(); + + /*! * \brief initialize static memory allocation plan * * This can be used with custom StaticDeviceMemoryAllocator so static @@ -407,7 +415,8 @@ class VarNodeMemManager { bool check_need_realloc(); }; - bool m_first_static_plan_run = true, m_optimize_started = false; + bool m_first_static_plan_run = true, m_optimize_started = false, + m_already_free_no_need_mem = false; ComputingGraphImpl *m_owner_graph; ThinHashMap m_node_mem_trait; NullableHashMap diff --git a/src/core/impl/tensor.cpp b/src/core/impl/tensor.cpp index 1ef2ab7d..d811056e 100644 --- a/src/core/impl/tensor.cpp +++ b/src/core/impl/tensor.cpp @@ -449,7 +449,11 @@ DEF(resize, &)(const TensorShape& shape) { } DEF(reset, &)(TensorStorage storage, const TensorLayout &layout) { - mgb_assert(!layout.ndim || storage.valid_span(layout.span())); + //! The storage to be reset is either satisfy the layout or empty. + //! Empty storage is used after weight preprocess for saving memory and + //! checking layout when running + mgb_assert(!layout.ndim || storage.valid_span(layout.span()) || + storage.empty()); m_storage = std::move(storage); m_layout = layout; return static_cast(*this); diff --git a/src/core/include/megbrain/graph/cg.h b/src/core/include/megbrain/graph/cg.h index 27b33565..127d4084 100644 --- a/src/core/include/megbrain/graph/cg.h +++ b/src/core/include/megbrain/graph/cg.h @@ -98,7 +98,8 @@ struct GraphCommonOptimizeOptions { //! whether to enable fast-run profiled winograd opr replace bool weight_winograd_transform = false; //! whether to enable weight preprocess, if enabled it may use more - //! memory, default disable now + //! memory, default disable now, when weight preprocess is enabled, the + //! input shape should no change bool weight_preprocess = false; enum LayoutTransform : uint32_t { DEFAULT, diff --git a/src/core/include/megbrain/graph/var_node.h b/src/core/include/megbrain/graph/var_node.h index 31e4fef2..cc1c9b07 100644 --- a/src/core/include/megbrain/graph/var_node.h +++ b/src/core/include/megbrain/graph/var_node.h @@ -589,7 +589,7 @@ class VarNode final: public GraphNodeBase { friend class imperative::ProxyGraph; }; -enum class VarNode::Flag: uint32_t { +enum class VarNode::Flag : uint32_t { //! do not allocate memory by the system allocator even if shape could be //! inferred NO_SYS_MEM_ALLOC = 1 << 0, @@ -667,6 +667,12 @@ enum class VarNode::Flag: uint32_t { * after FLAG_FREEZED is present. */ FLAG_FREEZED = 1 << 10, + + /*! + * this flag indicates that data of this var has been processed and no need + * later, it can be freed, this is used in weight preprocess for memory save + */ + MEMORY_NO_NEED = 1 << 11, }; MGB_DEF_ENUM_CLASS_BIT_OPR(VarNode::Flag) diff --git a/src/core/test/graph/misc.cpp b/src/core/test/graph/misc.cpp index 04cfb43f..3655587f 100644 --- a/src/core/test/graph/misc.cpp +++ b/src/core/test/graph/misc.cpp @@ -1920,4 +1920,236 @@ TEST(TestGraph, NaiveRecord2NCHW44) { func->execute().wait(); } +namespace { +template +typename DnnOp::Algorithm* try_find_any_weight_preprocess_algo( + DnnOp* dnn_op, const char* mgb_info, Maybe& found, + Args&& ...args) { + if (found.valid()) { + if (found.val()) { + return dnn_op->execution_policy().algorithm; + } else { + return nullptr; + } + } + for (auto&& algo : dnn_op->get_all_algorithms( + std::forward(args)...)) { + dnn_op->execution_policy().algorithm = algo; + auto layouts = dnn_op->deduce_preprocessed_filter_layout( + std::forward(args)...); + if (layouts.empty()) continue; + bool valid = false; + for (auto&& l: layouts) { + if (!l.is_empty()) { + valid = true; + break; + } + } + if (valid) { + found.emplace(true); + return algo; + } + } + found.emplace(false); + mgb_log_warn("Can't find weight preprocess algo for op %s", mgb_info); + return nullptr; +} + +void test_free_memory_in_weight_preprocess(int record_level, CompNode cn) { + HostTensorGenerator<> gen; + auto graph = ComputingGraph::make(); + graph->options().graph_opt.weight_preprocess = true; + graph->options().comp_node_seq_record_level = record_level; + auto mkvar = [&](const char* name, const TensorShape& shp) { + return opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name); + }; + auto mkcvar = [&](const char* name, const TensorShape& shp) { + return opr::SharedDeviceTensor::make_const(*graph, *gen(shp, cn)) + .rename(name); + }; + auto x = mkvar("x", {1, 32, 16, 16}); + // ConvBias test dense + opr::ConvBias::Param param_conv_bias; + param_conv_bias.pad_h = param_conv_bias.pad_w = 0; + param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE; + auto w1 = mkcvar("w1", {32, 32, 1, 1}), b1 = mkcvar("b1", {1, 32, 1, 1}); + auto conv1 = opr::ConvBias::make(x, w1, b1, param_conv_bias); + Maybe wp1, wp2; + conv1.node()->owner_opr()->cast_final_safe() + .setup_algo_chooser([&](const cg::OperatorNodeBase* opr) { + return try_find_any_weight_preprocess_algo( + opr->cast_final_safe().megdnn_opr(), + opr->cname(), wp1, + opr->input(0)->layout(), opr->input(1)->layout(), + opr->input(2)->layout(), TensorLayout{}, + opr->output(0)->layout()); + }); + // Convolution + opr::Convolution::Param param_conv; + param_conv.pad_h = param_conv.pad_w = 0; + param_conv.sparse = opr::Convolution::Param::Sparse::DENSE; + auto w2 = mkcvar("w2", {32, 32, 1, 1}); + auto y = opr::Convolution::make(conv1, w2, param_conv); + y.node()->owner_opr()->cast_final_safe() + .setup_algo_chooser([&](const cg::OperatorNodeBase* opr) { + return try_find_any_weight_preprocess_algo( + opr->cast_final_safe().megdnn_opr(), + opr->cname(), wp2, + opr->input(0)->layout(), opr->input(1)->layout(), + opr->output(0)->layout()); + }); + + HostTensorND host_y; + auto func =graph->compile({make_callback_copy(y, host_y)}); + //!flag the no need memory of var + func->execute(); + //!free the no need memory of var + func->execute(); + auto check = [&](SymbolVar v) { + ASSERT_TRUE(v.node()->contain_flag(VarNode::Flag::MEMORY_NO_NEED)); + ASSERT_TRUE(v.node()->dev_tensor().empty()); + ASSERT_TRUE(v.node()->owner_opr() + ->cast_final_safe() + .get_dev_tensor() + .empty()); + }; + ASSERT_TRUE(wp1.valid() && wp2.valid()); + if (wp1.val()) { + check(w1); + } + if (wp2.val()) { + check(w2); + } +} +} // anonymous namespace + +TEST(TestGraph, FreeMemoryInWeightPreprocess) { + test_free_memory_in_weight_preprocess(0, CompNode::load("xpu0")); +} + +TEST(TestGraph, RecordFreeMemoryInWeightPreprocess) { + test_free_memory_in_weight_preprocess(1, CompNode::load("cpu0")); +} + +namespace { +MGB_DEFINE_OPR_CLASS(HostValueReader, cg::SingleCNOutshapePureByInshapeOprBase) // { + void scn_do_execute() override { + auto&& hv = owner_graph()->static_infer_manager().infer_value(input(0)); + MGB_MARK_USED_VAR(hv); + } + + NodeProp* do_make_node_prop() const override { + auto ret = Super::do_make_node_prop(); + ret->dep_map()[input(0)] = NodeProp::DepType::HOST_VALUE; + return ret; + } + + void get_output_var_shape( + const TensorShapeArray &, + TensorShapeArray &out_shape) const override { + out_shape.at(0) = {}; + } + + public: + HostValueReader(VarNode* inp) + : Super{inp->owner_graph(), {}, "host_value_reader", {inp}} { + add_input({inp}); + using F = VarNode::Flag; + add_output(None) + ->add_flag(F::ALLOW_EMPTY_SHAPE) + .add_flag(F::VOLATILE_CONTENT); + } + + static SymbolVar make(SymbolVar inp) { + return inp.node()->owner_graph()->insert_opr( + std::make_unique(inp.node()))->output(0); + } +}; +MGB_DYN_TYPE_OBJ_FINAL_IMPL(HostValueReader); +} + +TEST(TestGraph, FreeMemoryInWeightPreprocessWithValueInfer) { + HostTensorGenerator<> gen; + CompNode cn = CompNode::load("xpux"); + auto graph = ComputingGraph::make(); + graph->options().graph_opt.weight_preprocess = true; + graph->options().var_sanity_check_first_run = false; + auto mkvar = [&](const char* name, const TensorShape& shp) { + return opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name); + }; + auto mkcvar = [&](const char* name, const TensorShape& shp) { + return opr::SharedDeviceTensor::make_const(*graph, *gen(shp, cn)) + .rename(name); + }; + auto x = mkvar("x", {1, 32, 16, 16}); + auto w = mkcvar("w", {32, 32, 1, 1}); + auto y = opr::Convolution::make(x, w); + Maybe found; + y.node()->owner_opr()->cast_final_safe() + .setup_algo_chooser([&](const cg::OperatorNodeBase* opr) { + return try_find_any_weight_preprocess_algo( + opr->cast_final_safe().megdnn_opr(), + opr->cname(), found, + opr->input(0)->layout(), opr->input(1)->layout(), + opr->output(0)->layout()); + }); + auto reader = HostValueReader::make(w); + + HostTensorND host_y; + auto func = graph->compile({make_callback_copy(y, host_y), {reader, {}}}); + func->execute(); + // FIXME: failed on second execution due to requiring host value of the empty + // tensor which was freed in weight preprocess + func->execute(); + ASSERT_FALSE(w.node()->contain_flag(VarNode::Flag::MEMORY_NO_NEED)); + ASSERT_FALSE(w.node()->dev_tensor().empty()); + ASSERT_FALSE(w.node()->owner_opr() + ->cast_final_safe() + .get_dev_tensor() + .empty()); +} + +TEST(TestGraph, FreeMemoryInWeightPreprocessWithMultiReader) { + HostTensorGenerator<> gen; + CompNode cn = CompNode::load("xpux"); + auto graph = ComputingGraph::make(); + graph->options().graph_opt.weight_preprocess = true; + graph->options().var_sanity_check_first_run = false; + graph->options().graph_opt_level = 0; + auto mkvar = [&](const char* name, const TensorShape& shp) { + return opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name); + }; + auto mkcvar = [&](const char* name, const TensorShape& shp) { + return opr::SharedDeviceTensor::make_const(*graph, *gen(shp, cn)) + .rename(name); + }; + auto x = mkvar("x", {1, 32, 16, 16}); + auto w = mkcvar("w", {32, 32, 1, 1}); + auto y = opr::Convolution::make(x, w); + Maybe found; + y.node()->owner_opr()->cast_final_safe() + .setup_algo_chooser([&](const cg::OperatorNodeBase* opr) { + return try_find_any_weight_preprocess_algo( + opr->cast_final_safe().megdnn_opr(), + opr->cname(), found, + opr->input(0)->layout(), opr->input(1)->layout(), + opr->output(0)->layout()); + }); + auto y1 = w * 2 + 1; + + HostTensorND host_y, host_y1; + auto func = graph->compile({ + make_callback_copy(y, host_y), make_callback_copy(y1, host_y1)}); + func->execute(); + // FIXME: failed on second execution due to calculate expression + // (w * 2 + 1) with empty tensor + func->execute(); + ASSERT_FALSE(w.node()->contain_flag(VarNode::Flag::MEMORY_NO_NEED)); + ASSERT_FALSE(w.node()->dev_tensor().empty()); + ASSERT_FALSE(w.node()->owner_opr() + ->cast_final_safe() + .get_dev_tensor() + .empty()); +} + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr/impl/dnn/convolution.cpp b/src/opr/impl/dnn/convolution.cpp index 659956e3..2188e5ee 100644 --- a/src/opr/impl/dnn/convolution.cpp +++ b/src/opr/impl/dnn/convolution.cpp @@ -138,39 +138,36 @@ public: void mixin::WeightPreprocessExecutor::mixin_update_preprocessed_filter( cg::OperatorNodeBase& opr) { - if (!mixin_allow_weight_preprocess(opr)) + if (!mixin_allow_weight_preprocess(opr)) { return; - + } auto new_layout = deduce_preprocessed_filter_layout(); + size_t new_size = new_layout.size(); + //! No preprocess layout means no need weight preprocess if (new_layout.empty()) { - // Weight preprocess was needed before, but no longer needed. - if (m_preprocessed_filter) { - m_preprocessed_filter.reset(); - m_filter_storage.clear(); + return; + } + //! all layouts arm empty means no need weight preprocess + bool layout_valid = false; + for (auto&& layout : new_layout) { + if (!layout.is_empty()) { + layout_valid = true; } + } + if (!layout_valid) { return; } - bool should_update = false; - size_t new_size = new_layout.size(); - if (!m_preprocessed_filter || - m_preprocessed_filter->tensors.size() != new_size) { - should_update = true; - } else { + if (m_preprocessed_filter) { for (size_t i = 0; i < new_size; i++) { - if (!new_layout[i].eq_layout( - m_preprocessed_filter->tensors[i].layout)) { - should_update = true; - break; - } + mgb_assert(new_layout[i].eq_layout( + m_preprocessed_filter->tensors[i].layout), + "weight preprocess layout changed, please keep input " + "shape unchanged when weight preprocess is enabled"); } - } - if (!should_update) return; - - if (!m_preprocessed_filter) { - m_preprocessed_filter.reset(new PreprocessedFilter{}); } + m_preprocessed_filter.reset(new PreprocessedFilter{}); m_preprocessed_filter->tensors.resize(new_size); m_filter_storage.resize(new_size); m_preprocessed_filter->algorithm_id = nullptr; @@ -327,6 +324,14 @@ void ConvolutionForward::scn_do_execute_preprocess() { input(0)->layout(), input(1)->dev_tensor().as_megdnn(), output(0)->layout(), preprocessed_filter(), intl::get_megdnn_workspace_from_var(output().back())); + //! Flag the input(1) no use later, which can be freed when no other + //! var depend on its dev_value, host_value and shape. + auto receiver_info = + input(1)->owner_graph()->var_receiver_in_current_comp_seq(input(1)); + if (receiver_info.dev_value == 1 && receiver_info.host_value == 0 && + receiver_info.shape == 0) { + input(1)->add_flag(VarNode::Flag::MEMORY_NO_NEED); + } } /* ==================== ConvolutionBackwardData ==================== */ @@ -959,6 +964,14 @@ void ConvBiasForward::scn_do_execute_preprocess() { input(0)->layout(), input(1)->dev_tensor().as_megdnn(), bias_layout, z_layout, output(0)->layout(), preprocessed_filter(), intl::get_megdnn_workspace_from_var(output().back())); + //! Flag the input(1) no use later, which can be freed when no other + //! var depend on its dev_value, host_value and shape. + auto receiver_info = + input(1)->owner_graph()->var_receiver_in_current_comp_seq(input(1)); + if (receiver_info.dev_value == 1 && receiver_info.host_value == 0 && + receiver_info.shape == 0) { + input(1)->add_flag(VarNode::Flag::MEMORY_NO_NEED); + } } /* ===================== LocalShareForward ==================== */ diff --git a/src/opr/impl/io.cpp b/src/opr/impl/io.cpp index 6bf35a34..ee0da3e5 100644 --- a/src/opr/impl/io.cpp +++ b/src/opr/impl/io.cpp @@ -142,8 +142,10 @@ void intl::DeviceTensorHolder::add_output(DType dtype) { } void intl::DeviceTensorHolder::record_execute_deps(ExecDependencyArray& deps) { - deps.emplace_back( - std::make_unique(get_dev_tensor().storage())); + if (!output(0)->contain_flag(VarNode::Flag::MEMORY_NO_NEED)) { + deps.emplace_back( + std::make_unique(get_dev_tensor().storage())); + } } /* ===================== Host2DeviceCopy ===================== */ @@ -801,14 +803,19 @@ class intl::MultipleDeviceTensorHolderBase::DevValuesExecDep final SmallVector m_vals; public: - explicit DevValuesExecDep(const ValueArray& vals) { - for (auto&& val : vals) { - m_vals.emplace_back(std::move(val->storage())); + explicit DevValuesExecDep(const ValueArray& vals, + MultipleDeviceTensorHolderBase* opr) { + mgb_assert(vals.size() == opr->output().size(), + "the output value size is diff from output var size"); + for (size_t index = 0; index < vals.size(); index++) { + if (!opr->output(index)->contain_flag( + VarNode::Flag::MEMORY_NO_NEED)) { + m_vals.emplace_back(std::move(vals[index]->storage())); + } } } }; - intl::MultipleDeviceTensorHolderBase::MultipleDeviceTensorHolderBase( ComputingGraph& graph, ValueArray values, const OperatorNodeConfig& config) @@ -887,8 +894,7 @@ intl::MultipleDeviceTensorHolderBase::do_make_node_prop() const { void intl::MultipleDeviceTensorHolderBase::record_execute_deps( ExecDependencyArray& deps) { - deps.emplace_back( - std::make_unique(values())); + deps.emplace_back(std::make_unique(values(), this)); } /* ===================== MultipleDeviceTensorHolder ===================== */ diff --git a/src/opr/impl/search_policy/algo_chooser.cpp b/src/opr/impl/search_policy/algo_chooser.cpp index 5065ff61..293f1540 100644 --- a/src/opr/impl/search_policy/algo_chooser.cpp +++ b/src/opr/impl/search_policy/algo_chooser.cpp @@ -173,9 +173,15 @@ size_t AlgoChooser::setup_algo(const ConvTensorLayouts& layouts, return 0; } + ImplAlgo algo = nullptr; ExeContext ctx(layouts, megdnn_opr, mgb_opr, allow_weight_preprocess); - auto algo = get_algo(ctx); + if (auto algo_choose_hook = mgb_opr->algo_chooser()) { + algo = algo_choose_hook(mgb_opr); + } + if (!algo) { + algo = get_algo(ctx); + } size_t workspace = ctx.get_workspace_size_bytes(algo); mgb_log_debug( "%s: tensor layouts(%s %s, %s %s) -> (%s %s): algo=%s " @@ -360,16 +366,29 @@ AlgoChooser::ExeContext::construct_fake_preprocess_filter() const { if (!m_allow_weight_preprocess) return; auto opr = _(m_megdnn_opr); - auto layout = APPLY(opr->deduce_preprocessed_filter_layout(args...), - m_layouts); - if (layout.empty()) + auto layouts = APPLY(opr->deduce_preprocessed_filter_layout(args...), + m_layouts); + //! No preprocess layout means no need weight preprocess + if (layouts.empty()) { return; + } + //! all layouts arm empty means no need weight preprocess + bool layout_valid = false; + for (auto&& layout : layouts) { + if (!layout.is_empty()) { + layout_valid = true; + } + } + if (!layout_valid) { + return; + } + result = PreprocessFilter{}; auto& res = result.val(); res.algorithm_id = nullptr; - res.tensors.resize(layout.size()); - for (size_t i = 0; i < layout.size(); i++) { - res.tensors[i] = megdnn::TensorND(nullptr, layout[i]); + res.tensors.resize(layouts.size()); + for (size_t i = 0; i < layouts.size(); i++) { + res.tensors[i] = megdnn::TensorND(nullptr, layouts[i]); } }); return result; diff --git a/src/opr/include/megbrain/opr/dnn/convolution.h b/src/opr/include/megbrain/opr/dnn/convolution.h index 0f5afa71..b3071817 100644 --- a/src/opr/include/megbrain/opr/dnn/convolution.h +++ b/src/opr/include/megbrain/opr/dnn/convolution.h @@ -25,6 +25,9 @@ namespace mixin { class Convolution { public: using ExecutionPolicy = megdnn::param::ExecutionPolicy; + using Algorithm = megdnn::detail::Algorithm; + using AlgoChooserHook = + std::function; const ExecutionPolicy& execution_policy() const { if (!m_policy_accessed) { @@ -55,6 +58,16 @@ class Convolution { virtual std::pair param_blob() const = 0; + /*! + * \brief register a hook to implement custom algo chooser + */ + void setup_algo_chooser(AlgoChooserHook&& func) { + m_algo_chooser = func; + } + AlgoChooserHook algo_chooser() const { + return m_algo_chooser; + } + protected: ~Convolution(); @@ -63,6 +76,8 @@ class Convolution { std::unique_ptr m_profile_cache; + AlgoChooserHook m_algo_chooser; + virtual void init_profile_cache() = 0; //! init output desc for conv backward data oprs; it handles both grad diff --git a/src/opr/include/megbrain/opr/io.h b/src/opr/include/megbrain/opr/io.h index 8c842efa..d6b4c1ec 100644 --- a/src/opr/include/megbrain/opr/io.h +++ b/src/opr/include/megbrain/opr/io.h @@ -99,6 +99,11 @@ MGB_DEFINE_CLS_WITH_SUPER(SharedDeviceTensorBase, DeviceTensorHolder) // { return *m_dev_data; } + void free_dev_data() { + m_dev_data->reset(DeviceTensorStorage{m_dev_data->comp_node()}, + m_dev_data->layout()); + } + const std::shared_ptr& dev_data() const { return m_dev_data; } @@ -122,6 +127,10 @@ public: const OperatorNodeConfig& config); const ValueArray& values() const { return m_values; } + ValueArray& mutable_values() { + return m_values; + } + protected: ValueArray m_values; @@ -292,7 +301,7 @@ MGB_DEFINE_OPR_CLASS(SharedDeviceTensor, intl::SharedDeviceTensorBase) // { static SymbolVar make_const(ComputingGraph& graph, const HostTensorND& value, const OperatorNodeConfig& config = {}) { - return make(graph, value, false, config); + return make(graph, value, true, config); } };