|
|
@@ -13,6 +13,7 @@ |
|
|
|
#include "megbrain/opr/basic_arith.h" |
|
|
|
#include "megbrain/opr/param_defs.h" |
|
|
|
#include "megbrain/opr/utility.h" |
|
|
|
#include "megbrain/opr/io.h" |
|
|
|
#include "megbrain/graph/event.h" |
|
|
|
#include "megbrain/comp_node_env.h" |
|
|
|
#include "megbrain/utils/arith_helper.h" |
|
|
@@ -1434,15 +1435,13 @@ void ParamPackConcat::on_output_comp_node_stream_changed(){ |
|
|
|
/* f{{{ ======================= ParamPackSplit ======================= */ |
|
|
|
|
|
|
|
MGB_DYN_TYPE_OBJ_FINAL_IMPL(ParamPackSplit); |
|
|
|
ParamPackSplit::ParamPackSplit(VarNode* src, VarNode* offsets, |
|
|
|
const std::vector<dt_int32> offsets_val, |
|
|
|
ParamPackSplit::ParamPackSplit(VarNode* src, |
|
|
|
const std::vector<dt_int32> offsets, |
|
|
|
TensorShapeArray& shapes, |
|
|
|
const OperatorNodeConfig& config) |
|
|
|
: Super{src->owner_graph(), config, "ParamPackSplit", {src, offsets}}, |
|
|
|
m_shapes(shapes), m_offsets(offsets_val) { |
|
|
|
mgb_assert(src->comp_node() == offsets->comp_node()); |
|
|
|
: Super{src->owner_graph(), config, "ParamPackSplit", {src}}, |
|
|
|
m_shapes(shapes), m_offsets(offsets) { |
|
|
|
add_input({src}); |
|
|
|
add_input({offsets}); |
|
|
|
|
|
|
|
for (size_t i = 0; i < shapes.size(); i++) { |
|
|
|
mgb_assert(shapes[i].total_nr_elems(), "empty param is not allowed!"); |
|
|
@@ -1456,14 +1455,13 @@ void ParamPackSplit::add_input_layout_constraint(){ |
|
|
|
} |
|
|
|
|
|
|
|
SymbolVarArray ParamPackSplit::make(const SymbolVar& src, |
|
|
|
const SymbolVar& offsets, |
|
|
|
const std::vector<dt_int32> offsets_val, |
|
|
|
const std::vector<dt_int32> offsets, |
|
|
|
TensorShapeArray shapes, |
|
|
|
const OperatorNodeConfig& config) { |
|
|
|
auto&& out = src.node() |
|
|
|
->owner_graph() |
|
|
|
->insert_opr(std::make_unique<ParamPackSplit>( |
|
|
|
src.node(), offsets.node(), offsets_val, |
|
|
|
src.node(), offsets, |
|
|
|
shapes, config)) |
|
|
|
->output(); |
|
|
|
|
|
|
@@ -1499,7 +1497,7 @@ void ParamPackSplit::init_output_static_infer_desc() { |
|
|
|
using namespace std::placeholders; |
|
|
|
auto&& mgr = owner_graph()->static_infer_manager(); |
|
|
|
|
|
|
|
DepVal shp_deps{{input(0), DepType::SHAPE}, {input(1), DepType::SHAPE}}; |
|
|
|
DepVal shp_deps{{input(0), DepType::SHAPE}}; |
|
|
|
|
|
|
|
for (size_t i = 0; i < output().size(); i++) { |
|
|
|
auto ov = output(i); |
|
|
@@ -1519,9 +1517,17 @@ MGB_IMPL_OPR_GRAD(ParamPackSplit) { |
|
|
|
} |
|
|
|
grad.emplace_back(gval); |
|
|
|
} |
|
|
|
auto offsets_val = opr.get_offsets(); |
|
|
|
auto cn = opr.input(0)->comp_node(); |
|
|
|
if (opr.config().has_comp_node_set()) { |
|
|
|
cn = opr.config().get_single_comp_node(); |
|
|
|
} |
|
|
|
HostTensorND hv{cn, TensorShape{offsets_val.size()}, dtype::Int32{}}; |
|
|
|
memcpy(hv.raw_ptr(), offsets_val.data(), offsets_val.size() * sizeof(int)); |
|
|
|
auto offsets = opr::ImmutableTensor::make(*opr.input(0)->owner_graph(), hv); |
|
|
|
|
|
|
|
return ParamPackConcat::make( |
|
|
|
grad, opr.input(1), opr.get_offsets(), |
|
|
|
grad, offsets, offsets_val, |
|
|
|
OperatorNodeConfig{}.follow_comp_node(opr.input(0))) |
|
|
|
.node(); |
|
|
|
} |
|
|
|