GitOrigin-RevId: 5f91acb909
release-1.1
@@ -576,10 +576,10 @@ VarNode& VarNode::add_flag(Flag flag) { | |||||
void VarNode::modify_flag(Flag delta, Flag new_flag) { | void VarNode::modify_flag(Flag delta, Flag new_flag) { | ||||
if (contain_flag(Flag::FLAG_FREEZED)) { | if (contain_flag(Flag::FLAG_FREEZED)) { | ||||
mgb_assert((delta & ( | |||||
Flag::NO_MEM_RECLAIM | | |||||
Flag::NO_SYS_STATIC_MEM_ALLOC | | |||||
Flag::RT_FORCE_DYNAMIC_MEM_ALLOC)) == delta); | |||||
mgb_assert( | |||||
(delta & (Flag::NO_MEM_RECLAIM | Flag::NO_SYS_STATIC_MEM_ALLOC | | |||||
Flag::RT_FORCE_DYNAMIC_MEM_ALLOC)) == delta || | |||||
(new_flag & Flag::MEMORY_NO_NEED)); | |||||
mgb_assert(!ComputingGraphImpl::downcast(owner_graph())-> | mgb_assert(!ComputingGraphImpl::downcast(owner_graph())-> | ||||
var_node_mem_manager().optimize_started(), | var_node_mem_manager().optimize_started(), | ||||
@@ -24,6 +24,8 @@ | |||||
#include "megbrain/utils/timer.h" | #include "megbrain/utils/timer.h" | ||||
#include "megbrain/utils/arith_helper.h" | #include "megbrain/utils/arith_helper.h" | ||||
#include "megbrain/opr/io.h" | |||||
#include <chrono> | #include <chrono> | ||||
using namespace mgb; | using namespace mgb; | ||||
@@ -36,7 +38,6 @@ void call_mem_status_changed(cg::OperatorNodeBase* opr) { | |||||
if (cb.on_mem_status_changed.valid()) | if (cb.on_mem_status_changed.valid()) | ||||
cb.on_mem_status_changed.val()(); | cb.on_mem_status_changed.val()(); | ||||
} | } | ||||
} // namespace | } // namespace | ||||
/* ==================== StaticDeviceMemoryManager ==================== */ | /* ==================== StaticDeviceMemoryManager ==================== */ | ||||
@@ -393,11 +394,12 @@ bool VarNodeMemManager::alloc_var_node_mem_static() { | |||||
bool VarNodeMemManager::update_static_alloc_plan() { | bool VarNodeMemManager::update_static_alloc_plan() { | ||||
// check whether unchanged | // check whether unchanged | ||||
bool free_no_need_memory = free_combine_memory_no_need_var(); | |||||
if (!m_owner_graph->static_infer_comp_seq_manager() | if (!m_owner_graph->static_infer_comp_seq_manager() | ||||
.update_static_check_shape_change() && | .update_static_check_shape_change() && | ||||
!m_first_static_plan_run && | !m_first_static_plan_run && | ||||
!m_impure_mem_plan_mgr.check_need_realloc()) { | !m_impure_mem_plan_mgr.check_need_realloc()) { | ||||
return false; | |||||
return false || free_no_need_memory; | |||||
} | } | ||||
if (m_first_static_plan_run) | if (m_first_static_plan_run) | ||||
@@ -494,6 +496,96 @@ bool VarNodeMemManager::make_static_var_tensor_from_alloc_plan() { | |||||
return true; | return true; | ||||
} | } | ||||
bool VarNodeMemManager::free_combine_memory_no_need_var() { | |||||
if (!m_owner_graph->options().graph_opt.weight_preprocess || | |||||
m_already_free_no_need_mem) { | |||||
return false; | |||||
} | |||||
bool reordered = false; | |||||
//! free no need storage | |||||
for (auto opr : *m_opr_seq) { | |||||
if (opr->try_cast_final<opr::SharedDeviceTensor>() || | |||||
opr->try_cast_final<opr::SharedDeviceTensorWithFormat>()) { | |||||
auto opr_base = static_cast<opr::intl::SharedDeviceTensorBase*>(opr); | |||||
auto var = opr_base->output(0); | |||||
if (var->contain_flag(VarNode::Flag::MEMORY_NO_NEED) && | |||||
var->dev_tensor_valid() && !var->dev_tensor().empty()) { | |||||
//! Only the tensor share count is 1, it can be free | |||||
if (opr_base->dev_data().use_count() == 1) { | |||||
auto layout = var->layout(); | |||||
var->m_dev_tensor.reset( | |||||
DeviceTensorStorage{var->comp_node()}, layout); | |||||
opr_base->free_dev_data(); | |||||
mgb_log_debug( | |||||
"preprocessed weight is freed, var name = %s, " | |||||
"var layout = %s", | |||||
var->name().c_str(), layout.to_string().c_str()); | |||||
} | |||||
m_already_free_no_need_mem = true; | |||||
} | |||||
} | |||||
bool memory_need_reorder = false; | |||||
if (opr->try_cast_final<opr::MultipleDeviceTensorHolder>() || | |||||
opr->try_cast_final<opr::MultipleDeviceTensorWithFormatHolder>()) { | |||||
auto opr_base = | |||||
static_cast<opr::intl::MultipleDeviceTensorHolderBase*>( | |||||
opr); | |||||
for (size_t index = 0; index < opr_base->output().size(); index++) { | |||||
auto var = opr_base->output(index); | |||||
if (var->contain_flag(VarNode::Flag::MEMORY_NO_NEED) && | |||||
var->dev_tensor_valid() && !var->dev_tensor().empty()) { | |||||
//! Only the tensor share count is 1, it can be free | |||||
if (opr_base->values()[index].use_count() == 1) { | |||||
auto layout = var->layout(); | |||||
var->m_dev_tensor.reset( | |||||
DeviceTensorStorage{var->comp_node()}, layout); | |||||
opr_base->mutable_values()[index]->reset( | |||||
DeviceTensorStorage{var->comp_node()}, layout); | |||||
memory_need_reorder = true; | |||||
mgb_log_debug( | |||||
"preprocessed weight is freed, var name " | |||||
"= %s, var layout = %s", | |||||
var->name().c_str(), | |||||
layout.to_string().c_str()); | |||||
} | |||||
m_already_free_no_need_mem = true; | |||||
} | |||||
} | |||||
} | |||||
//! recorder the other needed outputs, because they share the | |||||
//! same chunk of mem in device with no needed var, see | |||||
//! BatchedDeviceValueLoader | |||||
if (memory_need_reorder) { | |||||
auto opr_base = | |||||
static_cast<opr::intl::MultipleDeviceTensorHolderBase*>( | |||||
opr); | |||||
auto comp_node = opr_base->output(0)->comp_node(); | |||||
bool is_device_opr = | |||||
comp_node.mem_node() != CompNode::default_cpu().mem_node(); | |||||
if (memory_need_reorder && is_device_opr) { | |||||
for (size_t index = 0; index < opr_base->output().size(); | |||||
index++) { | |||||
auto var = opr_base->output(index); | |||||
if (!var->contain_flag(VarNode::Flag::MEMORY_NO_NEED)) { | |||||
DeviceTensorStorage storage(var->comp_node()); | |||||
size_t size = var->layout().span().dist_byte(); | |||||
storage.ensure_size(size); | |||||
storage.copy_from(var->m_dev_tensor.storage(), size); | |||||
var->m_dev_tensor.reset(storage, var->layout()); | |||||
opr_base->mutable_values()[index]->reset(storage, | |||||
var->layout()); | |||||
reordered = true; | |||||
} | |||||
} | |||||
//! sync to make sure memcopy is finished | |||||
comp_node.sync(); | |||||
} | |||||
} | |||||
} | |||||
return reordered; | |||||
} | |||||
void VarNodeMemManager::init_dynamic_alloc_opr_info() { | void VarNodeMemManager::init_dynamic_alloc_opr_info() { | ||||
mgb_assert(m_first_static_plan_run); | mgb_assert(m_first_static_plan_run); | ||||
m_need_post_exec_action_vars.clear(); | m_need_post_exec_action_vars.clear(); | ||||
@@ -174,6 +174,14 @@ class VarNodeMemManager { | |||||
bool alloc_var_node_mem_static(); | bool alloc_var_node_mem_static(); | ||||
/*! | /*! | ||||
* \brief free the memory of var with MEMORY_NO_NEED flag | |||||
* | |||||
* \return whether memory of MEMORY_NO_NEED var or related other var | |||||
* memory changed | |||||
*/ | |||||
bool free_combine_memory_no_need_var(); | |||||
/*! | |||||
* \brief initialize static memory allocation plan | * \brief initialize static memory allocation plan | ||||
* | * | ||||
* This can be used with custom StaticDeviceMemoryAllocator so static | * This can be used with custom StaticDeviceMemoryAllocator so static | ||||
@@ -407,7 +415,8 @@ class VarNodeMemManager { | |||||
bool check_need_realloc(); | bool check_need_realloc(); | ||||
}; | }; | ||||
bool m_first_static_plan_run = true, m_optimize_started = false; | |||||
bool m_first_static_plan_run = true, m_optimize_started = false, | |||||
m_already_free_no_need_mem = false; | |||||
ComputingGraphImpl *m_owner_graph; | ComputingGraphImpl *m_owner_graph; | ||||
ThinHashMap<VarNode*, VarNodeMemTrait> m_node_mem_trait; | ThinHashMap<VarNode*, VarNodeMemTrait> m_node_mem_trait; | ||||
NullableHashMap<OperatorNodeBase*, DynamicAllocOprInfo> | NullableHashMap<OperatorNodeBase*, DynamicAllocOprInfo> | ||||
@@ -449,7 +449,11 @@ DEF(resize, &)(const TensorShape& shape) { | |||||
} | } | ||||
DEF(reset, &)(TensorStorage storage, const TensorLayout &layout) { | DEF(reset, &)(TensorStorage storage, const TensorLayout &layout) { | ||||
mgb_assert(!layout.ndim || storage.valid_span(layout.span())); | |||||
//! The storage to be reset is either satisfy the layout or empty. | |||||
//! Empty storage is used after weight preprocess for saving memory and | |||||
//! checking layout when running | |||||
mgb_assert(!layout.ndim || storage.valid_span(layout.span()) || | |||||
storage.empty()); | |||||
m_storage = std::move(storage); | m_storage = std::move(storage); | ||||
m_layout = layout; | m_layout = layout; | ||||
return static_cast<ChainReturnType&>(*this); | return static_cast<ChainReturnType&>(*this); | ||||
@@ -98,7 +98,8 @@ struct GraphCommonOptimizeOptions { | |||||
//! whether to enable fast-run profiled winograd opr replace | //! whether to enable fast-run profiled winograd opr replace | ||||
bool weight_winograd_transform = false; | bool weight_winograd_transform = false; | ||||
//! whether to enable weight preprocess, if enabled it may use more | //! whether to enable weight preprocess, if enabled it may use more | ||||
//! memory, default disable now | |||||
//! memory, default disable now, when weight preprocess is enabled, the | |||||
//! input shape should no change | |||||
bool weight_preprocess = false; | bool weight_preprocess = false; | ||||
enum LayoutTransform : uint32_t { | enum LayoutTransform : uint32_t { | ||||
DEFAULT, | DEFAULT, | ||||
@@ -589,7 +589,7 @@ class VarNode final: public GraphNodeBase { | |||||
friend class imperative::ProxyGraph; | friend class imperative::ProxyGraph; | ||||
}; | }; | ||||
enum class VarNode::Flag: uint32_t { | |||||
enum class VarNode::Flag : uint32_t { | |||||
//! do not allocate memory by the system allocator even if shape could be | //! do not allocate memory by the system allocator even if shape could be | ||||
//! inferred | //! inferred | ||||
NO_SYS_MEM_ALLOC = 1 << 0, | NO_SYS_MEM_ALLOC = 1 << 0, | ||||
@@ -667,6 +667,12 @@ enum class VarNode::Flag: uint32_t { | |||||
* after FLAG_FREEZED is present. | * after FLAG_FREEZED is present. | ||||
*/ | */ | ||||
FLAG_FREEZED = 1 << 10, | FLAG_FREEZED = 1 << 10, | ||||
/*! | |||||
* this flag indicates that data of this var has been processed and no need | |||||
* later, it can be freed, this is used in weight preprocess for memory save | |||||
*/ | |||||
MEMORY_NO_NEED = 1 << 11, | |||||
}; | }; | ||||
MGB_DEF_ENUM_CLASS_BIT_OPR(VarNode::Flag) | MGB_DEF_ENUM_CLASS_BIT_OPR(VarNode::Flag) | ||||
@@ -1920,4 +1920,236 @@ TEST(TestGraph, NaiveRecord2NCHW44) { | |||||
func->execute().wait(); | func->execute().wait(); | ||||
} | } | ||||
namespace { | |||||
template <typename DnnOp, typename... Args> | |||||
typename DnnOp::Algorithm* try_find_any_weight_preprocess_algo( | |||||
DnnOp* dnn_op, const char* mgb_info, Maybe<bool>& found, | |||||
Args&& ...args) { | |||||
if (found.valid()) { | |||||
if (found.val()) { | |||||
return dnn_op->execution_policy().algorithm; | |||||
} else { | |||||
return nullptr; | |||||
} | |||||
} | |||||
for (auto&& algo : dnn_op->get_all_algorithms( | |||||
std::forward<Args>(args)...)) { | |||||
dnn_op->execution_policy().algorithm = algo; | |||||
auto layouts = dnn_op->deduce_preprocessed_filter_layout( | |||||
std::forward<Args>(args)...); | |||||
if (layouts.empty()) continue; | |||||
bool valid = false; | |||||
for (auto&& l: layouts) { | |||||
if (!l.is_empty()) { | |||||
valid = true; | |||||
break; | |||||
} | |||||
} | |||||
if (valid) { | |||||
found.emplace(true); | |||||
return algo; | |||||
} | |||||
} | |||||
found.emplace(false); | |||||
mgb_log_warn("Can't find weight preprocess algo for op %s", mgb_info); | |||||
return nullptr; | |||||
} | |||||
void test_free_memory_in_weight_preprocess(int record_level, CompNode cn) { | |||||
HostTensorGenerator<> gen; | |||||
auto graph = ComputingGraph::make(); | |||||
graph->options().graph_opt.weight_preprocess = true; | |||||
graph->options().comp_node_seq_record_level = record_level; | |||||
auto mkvar = [&](const char* name, const TensorShape& shp) { | |||||
return opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name); | |||||
}; | |||||
auto mkcvar = [&](const char* name, const TensorShape& shp) { | |||||
return opr::SharedDeviceTensor::make_const(*graph, *gen(shp, cn)) | |||||
.rename(name); | |||||
}; | |||||
auto x = mkvar("x", {1, 32, 16, 16}); | |||||
// ConvBias test dense | |||||
opr::ConvBias::Param param_conv_bias; | |||||
param_conv_bias.pad_h = param_conv_bias.pad_w = 0; | |||||
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE; | |||||
auto w1 = mkcvar("w1", {32, 32, 1, 1}), b1 = mkcvar("b1", {1, 32, 1, 1}); | |||||
auto conv1 = opr::ConvBias::make(x, w1, b1, param_conv_bias); | |||||
Maybe<bool> wp1, wp2; | |||||
conv1.node()->owner_opr()->cast_final_safe<opr::ConvBias>() | |||||
.setup_algo_chooser([&](const cg::OperatorNodeBase* opr) { | |||||
return try_find_any_weight_preprocess_algo( | |||||
opr->cast_final_safe<opr::ConvBias>().megdnn_opr(), | |||||
opr->cname(), wp1, | |||||
opr->input(0)->layout(), opr->input(1)->layout(), | |||||
opr->input(2)->layout(), TensorLayout{}, | |||||
opr->output(0)->layout()); | |||||
}); | |||||
// Convolution | |||||
opr::Convolution::Param param_conv; | |||||
param_conv.pad_h = param_conv.pad_w = 0; | |||||
param_conv.sparse = opr::Convolution::Param::Sparse::DENSE; | |||||
auto w2 = mkcvar("w2", {32, 32, 1, 1}); | |||||
auto y = opr::Convolution::make(conv1, w2, param_conv); | |||||
y.node()->owner_opr()->cast_final_safe<opr::Convolution>() | |||||
.setup_algo_chooser([&](const cg::OperatorNodeBase* opr) { | |||||
return try_find_any_weight_preprocess_algo( | |||||
opr->cast_final_safe<opr::Convolution>().megdnn_opr(), | |||||
opr->cname(), wp2, | |||||
opr->input(0)->layout(), opr->input(1)->layout(), | |||||
opr->output(0)->layout()); | |||||
}); | |||||
HostTensorND host_y; | |||||
auto func =graph->compile({make_callback_copy(y, host_y)}); | |||||
//!flag the no need memory of var | |||||
func->execute(); | |||||
//!free the no need memory of var | |||||
func->execute(); | |||||
auto check = [&](SymbolVar v) { | |||||
ASSERT_TRUE(v.node()->contain_flag(VarNode::Flag::MEMORY_NO_NEED)); | |||||
ASSERT_TRUE(v.node()->dev_tensor().empty()); | |||||
ASSERT_TRUE(v.node()->owner_opr() | |||||
->cast_final_safe<opr::SharedDeviceTensor>() | |||||
.get_dev_tensor() | |||||
.empty()); | |||||
}; | |||||
ASSERT_TRUE(wp1.valid() && wp2.valid()); | |||||
if (wp1.val()) { | |||||
check(w1); | |||||
} | |||||
if (wp2.val()) { | |||||
check(w2); | |||||
} | |||||
} | |||||
} // anonymous namespace | |||||
TEST(TestGraph, FreeMemoryInWeightPreprocess) { | |||||
test_free_memory_in_weight_preprocess(0, CompNode::load("xpu0")); | |||||
} | |||||
TEST(TestGraph, RecordFreeMemoryInWeightPreprocess) { | |||||
test_free_memory_in_weight_preprocess(1, CompNode::load("cpu0")); | |||||
} | |||||
namespace { | |||||
MGB_DEFINE_OPR_CLASS(HostValueReader, cg::SingleCNOutshapePureByInshapeOprBase) // { | |||||
void scn_do_execute() override { | |||||
auto&& hv = owner_graph()->static_infer_manager().infer_value(input(0)); | |||||
MGB_MARK_USED_VAR(hv); | |||||
} | |||||
NodeProp* do_make_node_prop() const override { | |||||
auto ret = Super::do_make_node_prop(); | |||||
ret->dep_map()[input(0)] = NodeProp::DepType::HOST_VALUE; | |||||
return ret; | |||||
} | |||||
void get_output_var_shape( | |||||
const TensorShapeArray &, | |||||
TensorShapeArray &out_shape) const override { | |||||
out_shape.at(0) = {}; | |||||
} | |||||
public: | |||||
HostValueReader(VarNode* inp) | |||||
: Super{inp->owner_graph(), {}, "host_value_reader", {inp}} { | |||||
add_input({inp}); | |||||
using F = VarNode::Flag; | |||||
add_output(None) | |||||
->add_flag(F::ALLOW_EMPTY_SHAPE) | |||||
.add_flag(F::VOLATILE_CONTENT); | |||||
} | |||||
static SymbolVar make(SymbolVar inp) { | |||||
return inp.node()->owner_graph()->insert_opr( | |||||
std::make_unique<HostValueReader>(inp.node()))->output(0); | |||||
} | |||||
}; | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(HostValueReader); | |||||
} | |||||
TEST(TestGraph, FreeMemoryInWeightPreprocessWithValueInfer) { | |||||
HostTensorGenerator<> gen; | |||||
CompNode cn = CompNode::load("xpux"); | |||||
auto graph = ComputingGraph::make(); | |||||
graph->options().graph_opt.weight_preprocess = true; | |||||
graph->options().var_sanity_check_first_run = false; | |||||
auto mkvar = [&](const char* name, const TensorShape& shp) { | |||||
return opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name); | |||||
}; | |||||
auto mkcvar = [&](const char* name, const TensorShape& shp) { | |||||
return opr::SharedDeviceTensor::make_const(*graph, *gen(shp, cn)) | |||||
.rename(name); | |||||
}; | |||||
auto x = mkvar("x", {1, 32, 16, 16}); | |||||
auto w = mkcvar("w", {32, 32, 1, 1}); | |||||
auto y = opr::Convolution::make(x, w); | |||||
Maybe<bool> found; | |||||
y.node()->owner_opr()->cast_final_safe<opr::Convolution>() | |||||
.setup_algo_chooser([&](const cg::OperatorNodeBase* opr) { | |||||
return try_find_any_weight_preprocess_algo( | |||||
opr->cast_final_safe<opr::Convolution>().megdnn_opr(), | |||||
opr->cname(), found, | |||||
opr->input(0)->layout(), opr->input(1)->layout(), | |||||
opr->output(0)->layout()); | |||||
}); | |||||
auto reader = HostValueReader::make(w); | |||||
HostTensorND host_y; | |||||
auto func = graph->compile({make_callback_copy(y, host_y), {reader, {}}}); | |||||
func->execute(); | |||||
// FIXME: failed on second execution due to requiring host value of the empty | |||||
// tensor which was freed in weight preprocess | |||||
func->execute(); | |||||
ASSERT_FALSE(w.node()->contain_flag(VarNode::Flag::MEMORY_NO_NEED)); | |||||
ASSERT_FALSE(w.node()->dev_tensor().empty()); | |||||
ASSERT_FALSE(w.node()->owner_opr() | |||||
->cast_final_safe<opr::SharedDeviceTensor>() | |||||
.get_dev_tensor() | |||||
.empty()); | |||||
} | |||||
TEST(TestGraph, FreeMemoryInWeightPreprocessWithMultiReader) { | |||||
HostTensorGenerator<> gen; | |||||
CompNode cn = CompNode::load("xpux"); | |||||
auto graph = ComputingGraph::make(); | |||||
graph->options().graph_opt.weight_preprocess = true; | |||||
graph->options().var_sanity_check_first_run = false; | |||||
graph->options().graph_opt_level = 0; | |||||
auto mkvar = [&](const char* name, const TensorShape& shp) { | |||||
return opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name); | |||||
}; | |||||
auto mkcvar = [&](const char* name, const TensorShape& shp) { | |||||
return opr::SharedDeviceTensor::make_const(*graph, *gen(shp, cn)) | |||||
.rename(name); | |||||
}; | |||||
auto x = mkvar("x", {1, 32, 16, 16}); | |||||
auto w = mkcvar("w", {32, 32, 1, 1}); | |||||
auto y = opr::Convolution::make(x, w); | |||||
Maybe<bool> found; | |||||
y.node()->owner_opr()->cast_final_safe<opr::Convolution>() | |||||
.setup_algo_chooser([&](const cg::OperatorNodeBase* opr) { | |||||
return try_find_any_weight_preprocess_algo( | |||||
opr->cast_final_safe<opr::Convolution>().megdnn_opr(), | |||||
opr->cname(), found, | |||||
opr->input(0)->layout(), opr->input(1)->layout(), | |||||
opr->output(0)->layout()); | |||||
}); | |||||
auto y1 = w * 2 + 1; | |||||
HostTensorND host_y, host_y1; | |||||
auto func = graph->compile({ | |||||
make_callback_copy(y, host_y), make_callback_copy(y1, host_y1)}); | |||||
func->execute(); | |||||
// FIXME: failed on second execution due to calculate expression | |||||
// (w * 2 + 1) with empty tensor | |||||
func->execute(); | |||||
ASSERT_FALSE(w.node()->contain_flag(VarNode::Flag::MEMORY_NO_NEED)); | |||||
ASSERT_FALSE(w.node()->dev_tensor().empty()); | |||||
ASSERT_FALSE(w.node()->owner_opr() | |||||
->cast_final_safe<opr::SharedDeviceTensor>() | |||||
.get_dev_tensor() | |||||
.empty()); | |||||
} | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -138,39 +138,36 @@ public: | |||||
void mixin::WeightPreprocessExecutor::mixin_update_preprocessed_filter( | void mixin::WeightPreprocessExecutor::mixin_update_preprocessed_filter( | ||||
cg::OperatorNodeBase& opr) { | cg::OperatorNodeBase& opr) { | ||||
if (!mixin_allow_weight_preprocess(opr)) | |||||
if (!mixin_allow_weight_preprocess(opr)) { | |||||
return; | return; | ||||
} | |||||
auto new_layout = deduce_preprocessed_filter_layout(); | auto new_layout = deduce_preprocessed_filter_layout(); | ||||
size_t new_size = new_layout.size(); | |||||
//! No preprocess layout means no need weight preprocess | |||||
if (new_layout.empty()) { | if (new_layout.empty()) { | ||||
// Weight preprocess was needed before, but no longer needed. | |||||
if (m_preprocessed_filter) { | |||||
m_preprocessed_filter.reset(); | |||||
m_filter_storage.clear(); | |||||
return; | |||||
} | |||||
//! all layouts arm empty means no need weight preprocess | |||||
bool layout_valid = false; | |||||
for (auto&& layout : new_layout) { | |||||
if (!layout.is_empty()) { | |||||
layout_valid = true; | |||||
} | } | ||||
} | |||||
if (!layout_valid) { | |||||
return; | return; | ||||
} | } | ||||
bool should_update = false; | |||||
size_t new_size = new_layout.size(); | |||||
if (!m_preprocessed_filter || | |||||
m_preprocessed_filter->tensors.size() != new_size) { | |||||
should_update = true; | |||||
} else { | |||||
if (m_preprocessed_filter) { | |||||
for (size_t i = 0; i < new_size; i++) { | for (size_t i = 0; i < new_size; i++) { | ||||
if (!new_layout[i].eq_layout( | |||||
m_preprocessed_filter->tensors[i].layout)) { | |||||
should_update = true; | |||||
break; | |||||
} | |||||
mgb_assert(new_layout[i].eq_layout( | |||||
m_preprocessed_filter->tensors[i].layout), | |||||
"weight preprocess layout changed, please keep input " | |||||
"shape unchanged when weight preprocess is enabled"); | |||||
} | } | ||||
} | |||||
if (!should_update) | |||||
return; | return; | ||||
if (!m_preprocessed_filter) { | |||||
m_preprocessed_filter.reset(new PreprocessedFilter{}); | |||||
} | } | ||||
m_preprocessed_filter.reset(new PreprocessedFilter{}); | |||||
m_preprocessed_filter->tensors.resize(new_size); | m_preprocessed_filter->tensors.resize(new_size); | ||||
m_filter_storage.resize(new_size); | m_filter_storage.resize(new_size); | ||||
m_preprocessed_filter->algorithm_id = nullptr; | m_preprocessed_filter->algorithm_id = nullptr; | ||||
@@ -327,6 +324,14 @@ void ConvolutionForward::scn_do_execute_preprocess() { | |||||
input(0)->layout(), input(1)->dev_tensor().as_megdnn(), | input(0)->layout(), input(1)->dev_tensor().as_megdnn(), | ||||
output(0)->layout(), preprocessed_filter(), | output(0)->layout(), preprocessed_filter(), | ||||
intl::get_megdnn_workspace_from_var(output().back())); | intl::get_megdnn_workspace_from_var(output().back())); | ||||
//! Flag the input(1) no use later, which can be freed when no other | |||||
//! var depend on its dev_value, host_value and shape. | |||||
auto receiver_info = | |||||
input(1)->owner_graph()->var_receiver_in_current_comp_seq(input(1)); | |||||
if (receiver_info.dev_value == 1 && receiver_info.host_value == 0 && | |||||
receiver_info.shape == 0) { | |||||
input(1)->add_flag(VarNode::Flag::MEMORY_NO_NEED); | |||||
} | |||||
} | } | ||||
/* ==================== ConvolutionBackwardData ==================== */ | /* ==================== ConvolutionBackwardData ==================== */ | ||||
@@ -959,6 +964,14 @@ void ConvBiasForward::scn_do_execute_preprocess() { | |||||
input(0)->layout(), input(1)->dev_tensor().as_megdnn(), bias_layout, | input(0)->layout(), input(1)->dev_tensor().as_megdnn(), bias_layout, | ||||
z_layout, output(0)->layout(), preprocessed_filter(), | z_layout, output(0)->layout(), preprocessed_filter(), | ||||
intl::get_megdnn_workspace_from_var(output().back())); | intl::get_megdnn_workspace_from_var(output().back())); | ||||
//! Flag the input(1) no use later, which can be freed when no other | |||||
//! var depend on its dev_value, host_value and shape. | |||||
auto receiver_info = | |||||
input(1)->owner_graph()->var_receiver_in_current_comp_seq(input(1)); | |||||
if (receiver_info.dev_value == 1 && receiver_info.host_value == 0 && | |||||
receiver_info.shape == 0) { | |||||
input(1)->add_flag(VarNode::Flag::MEMORY_NO_NEED); | |||||
} | |||||
} | } | ||||
/* ===================== LocalShareForward ==================== */ | /* ===================== LocalShareForward ==================== */ | ||||
@@ -142,8 +142,10 @@ void intl::DeviceTensorHolder::add_output(DType dtype) { | |||||
} | } | ||||
void intl::DeviceTensorHolder::record_execute_deps(ExecDependencyArray& deps) { | void intl::DeviceTensorHolder::record_execute_deps(ExecDependencyArray& deps) { | ||||
deps.emplace_back( | |||||
std::make_unique<DevValueExecDep>(get_dev_tensor().storage())); | |||||
if (!output(0)->contain_flag(VarNode::Flag::MEMORY_NO_NEED)) { | |||||
deps.emplace_back( | |||||
std::make_unique<DevValueExecDep>(get_dev_tensor().storage())); | |||||
} | |||||
} | } | ||||
/* ===================== Host2DeviceCopy ===================== */ | /* ===================== Host2DeviceCopy ===================== */ | ||||
@@ -801,14 +803,19 @@ class intl::MultipleDeviceTensorHolderBase::DevValuesExecDep final | |||||
SmallVector<DeviceTensorStorage> m_vals; | SmallVector<DeviceTensorStorage> m_vals; | ||||
public: | public: | ||||
explicit DevValuesExecDep(const ValueArray& vals) { | |||||
for (auto&& val : vals) { | |||||
m_vals.emplace_back(std::move(val->storage())); | |||||
explicit DevValuesExecDep(const ValueArray& vals, | |||||
MultipleDeviceTensorHolderBase* opr) { | |||||
mgb_assert(vals.size() == opr->output().size(), | |||||
"the output value size is diff from output var size"); | |||||
for (size_t index = 0; index < vals.size(); index++) { | |||||
if (!opr->output(index)->contain_flag( | |||||
VarNode::Flag::MEMORY_NO_NEED)) { | |||||
m_vals.emplace_back(std::move(vals[index]->storage())); | |||||
} | |||||
} | } | ||||
} | } | ||||
}; | }; | ||||
intl::MultipleDeviceTensorHolderBase::MultipleDeviceTensorHolderBase( | intl::MultipleDeviceTensorHolderBase::MultipleDeviceTensorHolderBase( | ||||
ComputingGraph& graph, ValueArray values, | ComputingGraph& graph, ValueArray values, | ||||
const OperatorNodeConfig& config) | const OperatorNodeConfig& config) | ||||
@@ -887,8 +894,7 @@ intl::MultipleDeviceTensorHolderBase::do_make_node_prop() const { | |||||
void intl::MultipleDeviceTensorHolderBase::record_execute_deps( | void intl::MultipleDeviceTensorHolderBase::record_execute_deps( | ||||
ExecDependencyArray& deps) { | ExecDependencyArray& deps) { | ||||
deps.emplace_back( | |||||
std::make_unique<DevValuesExecDep>(values())); | |||||
deps.emplace_back(std::make_unique<DevValuesExecDep>(values(), this)); | |||||
} | } | ||||
/* ===================== MultipleDeviceTensorHolder ===================== */ | /* ===================== MultipleDeviceTensorHolder ===================== */ | ||||
@@ -173,9 +173,15 @@ size_t AlgoChooser<Opr>::setup_algo(const ConvTensorLayouts& layouts, | |||||
return 0; | return 0; | ||||
} | } | ||||
ImplAlgo algo = nullptr; | |||||
ExeContext ctx(layouts, megdnn_opr, mgb_opr, allow_weight_preprocess); | ExeContext ctx(layouts, megdnn_opr, mgb_opr, allow_weight_preprocess); | ||||
auto algo = get_algo(ctx); | |||||
if (auto algo_choose_hook = mgb_opr->algo_chooser()) { | |||||
algo = algo_choose_hook(mgb_opr); | |||||
} | |||||
if (!algo) { | |||||
algo = get_algo(ctx); | |||||
} | |||||
size_t workspace = ctx.get_workspace_size_bytes(algo); | size_t workspace = ctx.get_workspace_size_bytes(algo); | ||||
mgb_log_debug( | mgb_log_debug( | ||||
"%s: tensor layouts(%s %s, %s %s) -> (%s %s): algo=%s " | "%s: tensor layouts(%s %s, %s %s) -> (%s %s): algo=%s " | ||||
@@ -360,16 +366,29 @@ AlgoChooser<Opr>::ExeContext::construct_fake_preprocess_filter() const { | |||||
if (!m_allow_weight_preprocess) | if (!m_allow_weight_preprocess) | ||||
return; | return; | ||||
auto opr = _(m_megdnn_opr); | auto opr = _(m_megdnn_opr); | ||||
auto layout = APPLY(opr->deduce_preprocessed_filter_layout(args...), | |||||
m_layouts); | |||||
if (layout.empty()) | |||||
auto layouts = APPLY(opr->deduce_preprocessed_filter_layout(args...), | |||||
m_layouts); | |||||
//! No preprocess layout means no need weight preprocess | |||||
if (layouts.empty()) { | |||||
return; | return; | ||||
} | |||||
//! all layouts arm empty means no need weight preprocess | |||||
bool layout_valid = false; | |||||
for (auto&& layout : layouts) { | |||||
if (!layout.is_empty()) { | |||||
layout_valid = true; | |||||
} | |||||
} | |||||
if (!layout_valid) { | |||||
return; | |||||
} | |||||
result = PreprocessFilter<Opr>{}; | result = PreprocessFilter<Opr>{}; | ||||
auto& res = result.val(); | auto& res = result.val(); | ||||
res.algorithm_id = nullptr; | res.algorithm_id = nullptr; | ||||
res.tensors.resize(layout.size()); | |||||
for (size_t i = 0; i < layout.size(); i++) { | |||||
res.tensors[i] = megdnn::TensorND(nullptr, layout[i]); | |||||
res.tensors.resize(layouts.size()); | |||||
for (size_t i = 0; i < layouts.size(); i++) { | |||||
res.tensors[i] = megdnn::TensorND(nullptr, layouts[i]); | |||||
} | } | ||||
}); | }); | ||||
return result; | return result; | ||||
@@ -25,6 +25,9 @@ namespace mixin { | |||||
class Convolution { | class Convolution { | ||||
public: | public: | ||||
using ExecutionPolicy = megdnn::param::ExecutionPolicy; | using ExecutionPolicy = megdnn::param::ExecutionPolicy; | ||||
using Algorithm = megdnn::detail::Algorithm; | |||||
using AlgoChooserHook = | |||||
std::function<Algorithm*(const OperatorNodeBase*)>; | |||||
const ExecutionPolicy& execution_policy() const { | const ExecutionPolicy& execution_policy() const { | ||||
if (!m_policy_accessed) { | if (!m_policy_accessed) { | ||||
@@ -55,6 +58,16 @@ class Convolution { | |||||
virtual std::pair<const void*, size_t> param_blob() const = 0; | virtual std::pair<const void*, size_t> param_blob() const = 0; | ||||
/*! | |||||
* \brief register a hook to implement custom algo chooser | |||||
*/ | |||||
void setup_algo_chooser(AlgoChooserHook&& func) { | |||||
m_algo_chooser = func; | |||||
} | |||||
AlgoChooserHook algo_chooser() const { | |||||
return m_algo_chooser; | |||||
} | |||||
protected: | protected: | ||||
~Convolution(); | ~Convolution(); | ||||
@@ -63,6 +76,8 @@ class Convolution { | |||||
std::unique_ptr<AlgoChooserProfileCache> m_profile_cache; | std::unique_ptr<AlgoChooserProfileCache> m_profile_cache; | ||||
AlgoChooserHook m_algo_chooser; | |||||
virtual void init_profile_cache() = 0; | virtual void init_profile_cache() = 0; | ||||
//! init output desc for conv backward data oprs; it handles both grad | //! init output desc for conv backward data oprs; it handles both grad | ||||
@@ -99,6 +99,11 @@ MGB_DEFINE_CLS_WITH_SUPER(SharedDeviceTensorBase, DeviceTensorHolder) // { | |||||
return *m_dev_data; | return *m_dev_data; | ||||
} | } | ||||
void free_dev_data() { | |||||
m_dev_data->reset(DeviceTensorStorage{m_dev_data->comp_node()}, | |||||
m_dev_data->layout()); | |||||
} | |||||
const std::shared_ptr<DeviceTensorND>& dev_data() const { | const std::shared_ptr<DeviceTensorND>& dev_data() const { | ||||
return m_dev_data; | return m_dev_data; | ||||
} | } | ||||
@@ -122,6 +127,10 @@ public: | |||||
const OperatorNodeConfig& config); | const OperatorNodeConfig& config); | ||||
const ValueArray& values() const { return m_values; } | const ValueArray& values() const { return m_values; } | ||||
ValueArray& mutable_values() { | |||||
return m_values; | |||||
} | |||||
protected: | protected: | ||||
ValueArray m_values; | ValueArray m_values; | ||||
@@ -292,7 +301,7 @@ MGB_DEFINE_OPR_CLASS(SharedDeviceTensor, intl::SharedDeviceTensorBase) // { | |||||
static SymbolVar make_const(ComputingGraph& graph, | static SymbolVar make_const(ComputingGraph& graph, | ||||
const HostTensorND& value, | const HostTensorND& value, | ||||
const OperatorNodeConfig& config = {}) { | const OperatorNodeConfig& config = {}) { | ||||
return make(graph, value, false, config); | |||||
return make(graph, value, true, config); | |||||
} | } | ||||
}; | }; | ||||