From 37b67c9b2be43997b11a5676e831c2387b2eaa87 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 10 Apr 2020 18:21:45 +0800 Subject: [PATCH] refactor(dnn/parampack): reduce param pack memory use GitOrigin-RevId: a802a14e8dbb2b291f05862bd9f0a12622d57f0c --- dnn/src/common/param_pack.cpp | 24 ++++++++-------- dnn/src/cuda/param_pack/opr_impl.cpp | 24 ++++++++-------- dnn/src/cuda/param_pack/param_pack.cu | 35 ++++++++++++++--------- dnn/src/cuda/param_pack/param_pack.cuh | 5 ++-- dnn/src/naive/param_pack/opr_impl.cpp | 44 +++++++++++++++-------------- src/opr/impl/tensor_manip.cpp | 28 ++++++++++-------- src/opr/impl/tensor_manip.sereg.h | 42 +++++++++++---------------- src/opr/include/megbrain/opr/tensor_manip.h | 25 ++++++++++------ src/opr/test/tensor_manip.cpp | 2 +- 9 files changed, 120 insertions(+), 109 deletions(-) diff --git a/dnn/src/common/param_pack.cpp b/dnn/src/common/param_pack.cpp index bd5e5f77..4eb9de4d 100644 --- a/dnn/src/common/param_pack.cpp +++ b/dnn/src/common/param_pack.cpp @@ -15,18 +15,16 @@ using namespace megdnn; void ParamPackConcatSplitBase::check_exec(const TensorLayout& concated, - const TensorLayout& table, + const TensorLayout& offsets, const TensorLayout& parts) { - megdnn_assert(table.dtype == dtype::Int32{}, "bad dtype: %s", - table.dtype.name()); - megdnn_assert(concated.ndim == 1 && table.ndim == 1 && parts.ndim == 1 && - concated.stride[0] == 1 && table.stride[0] == 1 && + megdnn_assert(offsets.dtype == dtype::Int32{}, "bad dtype: %s", + offsets.dtype.name()); + megdnn_assert(concated.ndim == 1 && offsets.ndim == 1 && parts.ndim == 1 && + concated.stride[0] == 1 && offsets.stride[0] == 1 && parts.stride[0] == 1, - "bad layout: concated=%s table=%s parts=%s", - concated.to_string().c_str(), table.to_string().c_str(), + "bad layout: concated=%s offsets=%s parts=%s", + concated.to_string().c_str(), offsets.to_string().c_str(), parts.to_string().c_str()); - megdnn_assert(table.shape[0] == concated.shape[0] * 2, - "concated=%zu table=%zu", concated.shape[0], table.shape[0]); } std::vector ParamPackConcatSplitBase::gen_offsets( @@ -46,11 +44,13 @@ std::vector ParamPackConcatSplitBase::gen_offsets( return v + ((alignment - mod) & (alignment - 1)); }; - std::vector offsets(shapes.size()); + std::vector offsets(shapes.size() << 1); size_t offset = 0; for (size_t i = 0; i < shapes.size(); i++) { - offsets[i] = offset; - offset = get_aligned(offset) + shapes[i].total_nr_elems(); + offset = get_aligned(offset); + offsets[i * 2] = offset; + offset += shapes[i].total_nr_elems(); + offsets[i * 2 + 1] = offset; } return offsets; } diff --git a/dnn/src/cuda/param_pack/opr_impl.cpp b/dnn/src/cuda/param_pack/opr_impl.cpp index ab167735..fb521eae 100644 --- a/dnn/src/cuda/param_pack/opr_impl.cpp +++ b/dnn/src/cuda/param_pack/opr_impl.cpp @@ -24,7 +24,7 @@ size_t ParamPackConcatImpl::get_workspace_in_bytes(const TensorShapeArray& srcs, template void ParamPackConcatImpl::exec_internal(_megdnn_tensor_in srcs, - _megdnn_tensor_in table, + _megdnn_tensor_in offsets, _megdnn_tensor_out dst, _megdnn_workspace workspace) { size_t inp_size = srcs.layout.shape[0], @@ -35,25 +35,25 @@ void ParamPackConcatImpl::exec_internal(_megdnn_tensor_in srcs, megdnn_assert_internal(src_cpu); auto src_gpu = reinterpret_cast(workspace.raw_ptr); - auto table_outer_gpu = table.ptr(), - table_inner_gpu = table_outer_gpu + out_size; + auto offsets_gpu = offsets.ptr(); cuda_check(cudaMemcpyAsync(src_gpu, src_cpu, sizeof(const T*) * inp_size, cudaMemcpyHostToDevice, stream)); - param_pack::concat_proxy(src_gpu, dst.ptr(), out_size, - table_outer_gpu, table_inner_gpu, stream); + param_pack::concat_proxy(src_gpu, dst.ptr(), inp_size, out_size, + offsets_gpu, stream); } -void ParamPackConcatImpl::exec(_megdnn_tensor_in srcs, _megdnn_tensor_in table, +void ParamPackConcatImpl::exec(_megdnn_tensor_in srcs, + _megdnn_tensor_in offsets, _megdnn_tensor_out dst, _megdnn_workspace workspace) { - check_exec(dst.layout, table.layout, srcs.layout); -#define cb(DType) \ - if (dst.layout.dtype == DType()) { \ - using ctype = typename DTypeTrait::ctype; \ - exec_internal(srcs, table, dst, workspace); \ - return; \ + check_exec(dst.layout, offsets.layout, srcs.layout); +#define cb(DType) \ + if (dst.layout.dtype == DType()) { \ + using ctype = typename DTypeTrait::ctype; \ + exec_internal(srcs, offsets, dst, workspace); \ + return; \ } MEGDNN_FOREACH_COMPUTING_DTYPE(cb) megdnn_throw("bad type"); diff --git a/dnn/src/cuda/param_pack/param_pack.cu b/dnn/src/cuda/param_pack/param_pack.cu index 03e98509..1939d002 100644 --- a/dnn/src/cuda/param_pack/param_pack.cu +++ b/dnn/src/cuda/param_pack/param_pack.cu @@ -19,17 +19,24 @@ namespace param_pack { template __global__ void concat_kernel(const T** srcs, T* dst, - const int32_t* table_outer, - const int32_t* table_inner, + const int32_t* offsets, + size_t srcs_size, size_t total_size) { size_t addr = threadIdx.x + blockIdx.x * blockDim.x; if (addr < total_size) { - int32_t i = table_outer[addr]; - int32_t idx = table_inner[addr]; - if (idx != -1) - dst[addr] = srcs[i][idx]; - else + size_t l = 0, r = srcs_size - 1, mid; + while (l < r) { + mid = (l + r) >> 1; + if (offsets[(mid << 1) + 1] > addr) { + r = mid; + } else { + l = mid + 1; + } + } + if (addr < offsets[l << 1]) dst[addr] = 0; + else + dst[addr] = srcs[l][addr - offsets[l << 1]]; } } @@ -59,20 +66,20 @@ void split_proxy(const T* src, T** dsts, size_t total_size, } template -void concat_proxy(const T** srcs, T* dst, size_t total_size, - const int32_t* table_outer, - const int32_t* table_inner, cudaStream_t stream) { +void concat_proxy(const T** srcs, T* dst, size_t srcs_size, size_t total_size, + const int32_t* offsets, + cudaStream_t stream) { size_t NR_BLOCKS = DIVUP(total_size, NR_THREADS); concat_kernel<<>>( - srcs, dst, table_outer, table_inner, total_size); + srcs, dst, offsets, srcs_size, total_size); after_kernel_launch(); } #define INST(T) \ - template void concat_proxy(const T**, T*, size_t, \ - const int32_t*, const int32_t*, \ + template void concat_proxy(const T**, T*, size_t, size_t, \ + const int32_t*, \ cudaStream_t); \ - template void split_proxy(const T*, T**, size_t, \ + template void split_proxy(const T*, T**, size_t, \ const int32_t*, const int32_t*, \ cudaStream_t); #define cb(DType) INST(typename DTypeTrait::ctype) diff --git a/dnn/src/cuda/param_pack/param_pack.cuh b/dnn/src/cuda/param_pack/param_pack.cuh index 4946f05b..53dc3e9c 100644 --- a/dnn/src/cuda/param_pack/param_pack.cuh +++ b/dnn/src/cuda/param_pack/param_pack.cuh @@ -25,9 +25,8 @@ void split_proxy(const T* src, T** dsts, size_t total_size, cudaStream_t stream); template -void concat_proxy(const T** srcs, T* dst, size_t total_size, - const int32_t* table_outer, - const int32_t* table_inner, cudaStream_t stream); +void concat_proxy(const T** srcs, T* dst, size_t srcs_size, size_t total_size, + const int32_t* offsets, cudaStream_t stream); } // namespace param_pack } // namespace cuda diff --git a/dnn/src/naive/param_pack/opr_impl.cpp b/dnn/src/naive/param_pack/opr_impl.cpp index 374b1f6b..8b15ce9b 100644 --- a/dnn/src/naive/param_pack/opr_impl.cpp +++ b/dnn/src/naive/param_pack/opr_impl.cpp @@ -54,38 +54,40 @@ void ParamPackSplitImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in table, } template -void ParamPackConcatImpl::exec_internal(_megdnn_tensor_in srcs, int32_t* table, +void ParamPackConcatImpl::exec_internal(_megdnn_tensor_in srcs, + int32_t* offsets, _megdnn_tensor_out dst, _megdnn_workspace) { - size_t out_size = dst.layout.total_nr_elems(); - auto srcs_ptr = static_cast(srcs.raw_ptr); auto dst_ptr = dst.ptr(); - auto table_outer = table, table_inner = table_outer + out_size; - - for (size_t j = 0; j < out_size; j++) { - int32_t i = table_outer[j]; - int32_t idx = table_inner[j]; - if (idx != -1) - dst_ptr[j] = srcs_ptr[i][idx]; - else - dst_ptr[j] = 0; + int32_t last_pos = 0; + for (size_t i = 0; i < srcs.layout[0]; i++) { + int32_t begin = offsets[i * 2], end = offsets[i * 2 + 1]; + while (last_pos < begin) { + dst_ptr[last_pos] = 0; + last_pos++; + } + for (int32_t j = 0; j < end - begin; j++) { + dst_ptr[begin + j] = srcs_ptr[i][j]; + } + last_pos = end; } } -void ParamPackConcatImpl::exec(_megdnn_tensor_in srcs, _megdnn_tensor_in table, +void ParamPackConcatImpl::exec(_megdnn_tensor_in srcs, + _megdnn_tensor_in offsets, _megdnn_tensor_out dst, _megdnn_workspace workspace) { - check_exec(dst.layout, table.layout, srcs.layout); - auto table_ptr = table.ptr(); + check_exec(dst.layout, offsets.layout, srcs.layout); + auto offsets_ptr = offsets.ptr(); -#define cb(DType) \ - if (dst.layout.dtype == DType()) { \ - using ctype = typename DTypeTrait::ctype; \ - MEGDNN_DISPATCH_CPU_KERN_OPR( \ - exec_internal(srcs, table_ptr, dst, workspace)); \ - return; \ +#define cb(DType) \ + if (dst.layout.dtype == DType()) { \ + using ctype = typename DTypeTrait::ctype; \ + MEGDNN_DISPATCH_CPU_KERN_OPR( \ + exec_internal(srcs, offsets_ptr, dst, workspace)); \ + return; \ } MEGDNN_FOREACH_COMPUTING_DTYPE(cb) megdnn_throw("bad type"); diff --git a/src/opr/impl/tensor_manip.cpp b/src/opr/impl/tensor_manip.cpp index 76f2ba33..d29a2524 100644 --- a/src/opr/impl/tensor_manip.cpp +++ b/src/opr/impl/tensor_manip.cpp @@ -1339,8 +1339,10 @@ void Concat::init_output_comp_node() { MGB_DYN_TYPE_OBJ_FINAL_IMPL(ParamPackConcat); ParamPackConcat::ParamPackConcat(VarNodeArray& inp, VarNode* table, + const std::vector offsets_val, const OperatorNodeConfig& config) - : Super(inp[0]->owner_graph(), config, "ParamPackConcat", inp) { + : Super(inp[0]->owner_graph(), config, "ParamPackConcat", inp), + m_offsets(offsets_val) { CompNode cn = inp[0]->comp_node(); add_input({inp[0]}); for (size_t i = 1; i < inp.size(); i++) { @@ -1361,14 +1363,16 @@ void ParamPackConcat::add_input_layout_constraint(){ } } -SymbolVar ParamPackConcat::make(const SmallVector &inp, - const SymbolVar &table, const OperatorNodeConfig& config) { +SymbolVar ParamPackConcat::make(const SmallVector& inp, + const SymbolVar& offsets, + const std::vector offsets_val, + const OperatorNodeConfig& config) { VarNodeArray array(inp.size()); for (size_t i = 0; i < inp.size(); i++) { array[i] = inp[i].node(); } - return inp.front(). - insert_single_output_opr(array, table.node(), config); + return inp.front().insert_single_output_opr( + array, offsets.node(), offsets_val, config); } void ParamPackConcat::scn_do_execute() { @@ -1379,13 +1383,13 @@ void ParamPackConcat::scn_do_execute() { for (size_t i = 0; i < inputs.size() - 1; i++) { ptr[i] = inputs[i]->dev_tensor().as_megdnn().raw_ptr; } - auto table = inputs.back()->dev_tensor().as_megdnn(); + auto offsets = inputs.back()->dev_tensor().as_megdnn(); megdnn::TensorND srcs( ptr, megdnn::TensorLayout({inputs.size() - 1}, dtype::Int32())); auto&& dst = output(0)->dev_tensor().as_megdnn(); - m_opr->exec(srcs, table, dst, get_megdnn_workspace_from_var(output(1))); + m_opr->exec(srcs, offsets, dst, get_megdnn_workspace_from_var(output(1))); } void ParamPackConcat::init_output_dtype() { @@ -1396,8 +1400,8 @@ void ParamPackConcat::init_output_static_infer_desc(){ using namespace cg::static_infer; auto &&mgr = owner_graph()->static_infer_manager(); - auto infer_out = [](TensorShape &dest, const InpVal &inp) { - dest = {inp.val.back().shape().total_nr_elems()/2}; + auto infer_out = [this](TensorShape &dest, const InpVal &inp) { + dest = {m_offsets.back()}; return true; }; DepVal shp_deps; @@ -1480,10 +1484,10 @@ void ParamPackSplit::init_output_dtype() { } void ParamPackSplit::mem_plan_fwd_in2out_readonly() { - mgb_assert(m_offsets.size() == output().size()); + mgb_assert(m_offsets.size() == output().size() * 2); for (size_t i = 0; i < output().size(); i++) { auto layout = output(i)->layout(); - auto spec = SubTensorSpec::make_from_offset_elem(layout, m_offsets[i]); + auto spec = SubTensorSpec::make_from_offset_elem(layout, m_offsets[i * 2]); m_mem_fwd_success[i] = output(i)->set_fwd_in2out_readonly( input(0), spec); mgb_assert(m_mem_fwd_success[i]); @@ -1524,7 +1528,7 @@ MGB_IMPL_OPR_GRAD(ParamPackSplit) { } return ParamPackConcat::make( - grad, opr.input(1), + grad, opr.input(1), opr.get_offsets(), OperatorNodeConfig{}.follow_comp_node(opr.input(0))) .node(); } diff --git a/src/opr/impl/tensor_manip.sereg.h b/src/opr/impl/tensor_manip.sereg.h index 4e09bcbf..ff8c1615 100644 --- a/src/opr/impl/tensor_manip.sereg.h +++ b/src/opr/impl/tensor_manip.sereg.h @@ -32,31 +32,6 @@ namespace serialization { public OprMakerVariadic{}; template<> - struct OprLoadDumpImpl - { - using ParamPackConcat = opr::ParamPackConcat; - using Param = opr::ParamPackConcat::Param; - - static void dump(OprDumpContext &ctx, - const cg::OperatorNodeBase &opr_) { - auto &&opr = opr_.cast_final_safe(); - ctx.write_param(opr.param()); - } - - static cg::OperatorNodeBase* load( - OprLoadContext &ctx, const cg::VarNodeArray &inputs, - const OperatorNodeConfig &config) { - auto param = ctx.read_param(); - mgb_assert(!inputs.empty()); - SymbolVarArray ivar{inputs.size() - 1}; - for (size_t i = 0; i < inputs.size() - 1; ++ i) - ivar[i] = inputs[i]; - return ParamPackConcat::make(ivar, inputs.back(), - param, config).node()->owner_opr(); - } - }; - - template<> struct OprLoadDumpImpl { using Split = opr::Split; using Options = Split::Options; @@ -151,7 +126,6 @@ namespace opr { MGB_SEREG_OPR(Dimshuffle, 1); MGB_SEREG_OPR(AxisAddRemove, 1); MGB_SEREG_OPR(Concat, 0); - MGB_SEREG_OPR(ParamPackConcat, 0); using GetVarShapeV1 = opr::GetVarShape; MGB_SEREG_OPR(GetVarShapeV1, 0); using ReshapeV1 = opr::Reshape; @@ -193,6 +167,22 @@ namespace opr { } MGB_REG_OPR_SHALLOW_COPY(ParamPackSplit, opr_shallow_copy_param_pack_split); + + cg::OperatorNodeBase* opr_shallow_copy_param_pack_concat( + const serialization::OprShallowCopyContext &ctx, + const cg::OperatorNodeBase &opr_, const VarNodeArray &inputs, + const OperatorNodeConfig &config){ + auto &&opr = opr_.cast_final_safe(); + auto &&offsets = opr.get_offsets(); + + SymbolVarArray ivar{inputs.size() - 1}; + for (size_t i = 0; i < inputs.size() - 1; ++i) + ivar[i] = inputs[i]; + return ParamPackConcat::make(ivar, inputs.back(), offsets, config). + node()->owner_opr(); + } + + MGB_REG_OPR_SHALLOW_COPY(ParamPackConcat, opr_shallow_copy_param_pack_concat); MGB_SEREG_OPR(RelayoutFormat, 1); MGB_SEREG_OPR(WinogradFilterPreprocess, 1); } // namespace opr diff --git a/src/opr/include/megbrain/opr/tensor_manip.h b/src/opr/include/megbrain/opr/tensor_manip.h index b267fb9c..c81f34bb 100644 --- a/src/opr/include/megbrain/opr/tensor_manip.h +++ b/src/opr/include/megbrain/opr/tensor_manip.h @@ -539,6 +539,7 @@ MGB_DEFINE_OPR_CLASS(Concat, cg::SingleCNOutshapePureByInshapeOprBase) // { MGB_DEFINE_OPR_CLASS(ParamPackConcat, cg::SingleCNOperatorNodeBase) // { //! input pointer buffer SmallVector m_inp_ptr; + std::vector m_offsets; intl::UniqPtrWithCN m_opr; void add_input_layout_constraint() override; @@ -554,15 +555,23 @@ public: return {}; } - ParamPackConcat(VarNodeArray &inp, VarNode *table, - const OperatorNodeConfig &config); - static SymbolVar make(const SmallVector &inp, - const SymbolVar &table, const OperatorNodeConfig &config = {}); + ParamPackConcat(VarNodeArray& inp, VarNode* offsets, + const std::vector offsets_val, + const OperatorNodeConfig& config); + static SymbolVar make(const SmallVector& inp, + const SymbolVar& offsets, + const std::vector offsets_val, + const OperatorNodeConfig& config = {}); + + static SymbolVar make(const SmallVector& inp, + const SymbolVar& offsets, + const std::vector offsets_val, const Param&, + const OperatorNodeConfig& config) { + return make(inp, offsets, offsets_val, config); + } - static SymbolVar make(const SmallVector &inp, - const SymbolVar &table, const Param &, - const OperatorNodeConfig &config) { - return make(inp, table, config); + const std::vector& get_offsets() const { + return m_offsets; } }; diff --git a/src/opr/test/tensor_manip.cpp b/src/opr/test/tensor_manip.cpp index 45635aca..55864658 100644 --- a/src/opr/test/tensor_manip.cpp +++ b/src/opr/test/tensor_manip.cpp @@ -1906,7 +1906,7 @@ void test_param_pack_concat(const TensorShapeArray &shapes, DType type){ memcpy(host_table->raw_ptr(), host_table_gen.data(), size * 8); auto table = opr::Host2DeviceCopy::make(*graph, host_table); - auto z = opr::ParamPackConcat::make(srcs, table); + auto z = opr::ParamPackConcat::make(srcs, table, host_table_gen); HostTensorND host_z; auto func = graph->compile({make_callback_copy(z, host_z)});