GitOrigin-RevId: 5f91acb909
release-1.1
@@ -576,10 +576,10 @@ VarNode& VarNode::add_flag(Flag flag) { | |||
void VarNode::modify_flag(Flag delta, Flag new_flag) { | |||
if (contain_flag(Flag::FLAG_FREEZED)) { | |||
mgb_assert((delta & ( | |||
Flag::NO_MEM_RECLAIM | | |||
Flag::NO_SYS_STATIC_MEM_ALLOC | | |||
Flag::RT_FORCE_DYNAMIC_MEM_ALLOC)) == delta); | |||
mgb_assert( | |||
(delta & (Flag::NO_MEM_RECLAIM | Flag::NO_SYS_STATIC_MEM_ALLOC | | |||
Flag::RT_FORCE_DYNAMIC_MEM_ALLOC)) == delta || | |||
(new_flag & Flag::MEMORY_NO_NEED)); | |||
mgb_assert(!ComputingGraphImpl::downcast(owner_graph())-> | |||
var_node_mem_manager().optimize_started(), | |||
@@ -24,6 +24,8 @@ | |||
#include "megbrain/utils/timer.h" | |||
#include "megbrain/utils/arith_helper.h" | |||
#include "megbrain/opr/io.h" | |||
#include <chrono> | |||
using namespace mgb; | |||
@@ -36,7 +38,6 @@ void call_mem_status_changed(cg::OperatorNodeBase* opr) { | |||
if (cb.on_mem_status_changed.valid()) | |||
cb.on_mem_status_changed.val()(); | |||
} | |||
} // namespace | |||
/* ==================== StaticDeviceMemoryManager ==================== */ | |||
@@ -393,11 +394,12 @@ bool VarNodeMemManager::alloc_var_node_mem_static() { | |||
bool VarNodeMemManager::update_static_alloc_plan() { | |||
// check whether unchanged | |||
bool free_no_need_memory = free_combine_memory_no_need_var(); | |||
if (!m_owner_graph->static_infer_comp_seq_manager() | |||
.update_static_check_shape_change() && | |||
!m_first_static_plan_run && | |||
!m_impure_mem_plan_mgr.check_need_realloc()) { | |||
return false; | |||
return false || free_no_need_memory; | |||
} | |||
if (m_first_static_plan_run) | |||
@@ -494,6 +496,96 @@ bool VarNodeMemManager::make_static_var_tensor_from_alloc_plan() { | |||
return true; | |||
} | |||
bool VarNodeMemManager::free_combine_memory_no_need_var() { | |||
if (!m_owner_graph->options().graph_opt.weight_preprocess || | |||
m_already_free_no_need_mem) { | |||
return false; | |||
} | |||
bool reordered = false; | |||
//! free no need storage | |||
for (auto opr : *m_opr_seq) { | |||
if (opr->try_cast_final<opr::SharedDeviceTensor>() || | |||
opr->try_cast_final<opr::SharedDeviceTensorWithFormat>()) { | |||
auto opr_base = static_cast<opr::intl::SharedDeviceTensorBase*>(opr); | |||
auto var = opr_base->output(0); | |||
if (var->contain_flag(VarNode::Flag::MEMORY_NO_NEED) && | |||
var->dev_tensor_valid() && !var->dev_tensor().empty()) { | |||
//! Only the tensor share count is 1, it can be free | |||
if (opr_base->dev_data().use_count() == 1) { | |||
auto layout = var->layout(); | |||
var->m_dev_tensor.reset( | |||
DeviceTensorStorage{var->comp_node()}, layout); | |||
opr_base->free_dev_data(); | |||
mgb_log_debug( | |||
"preprocessed weight is freed, var name = %s, " | |||
"var layout = %s", | |||
var->name().c_str(), layout.to_string().c_str()); | |||
} | |||
m_already_free_no_need_mem = true; | |||
} | |||
} | |||
bool memory_need_reorder = false; | |||
if (opr->try_cast_final<opr::MultipleDeviceTensorHolder>() || | |||
opr->try_cast_final<opr::MultipleDeviceTensorWithFormatHolder>()) { | |||
auto opr_base = | |||
static_cast<opr::intl::MultipleDeviceTensorHolderBase*>( | |||
opr); | |||
for (size_t index = 0; index < opr_base->output().size(); index++) { | |||
auto var = opr_base->output(index); | |||
if (var->contain_flag(VarNode::Flag::MEMORY_NO_NEED) && | |||
var->dev_tensor_valid() && !var->dev_tensor().empty()) { | |||
//! Only the tensor share count is 1, it can be free | |||
if (opr_base->values()[index].use_count() == 1) { | |||
auto layout = var->layout(); | |||
var->m_dev_tensor.reset( | |||
DeviceTensorStorage{var->comp_node()}, layout); | |||
opr_base->mutable_values()[index]->reset( | |||
DeviceTensorStorage{var->comp_node()}, layout); | |||
memory_need_reorder = true; | |||
mgb_log_debug( | |||
"preprocessed weight is freed, var name " | |||
"= %s, var layout = %s", | |||
var->name().c_str(), | |||
layout.to_string().c_str()); | |||
} | |||
m_already_free_no_need_mem = true; | |||
} | |||
} | |||
} | |||
//! recorder the other needed outputs, because they share the | |||
//! same chunk of mem in device with no needed var, see | |||
//! BatchedDeviceValueLoader | |||
if (memory_need_reorder) { | |||
auto opr_base = | |||
static_cast<opr::intl::MultipleDeviceTensorHolderBase*>( | |||
opr); | |||
auto comp_node = opr_base->output(0)->comp_node(); | |||
bool is_device_opr = | |||
comp_node.mem_node() != CompNode::default_cpu().mem_node(); | |||
if (memory_need_reorder && is_device_opr) { | |||
for (size_t index = 0; index < opr_base->output().size(); | |||
index++) { | |||
auto var = opr_base->output(index); | |||
if (!var->contain_flag(VarNode::Flag::MEMORY_NO_NEED)) { | |||
DeviceTensorStorage storage(var->comp_node()); | |||
size_t size = var->layout().span().dist_byte(); | |||
storage.ensure_size(size); | |||
storage.copy_from(var->m_dev_tensor.storage(), size); | |||
var->m_dev_tensor.reset(storage, var->layout()); | |||
opr_base->mutable_values()[index]->reset(storage, | |||
var->layout()); | |||
reordered = true; | |||
} | |||
} | |||
//! sync to make sure memcopy is finished | |||
comp_node.sync(); | |||
} | |||
} | |||
} | |||
return reordered; | |||
} | |||
void VarNodeMemManager::init_dynamic_alloc_opr_info() { | |||
mgb_assert(m_first_static_plan_run); | |||
m_need_post_exec_action_vars.clear(); | |||
@@ -174,6 +174,14 @@ class VarNodeMemManager { | |||
bool alloc_var_node_mem_static(); | |||
/*! | |||
* \brief free the memory of var with MEMORY_NO_NEED flag | |||
* | |||
* \return whether memory of MEMORY_NO_NEED var or related other var | |||
* memory changed | |||
*/ | |||
bool free_combine_memory_no_need_var(); | |||
/*! | |||
* \brief initialize static memory allocation plan | |||
* | |||
* This can be used with custom StaticDeviceMemoryAllocator so static | |||
@@ -407,7 +415,8 @@ class VarNodeMemManager { | |||
bool check_need_realloc(); | |||
}; | |||
bool m_first_static_plan_run = true, m_optimize_started = false; | |||
bool m_first_static_plan_run = true, m_optimize_started = false, | |||
m_already_free_no_need_mem = false; | |||
ComputingGraphImpl *m_owner_graph; | |||
ThinHashMap<VarNode*, VarNodeMemTrait> m_node_mem_trait; | |||
NullableHashMap<OperatorNodeBase*, DynamicAllocOprInfo> | |||
@@ -449,7 +449,11 @@ DEF(resize, &)(const TensorShape& shape) { | |||
} | |||
DEF(reset, &)(TensorStorage storage, const TensorLayout &layout) { | |||
mgb_assert(!layout.ndim || storage.valid_span(layout.span())); | |||
//! The storage to be reset is either satisfy the layout or empty. | |||
//! Empty storage is used after weight preprocess for saving memory and | |||
//! checking layout when running | |||
mgb_assert(!layout.ndim || storage.valid_span(layout.span()) || | |||
storage.empty()); | |||
m_storage = std::move(storage); | |||
m_layout = layout; | |||
return static_cast<ChainReturnType&>(*this); | |||
@@ -98,7 +98,8 @@ struct GraphCommonOptimizeOptions { | |||
//! whether to enable fast-run profiled winograd opr replace | |||
bool weight_winograd_transform = false; | |||
//! whether to enable weight preprocess, if enabled it may use more | |||
//! memory, default disable now | |||
//! memory, default disable now, when weight preprocess is enabled, the | |||
//! input shape should no change | |||
bool weight_preprocess = false; | |||
enum LayoutTransform : uint32_t { | |||
DEFAULT, | |||
@@ -589,7 +589,7 @@ class VarNode final: public GraphNodeBase { | |||
friend class imperative::ProxyGraph; | |||
}; | |||
enum class VarNode::Flag: uint32_t { | |||
enum class VarNode::Flag : uint32_t { | |||
//! do not allocate memory by the system allocator even if shape could be | |||
//! inferred | |||
NO_SYS_MEM_ALLOC = 1 << 0, | |||
@@ -667,6 +667,12 @@ enum class VarNode::Flag: uint32_t { | |||
* after FLAG_FREEZED is present. | |||
*/ | |||
FLAG_FREEZED = 1 << 10, | |||
/*! | |||
* this flag indicates that data of this var has been processed and no need | |||
* later, it can be freed, this is used in weight preprocess for memory save | |||
*/ | |||
MEMORY_NO_NEED = 1 << 11, | |||
}; | |||
MGB_DEF_ENUM_CLASS_BIT_OPR(VarNode::Flag) | |||
@@ -1920,4 +1920,236 @@ TEST(TestGraph, NaiveRecord2NCHW44) { | |||
func->execute().wait(); | |||
} | |||
namespace { | |||
template <typename DnnOp, typename... Args> | |||
typename DnnOp::Algorithm* try_find_any_weight_preprocess_algo( | |||
DnnOp* dnn_op, const char* mgb_info, Maybe<bool>& found, | |||
Args&& ...args) { | |||
if (found.valid()) { | |||
if (found.val()) { | |||
return dnn_op->execution_policy().algorithm; | |||
} else { | |||
return nullptr; | |||
} | |||
} | |||
for (auto&& algo : dnn_op->get_all_algorithms( | |||
std::forward<Args>(args)...)) { | |||
dnn_op->execution_policy().algorithm = algo; | |||
auto layouts = dnn_op->deduce_preprocessed_filter_layout( | |||
std::forward<Args>(args)...); | |||
if (layouts.empty()) continue; | |||
bool valid = false; | |||
for (auto&& l: layouts) { | |||
if (!l.is_empty()) { | |||
valid = true; | |||
break; | |||
} | |||
} | |||
if (valid) { | |||
found.emplace(true); | |||
return algo; | |||
} | |||
} | |||
found.emplace(false); | |||
mgb_log_warn("Can't find weight preprocess algo for op %s", mgb_info); | |||
return nullptr; | |||
} | |||
void test_free_memory_in_weight_preprocess(int record_level, CompNode cn) { | |||
HostTensorGenerator<> gen; | |||
auto graph = ComputingGraph::make(); | |||
graph->options().graph_opt.weight_preprocess = true; | |||
graph->options().comp_node_seq_record_level = record_level; | |||
auto mkvar = [&](const char* name, const TensorShape& shp) { | |||
return opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name); | |||
}; | |||
auto mkcvar = [&](const char* name, const TensorShape& shp) { | |||
return opr::SharedDeviceTensor::make_const(*graph, *gen(shp, cn)) | |||
.rename(name); | |||
}; | |||
auto x = mkvar("x", {1, 32, 16, 16}); | |||
// ConvBias test dense | |||
opr::ConvBias::Param param_conv_bias; | |||
param_conv_bias.pad_h = param_conv_bias.pad_w = 0; | |||
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE; | |||
auto w1 = mkcvar("w1", {32, 32, 1, 1}), b1 = mkcvar("b1", {1, 32, 1, 1}); | |||
auto conv1 = opr::ConvBias::make(x, w1, b1, param_conv_bias); | |||
Maybe<bool> wp1, wp2; | |||
conv1.node()->owner_opr()->cast_final_safe<opr::ConvBias>() | |||
.setup_algo_chooser([&](const cg::OperatorNodeBase* opr) { | |||
return try_find_any_weight_preprocess_algo( | |||
opr->cast_final_safe<opr::ConvBias>().megdnn_opr(), | |||
opr->cname(), wp1, | |||
opr->input(0)->layout(), opr->input(1)->layout(), | |||
opr->input(2)->layout(), TensorLayout{}, | |||
opr->output(0)->layout()); | |||
}); | |||
// Convolution | |||
opr::Convolution::Param param_conv; | |||
param_conv.pad_h = param_conv.pad_w = 0; | |||
param_conv.sparse = opr::Convolution::Param::Sparse::DENSE; | |||
auto w2 = mkcvar("w2", {32, 32, 1, 1}); | |||
auto y = opr::Convolution::make(conv1, w2, param_conv); | |||
y.node()->owner_opr()->cast_final_safe<opr::Convolution>() | |||
.setup_algo_chooser([&](const cg::OperatorNodeBase* opr) { | |||
return try_find_any_weight_preprocess_algo( | |||
opr->cast_final_safe<opr::Convolution>().megdnn_opr(), | |||
opr->cname(), wp2, | |||
opr->input(0)->layout(), opr->input(1)->layout(), | |||
opr->output(0)->layout()); | |||
}); | |||
HostTensorND host_y; | |||
auto func =graph->compile({make_callback_copy(y, host_y)}); | |||
//!flag the no need memory of var | |||
func->execute(); | |||
//!free the no need memory of var | |||
func->execute(); | |||
auto check = [&](SymbolVar v) { | |||
ASSERT_TRUE(v.node()->contain_flag(VarNode::Flag::MEMORY_NO_NEED)); | |||
ASSERT_TRUE(v.node()->dev_tensor().empty()); | |||
ASSERT_TRUE(v.node()->owner_opr() | |||
->cast_final_safe<opr::SharedDeviceTensor>() | |||
.get_dev_tensor() | |||
.empty()); | |||
}; | |||
ASSERT_TRUE(wp1.valid() && wp2.valid()); | |||
if (wp1.val()) { | |||
check(w1); | |||
} | |||
if (wp2.val()) { | |||
check(w2); | |||
} | |||
} | |||
} // anonymous namespace | |||
TEST(TestGraph, FreeMemoryInWeightPreprocess) { | |||
test_free_memory_in_weight_preprocess(0, CompNode::load("xpu0")); | |||
} | |||
TEST(TestGraph, RecordFreeMemoryInWeightPreprocess) { | |||
test_free_memory_in_weight_preprocess(1, CompNode::load("cpu0")); | |||
} | |||
namespace { | |||
MGB_DEFINE_OPR_CLASS(HostValueReader, cg::SingleCNOutshapePureByInshapeOprBase) // { | |||
void scn_do_execute() override { | |||
auto&& hv = owner_graph()->static_infer_manager().infer_value(input(0)); | |||
MGB_MARK_USED_VAR(hv); | |||
} | |||
NodeProp* do_make_node_prop() const override { | |||
auto ret = Super::do_make_node_prop(); | |||
ret->dep_map()[input(0)] = NodeProp::DepType::HOST_VALUE; | |||
return ret; | |||
} | |||
void get_output_var_shape( | |||
const TensorShapeArray &, | |||
TensorShapeArray &out_shape) const override { | |||
out_shape.at(0) = {}; | |||
} | |||
public: | |||
HostValueReader(VarNode* inp) | |||
: Super{inp->owner_graph(), {}, "host_value_reader", {inp}} { | |||
add_input({inp}); | |||
using F = VarNode::Flag; | |||
add_output(None) | |||
->add_flag(F::ALLOW_EMPTY_SHAPE) | |||
.add_flag(F::VOLATILE_CONTENT); | |||
} | |||
static SymbolVar make(SymbolVar inp) { | |||
return inp.node()->owner_graph()->insert_opr( | |||
std::make_unique<HostValueReader>(inp.node()))->output(0); | |||
} | |||
}; | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(HostValueReader); | |||
} | |||
TEST(TestGraph, FreeMemoryInWeightPreprocessWithValueInfer) { | |||
HostTensorGenerator<> gen; | |||
CompNode cn = CompNode::load("xpux"); | |||
auto graph = ComputingGraph::make(); | |||
graph->options().graph_opt.weight_preprocess = true; | |||
graph->options().var_sanity_check_first_run = false; | |||
auto mkvar = [&](const char* name, const TensorShape& shp) { | |||
return opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name); | |||
}; | |||
auto mkcvar = [&](const char* name, const TensorShape& shp) { | |||
return opr::SharedDeviceTensor::make_const(*graph, *gen(shp, cn)) | |||
.rename(name); | |||
}; | |||
auto x = mkvar("x", {1, 32, 16, 16}); | |||
auto w = mkcvar("w", {32, 32, 1, 1}); | |||
auto y = opr::Convolution::make(x, w); | |||
Maybe<bool> found; | |||
y.node()->owner_opr()->cast_final_safe<opr::Convolution>() | |||
.setup_algo_chooser([&](const cg::OperatorNodeBase* opr) { | |||
return try_find_any_weight_preprocess_algo( | |||
opr->cast_final_safe<opr::Convolution>().megdnn_opr(), | |||
opr->cname(), found, | |||
opr->input(0)->layout(), opr->input(1)->layout(), | |||
opr->output(0)->layout()); | |||
}); | |||
auto reader = HostValueReader::make(w); | |||
HostTensorND host_y; | |||
auto func = graph->compile({make_callback_copy(y, host_y), {reader, {}}}); | |||
func->execute(); | |||
// FIXME: failed on second execution due to requiring host value of the empty | |||
// tensor which was freed in weight preprocess | |||
func->execute(); | |||
ASSERT_FALSE(w.node()->contain_flag(VarNode::Flag::MEMORY_NO_NEED)); | |||
ASSERT_FALSE(w.node()->dev_tensor().empty()); | |||
ASSERT_FALSE(w.node()->owner_opr() | |||
->cast_final_safe<opr::SharedDeviceTensor>() | |||
.get_dev_tensor() | |||
.empty()); | |||
} | |||
TEST(TestGraph, FreeMemoryInWeightPreprocessWithMultiReader) { | |||
HostTensorGenerator<> gen; | |||
CompNode cn = CompNode::load("xpux"); | |||
auto graph = ComputingGraph::make(); | |||
graph->options().graph_opt.weight_preprocess = true; | |||
graph->options().var_sanity_check_first_run = false; | |||
graph->options().graph_opt_level = 0; | |||
auto mkvar = [&](const char* name, const TensorShape& shp) { | |||
return opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name); | |||
}; | |||
auto mkcvar = [&](const char* name, const TensorShape& shp) { | |||
return opr::SharedDeviceTensor::make_const(*graph, *gen(shp, cn)) | |||
.rename(name); | |||
}; | |||
auto x = mkvar("x", {1, 32, 16, 16}); | |||
auto w = mkcvar("w", {32, 32, 1, 1}); | |||
auto y = opr::Convolution::make(x, w); | |||
Maybe<bool> found; | |||
y.node()->owner_opr()->cast_final_safe<opr::Convolution>() | |||
.setup_algo_chooser([&](const cg::OperatorNodeBase* opr) { | |||
return try_find_any_weight_preprocess_algo( | |||
opr->cast_final_safe<opr::Convolution>().megdnn_opr(), | |||
opr->cname(), found, | |||
opr->input(0)->layout(), opr->input(1)->layout(), | |||
opr->output(0)->layout()); | |||
}); | |||
auto y1 = w * 2 + 1; | |||
HostTensorND host_y, host_y1; | |||
auto func = graph->compile({ | |||
make_callback_copy(y, host_y), make_callback_copy(y1, host_y1)}); | |||
func->execute(); | |||
// FIXME: failed on second execution due to calculate expression | |||
// (w * 2 + 1) with empty tensor | |||
func->execute(); | |||
ASSERT_FALSE(w.node()->contain_flag(VarNode::Flag::MEMORY_NO_NEED)); | |||
ASSERT_FALSE(w.node()->dev_tensor().empty()); | |||
ASSERT_FALSE(w.node()->owner_opr() | |||
->cast_final_safe<opr::SharedDeviceTensor>() | |||
.get_dev_tensor() | |||
.empty()); | |||
} | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -138,39 +138,36 @@ public: | |||
void mixin::WeightPreprocessExecutor::mixin_update_preprocessed_filter( | |||
cg::OperatorNodeBase& opr) { | |||
if (!mixin_allow_weight_preprocess(opr)) | |||
if (!mixin_allow_weight_preprocess(opr)) { | |||
return; | |||
} | |||
auto new_layout = deduce_preprocessed_filter_layout(); | |||
size_t new_size = new_layout.size(); | |||
//! No preprocess layout means no need weight preprocess | |||
if (new_layout.empty()) { | |||
// Weight preprocess was needed before, but no longer needed. | |||
if (m_preprocessed_filter) { | |||
m_preprocessed_filter.reset(); | |||
m_filter_storage.clear(); | |||
return; | |||
} | |||
//! all layouts arm empty means no need weight preprocess | |||
bool layout_valid = false; | |||
for (auto&& layout : new_layout) { | |||
if (!layout.is_empty()) { | |||
layout_valid = true; | |||
} | |||
} | |||
if (!layout_valid) { | |||
return; | |||
} | |||
bool should_update = false; | |||
size_t new_size = new_layout.size(); | |||
if (!m_preprocessed_filter || | |||
m_preprocessed_filter->tensors.size() != new_size) { | |||
should_update = true; | |||
} else { | |||
if (m_preprocessed_filter) { | |||
for (size_t i = 0; i < new_size; i++) { | |||
if (!new_layout[i].eq_layout( | |||
m_preprocessed_filter->tensors[i].layout)) { | |||
should_update = true; | |||
break; | |||
} | |||
mgb_assert(new_layout[i].eq_layout( | |||
m_preprocessed_filter->tensors[i].layout), | |||
"weight preprocess layout changed, please keep input " | |||
"shape unchanged when weight preprocess is enabled"); | |||
} | |||
} | |||
if (!should_update) | |||
return; | |||
if (!m_preprocessed_filter) { | |||
m_preprocessed_filter.reset(new PreprocessedFilter{}); | |||
} | |||
m_preprocessed_filter.reset(new PreprocessedFilter{}); | |||
m_preprocessed_filter->tensors.resize(new_size); | |||
m_filter_storage.resize(new_size); | |||
m_preprocessed_filter->algorithm_id = nullptr; | |||
@@ -327,6 +324,14 @@ void ConvolutionForward::scn_do_execute_preprocess() { | |||
input(0)->layout(), input(1)->dev_tensor().as_megdnn(), | |||
output(0)->layout(), preprocessed_filter(), | |||
intl::get_megdnn_workspace_from_var(output().back())); | |||
//! Flag the input(1) no use later, which can be freed when no other | |||
//! var depend on its dev_value, host_value and shape. | |||
auto receiver_info = | |||
input(1)->owner_graph()->var_receiver_in_current_comp_seq(input(1)); | |||
if (receiver_info.dev_value == 1 && receiver_info.host_value == 0 && | |||
receiver_info.shape == 0) { | |||
input(1)->add_flag(VarNode::Flag::MEMORY_NO_NEED); | |||
} | |||
} | |||
/* ==================== ConvolutionBackwardData ==================== */ | |||
@@ -959,6 +964,14 @@ void ConvBiasForward::scn_do_execute_preprocess() { | |||
input(0)->layout(), input(1)->dev_tensor().as_megdnn(), bias_layout, | |||
z_layout, output(0)->layout(), preprocessed_filter(), | |||
intl::get_megdnn_workspace_from_var(output().back())); | |||
//! Flag the input(1) no use later, which can be freed when no other | |||
//! var depend on its dev_value, host_value and shape. | |||
auto receiver_info = | |||
input(1)->owner_graph()->var_receiver_in_current_comp_seq(input(1)); | |||
if (receiver_info.dev_value == 1 && receiver_info.host_value == 0 && | |||
receiver_info.shape == 0) { | |||
input(1)->add_flag(VarNode::Flag::MEMORY_NO_NEED); | |||
} | |||
} | |||
/* ===================== LocalShareForward ==================== */ | |||
@@ -142,8 +142,10 @@ void intl::DeviceTensorHolder::add_output(DType dtype) { | |||
} | |||
void intl::DeviceTensorHolder::record_execute_deps(ExecDependencyArray& deps) { | |||
deps.emplace_back( | |||
std::make_unique<DevValueExecDep>(get_dev_tensor().storage())); | |||
if (!output(0)->contain_flag(VarNode::Flag::MEMORY_NO_NEED)) { | |||
deps.emplace_back( | |||
std::make_unique<DevValueExecDep>(get_dev_tensor().storage())); | |||
} | |||
} | |||
/* ===================== Host2DeviceCopy ===================== */ | |||
@@ -801,14 +803,19 @@ class intl::MultipleDeviceTensorHolderBase::DevValuesExecDep final | |||
SmallVector<DeviceTensorStorage> m_vals; | |||
public: | |||
explicit DevValuesExecDep(const ValueArray& vals) { | |||
for (auto&& val : vals) { | |||
m_vals.emplace_back(std::move(val->storage())); | |||
explicit DevValuesExecDep(const ValueArray& vals, | |||
MultipleDeviceTensorHolderBase* opr) { | |||
mgb_assert(vals.size() == opr->output().size(), | |||
"the output value size is diff from output var size"); | |||
for (size_t index = 0; index < vals.size(); index++) { | |||
if (!opr->output(index)->contain_flag( | |||
VarNode::Flag::MEMORY_NO_NEED)) { | |||
m_vals.emplace_back(std::move(vals[index]->storage())); | |||
} | |||
} | |||
} | |||
}; | |||
intl::MultipleDeviceTensorHolderBase::MultipleDeviceTensorHolderBase( | |||
ComputingGraph& graph, ValueArray values, | |||
const OperatorNodeConfig& config) | |||
@@ -887,8 +894,7 @@ intl::MultipleDeviceTensorHolderBase::do_make_node_prop() const { | |||
void intl::MultipleDeviceTensorHolderBase::record_execute_deps( | |||
ExecDependencyArray& deps) { | |||
deps.emplace_back( | |||
std::make_unique<DevValuesExecDep>(values())); | |||
deps.emplace_back(std::make_unique<DevValuesExecDep>(values(), this)); | |||
} | |||
/* ===================== MultipleDeviceTensorHolder ===================== */ | |||
@@ -173,9 +173,15 @@ size_t AlgoChooser<Opr>::setup_algo(const ConvTensorLayouts& layouts, | |||
return 0; | |||
} | |||
ImplAlgo algo = nullptr; | |||
ExeContext ctx(layouts, megdnn_opr, mgb_opr, allow_weight_preprocess); | |||
auto algo = get_algo(ctx); | |||
if (auto algo_choose_hook = mgb_opr->algo_chooser()) { | |||
algo = algo_choose_hook(mgb_opr); | |||
} | |||
if (!algo) { | |||
algo = get_algo(ctx); | |||
} | |||
size_t workspace = ctx.get_workspace_size_bytes(algo); | |||
mgb_log_debug( | |||
"%s: tensor layouts(%s %s, %s %s) -> (%s %s): algo=%s " | |||
@@ -360,16 +366,29 @@ AlgoChooser<Opr>::ExeContext::construct_fake_preprocess_filter() const { | |||
if (!m_allow_weight_preprocess) | |||
return; | |||
auto opr = _(m_megdnn_opr); | |||
auto layout = APPLY(opr->deduce_preprocessed_filter_layout(args...), | |||
m_layouts); | |||
if (layout.empty()) | |||
auto layouts = APPLY(opr->deduce_preprocessed_filter_layout(args...), | |||
m_layouts); | |||
//! No preprocess layout means no need weight preprocess | |||
if (layouts.empty()) { | |||
return; | |||
} | |||
//! all layouts arm empty means no need weight preprocess | |||
bool layout_valid = false; | |||
for (auto&& layout : layouts) { | |||
if (!layout.is_empty()) { | |||
layout_valid = true; | |||
} | |||
} | |||
if (!layout_valid) { | |||
return; | |||
} | |||
result = PreprocessFilter<Opr>{}; | |||
auto& res = result.val(); | |||
res.algorithm_id = nullptr; | |||
res.tensors.resize(layout.size()); | |||
for (size_t i = 0; i < layout.size(); i++) { | |||
res.tensors[i] = megdnn::TensorND(nullptr, layout[i]); | |||
res.tensors.resize(layouts.size()); | |||
for (size_t i = 0; i < layouts.size(); i++) { | |||
res.tensors[i] = megdnn::TensorND(nullptr, layouts[i]); | |||
} | |||
}); | |||
return result; | |||
@@ -25,6 +25,9 @@ namespace mixin { | |||
class Convolution { | |||
public: | |||
using ExecutionPolicy = megdnn::param::ExecutionPolicy; | |||
using Algorithm = megdnn::detail::Algorithm; | |||
using AlgoChooserHook = | |||
std::function<Algorithm*(const OperatorNodeBase*)>; | |||
const ExecutionPolicy& execution_policy() const { | |||
if (!m_policy_accessed) { | |||
@@ -55,6 +58,16 @@ class Convolution { | |||
virtual std::pair<const void*, size_t> param_blob() const = 0; | |||
/*! | |||
* \brief register a hook to implement custom algo chooser | |||
*/ | |||
void setup_algo_chooser(AlgoChooserHook&& func) { | |||
m_algo_chooser = func; | |||
} | |||
AlgoChooserHook algo_chooser() const { | |||
return m_algo_chooser; | |||
} | |||
protected: | |||
~Convolution(); | |||
@@ -63,6 +76,8 @@ class Convolution { | |||
std::unique_ptr<AlgoChooserProfileCache> m_profile_cache; | |||
AlgoChooserHook m_algo_chooser; | |||
virtual void init_profile_cache() = 0; | |||
//! init output desc for conv backward data oprs; it handles both grad | |||
@@ -99,6 +99,11 @@ MGB_DEFINE_CLS_WITH_SUPER(SharedDeviceTensorBase, DeviceTensorHolder) // { | |||
return *m_dev_data; | |||
} | |||
void free_dev_data() { | |||
m_dev_data->reset(DeviceTensorStorage{m_dev_data->comp_node()}, | |||
m_dev_data->layout()); | |||
} | |||
const std::shared_ptr<DeviceTensorND>& dev_data() const { | |||
return m_dev_data; | |||
} | |||
@@ -122,6 +127,10 @@ public: | |||
const OperatorNodeConfig& config); | |||
const ValueArray& values() const { return m_values; } | |||
ValueArray& mutable_values() { | |||
return m_values; | |||
} | |||
protected: | |||
ValueArray m_values; | |||
@@ -292,7 +301,7 @@ MGB_DEFINE_OPR_CLASS(SharedDeviceTensor, intl::SharedDeviceTensorBase) // { | |||
static SymbolVar make_const(ComputingGraph& graph, | |||
const HostTensorND& value, | |||
const OperatorNodeConfig& config = {}) { | |||
return make(graph, value, false, config); | |||
return make(graph, value, true, config); | |||
} | |||
}; | |||