diff --git a/python_module/src/cpp/opr_defs.cpp b/python_module/src/cpp/opr_defs.cpp index 220605a9..57414eaf 100644 --- a/python_module/src/cpp/opr_defs.cpp +++ b/python_module/src/cpp/opr_defs.cpp @@ -47,16 +47,10 @@ SymbolVarArray _Opr::param_pack_split( } auto cn = src.node()->comp_node(); - auto offsets_val = megdnn::ParamPackConcat::gen_offsets( + auto offsets = megdnn::ParamPackConcat::gen_offsets( shapearr, cn.get_mem_addr_alignment(), src.dtype().size()); - if (config.has_comp_node_set()) { - cn = config.get_single_comp_node(); - } - HostTensorND hv{cn, TensorShape{{offsets_val.size()}}, dtype::Int32{}}; - memcpy(hv.raw_ptr(), offsets_val.data(), offsets_val.size() * sizeof(int)); - auto offsets = opr::ImmutableTensor::make(*src.node()->owner_graph(), hv); - return mgb::opr::ParamPackSplit::make(src, offsets, offsets_val, shapearr, config); + return mgb::opr::ParamPackSplit::make(src, offsets, shapearr, config); } #if MGB_ENABLE_OPR_MM diff --git a/src/opr/impl/tensor_manip.cpp b/src/opr/impl/tensor_manip.cpp index f493e241..ba93d3d3 100644 --- a/src/opr/impl/tensor_manip.cpp +++ b/src/opr/impl/tensor_manip.cpp @@ -13,6 +13,7 @@ #include "megbrain/opr/basic_arith.h" #include "megbrain/opr/param_defs.h" #include "megbrain/opr/utility.h" +#include "megbrain/opr/io.h" #include "megbrain/graph/event.h" #include "megbrain/comp_node_env.h" #include "megbrain/utils/arith_helper.h" @@ -1434,15 +1435,13 @@ void ParamPackConcat::on_output_comp_node_stream_changed(){ /* f{{{ ======================= ParamPackSplit ======================= */ MGB_DYN_TYPE_OBJ_FINAL_IMPL(ParamPackSplit); -ParamPackSplit::ParamPackSplit(VarNode* src, VarNode* offsets, - const std::vector offsets_val, +ParamPackSplit::ParamPackSplit(VarNode* src, + const std::vector offsets, TensorShapeArray& shapes, const OperatorNodeConfig& config) - : Super{src->owner_graph(), config, "ParamPackSplit", {src, offsets}}, - m_shapes(shapes), m_offsets(offsets_val) { - mgb_assert(src->comp_node() == offsets->comp_node()); + : Super{src->owner_graph(), config, "ParamPackSplit", {src}}, + m_shapes(shapes), m_offsets(offsets) { add_input({src}); - add_input({offsets}); for (size_t i = 0; i < shapes.size(); i++) { mgb_assert(shapes[i].total_nr_elems(), "empty param is not allowed!"); @@ -1456,14 +1455,13 @@ void ParamPackSplit::add_input_layout_constraint(){ } SymbolVarArray ParamPackSplit::make(const SymbolVar& src, - const SymbolVar& offsets, - const std::vector offsets_val, + const std::vector offsets, TensorShapeArray shapes, const OperatorNodeConfig& config) { auto&& out = src.node() ->owner_graph() ->insert_opr(std::make_unique( - src.node(), offsets.node(), offsets_val, + src.node(), offsets, shapes, config)) ->output(); @@ -1499,7 +1497,7 @@ void ParamPackSplit::init_output_static_infer_desc() { using namespace std::placeholders; auto&& mgr = owner_graph()->static_infer_manager(); - DepVal shp_deps{{input(0), DepType::SHAPE}, {input(1), DepType::SHAPE}}; + DepVal shp_deps{{input(0), DepType::SHAPE}}; for (size_t i = 0; i < output().size(); i++) { auto ov = output(i); @@ -1519,9 +1517,17 @@ MGB_IMPL_OPR_GRAD(ParamPackSplit) { } grad.emplace_back(gval); } + auto offsets_val = opr.get_offsets(); + auto cn = opr.input(0)->comp_node(); + if (opr.config().has_comp_node_set()) { + cn = opr.config().get_single_comp_node(); + } + HostTensorND hv{cn, TensorShape{offsets_val.size()}, dtype::Int32{}}; + memcpy(hv.raw_ptr(), offsets_val.data(), offsets_val.size() * sizeof(int)); + auto offsets = opr::ImmutableTensor::make(*opr.input(0)->owner_graph(), hv); return ParamPackConcat::make( - grad, opr.input(1), opr.get_offsets(), + grad, offsets, offsets_val, OperatorNodeConfig{}.follow_comp_node(opr.input(0))) .node(); } diff --git a/src/opr/impl/tensor_manip.sereg.h b/src/opr/impl/tensor_manip.sereg.h index ff8c1615..b6a57460 100644 --- a/src/opr/impl/tensor_manip.sereg.h +++ b/src/opr/impl/tensor_manip.sereg.h @@ -162,7 +162,7 @@ namespace opr { auto &&offsets = opr.get_offsets(); auto &&shape = opr.get_output_shapes(); - return ParamPackSplit::make(inputs[0], inputs[1], offsets, shape, config).at(0). + return ParamPackSplit::make(inputs[0], offsets, shape, config).at(0). node()->owner_opr(); } diff --git a/src/opr/include/megbrain/opr/tensor_manip.h b/src/opr/include/megbrain/opr/tensor_manip.h index 2bc23ff3..4d80558a 100644 --- a/src/opr/include/megbrain/opr/tensor_manip.h +++ b/src/opr/include/megbrain/opr/tensor_manip.h @@ -600,12 +600,11 @@ MGB_DEFINE_OPR_CLASS(ParamPackSplit, cg::SingleCNOperatorNodeBase) // { void add_input_layout_constraint() override; public: - ParamPackSplit(VarNode* src, VarNode* offsets, - const std::vector offsets_val, + ParamPackSplit(VarNode* src, const std::vector offsets, TensorShapeArray& shapes, const OperatorNodeConfig& config); - static SymbolVarArray make(const SymbolVar& src, const SymbolVar& offsets, - const std::vector offsets_val, + static SymbolVarArray make(const SymbolVar& src, + const std::vector offsets, TensorShapeArray shapes, const OperatorNodeConfig& config = {}); diff --git a/src/opr/test/tensor_manip.cpp b/src/opr/test/tensor_manip.cpp index f3694c21..510506bb 100644 --- a/src/opr/test/tensor_manip.cpp +++ b/src/opr/test/tensor_manip.cpp @@ -1952,9 +1952,7 @@ void test_param_pack_split(const TensorShapeArray& shapes) { .comp_node(cn) .resize({offsets_val.size()}) .ptr()); - auto sym_offsets = opr::SharedDeviceTensor::make( - *inputs[0].node()->owner_graph(), offsets); - auto out = opr::ParamPackSplit::make(inputs[0], sym_offsets, offsets_val, + auto out = opr::ParamPackSplit::make(inputs[0], offsets_val, shapes); mgb_assert(out.size() == nr_out); typename Checker::SymOutArray ret;