Browse Source

feat(mgb/plugin): add param json func for indexing oprs

GitOrigin-RevId: b5becbbc02
tags/v0.5.0
Megvii Engine Team Xu Xinran 5 years ago
parent
commit
af29fcb2e3
1 changed files with 58 additions and 0 deletions
  1. +58
    -0
      src/plugin/impl/opr_footprint.cpp

+ 58
- 0
src/plugin/impl/opr_footprint.cpp View File

@@ -512,6 +512,61 @@ REGISTE_PARAM_JSON_FUNC(DeformableConvBackwardFilter)
REGISTE_PARAM_JSON_FUNC(DeformableConvBackwardData)
REGISTE_PARAM_JSON_FUNC(BatchConvBiasForward)

template <>
std::shared_ptr<json::Value> opr_param_json_func<opr::Dimshuffle>(
cg::OperatorNodeBase * opr) {
auto param = opr->cast_final_safe<opr::Dimshuffle>().param();

auto pattern = json::Array::make();
for (size_t i = 0; i < param.pattern_len; i++)
pattern->add(json::NumberInt::make(param.pattern[i]));

return json::Object::make({
{"ndim", json::NumberInt::make(param.ndim)},
{"pattern", pattern},
});
}

template <>
std::shared_ptr<json::Value> opr_param_json_func<opr::AxisAddRemove>(
cg::OperatorNodeBase * opr) {
auto param = opr->cast_final_safe<opr::AxisAddRemove>().param();

auto desc = json::Array::make();
for (size_t i = 0; i < param.nr_desc; i++) {
auto axisdesc = param.desc[i];
desc->add(
json::Object::make({
{"method", json::NumberInt::make(
static_cast<int32_t>(axisdesc.method))},
{"axisnum", json::NumberInt::make(axisdesc.axis.get_raw())},
}));
}

return json::Object::make({
{"nr_desc", json::NumberInt::make(param.nr_desc)},
{"desc", desc},
});
}

template <>
std::shared_ptr<json::Value> opr_param_json_func<opr::Subtensor>(
cg::OperatorNodeBase * opr) {
auto desc = json::Array::make();
auto indices = opr->cast_final_safe<opr::Subtensor>().index_desc();
for (auto &index : indices){
desc->add(
json::Object::make({
{"axis", json::NumberInt::make(index.axis.get_raw())},
{"begin", json::NumberInt::make(index.begin.node() != nullptr)},
{"end", json::NumberInt::make(index.end.node() != nullptr)},
{"step", json::NumberInt::make(index.step.node() != nullptr)},
{"idx", json::NumberInt::make(index.idx.node() != nullptr)},
}));
}

return desc;
}
#endif // MGB_ENABLE_JSON

} // namespace
@@ -573,6 +628,9 @@ void OprFootprint::init_all_footprints() {
add_single_param_json<opr::GroupLocal>();
add_single_param_json<opr::LRN>();
add_single_param_json<opr::Concat>();
add_single_param_json<opr::Dimshuffle>();
add_single_param_json<opr::AxisAddRemove>();
add_single_param_json<opr::Subtensor>();
add_single_param_json<opr::Reduce>();
add_single_param_json<opr::LocalShareForward>();
add_single_param_json<opr::LocalShareBackwardData>();


Loading…
Cancel
Save