|
|
@@ -512,6 +512,61 @@ REGISTE_PARAM_JSON_FUNC(DeformableConvBackwardFilter) |
|
|
|
REGISTE_PARAM_JSON_FUNC(DeformableConvBackwardData) |
|
|
|
REGISTE_PARAM_JSON_FUNC(BatchConvBiasForward) |
|
|
|
|
|
|
|
template <> |
|
|
|
std::shared_ptr<json::Value> opr_param_json_func<opr::Dimshuffle>( |
|
|
|
cg::OperatorNodeBase * opr) { |
|
|
|
auto param = opr->cast_final_safe<opr::Dimshuffle>().param(); |
|
|
|
|
|
|
|
auto pattern = json::Array::make(); |
|
|
|
for (size_t i = 0; i < param.pattern_len; i++) |
|
|
|
pattern->add(json::NumberInt::make(param.pattern[i])); |
|
|
|
|
|
|
|
return json::Object::make({ |
|
|
|
{"ndim", json::NumberInt::make(param.ndim)}, |
|
|
|
{"pattern", pattern}, |
|
|
|
}); |
|
|
|
} |
|
|
|
|
|
|
|
template <> |
|
|
|
std::shared_ptr<json::Value> opr_param_json_func<opr::AxisAddRemove>( |
|
|
|
cg::OperatorNodeBase * opr) { |
|
|
|
auto param = opr->cast_final_safe<opr::AxisAddRemove>().param(); |
|
|
|
|
|
|
|
auto desc = json::Array::make(); |
|
|
|
for (size_t i = 0; i < param.nr_desc; i++) { |
|
|
|
auto axisdesc = param.desc[i]; |
|
|
|
desc->add( |
|
|
|
json::Object::make({ |
|
|
|
{"method", json::NumberInt::make( |
|
|
|
static_cast<int32_t>(axisdesc.method))}, |
|
|
|
{"axisnum", json::NumberInt::make(axisdesc.axis.get_raw())}, |
|
|
|
})); |
|
|
|
} |
|
|
|
|
|
|
|
return json::Object::make({ |
|
|
|
{"nr_desc", json::NumberInt::make(param.nr_desc)}, |
|
|
|
{"desc", desc}, |
|
|
|
}); |
|
|
|
} |
|
|
|
|
|
|
|
template <> |
|
|
|
std::shared_ptr<json::Value> opr_param_json_func<opr::Subtensor>( |
|
|
|
cg::OperatorNodeBase * opr) { |
|
|
|
auto desc = json::Array::make(); |
|
|
|
auto indices = opr->cast_final_safe<opr::Subtensor>().index_desc(); |
|
|
|
for (auto &index : indices){ |
|
|
|
desc->add( |
|
|
|
json::Object::make({ |
|
|
|
{"axis", json::NumberInt::make(index.axis.get_raw())}, |
|
|
|
{"begin", json::NumberInt::make(index.begin.node() != nullptr)}, |
|
|
|
{"end", json::NumberInt::make(index.end.node() != nullptr)}, |
|
|
|
{"step", json::NumberInt::make(index.step.node() != nullptr)}, |
|
|
|
{"idx", json::NumberInt::make(index.idx.node() != nullptr)}, |
|
|
|
})); |
|
|
|
} |
|
|
|
|
|
|
|
return desc; |
|
|
|
} |
|
|
|
#endif // MGB_ENABLE_JSON |
|
|
|
|
|
|
|
} // namespace |
|
|
@@ -573,6 +628,9 @@ void OprFootprint::init_all_footprints() { |
|
|
|
add_single_param_json<opr::GroupLocal>(); |
|
|
|
add_single_param_json<opr::LRN>(); |
|
|
|
add_single_param_json<opr::Concat>(); |
|
|
|
add_single_param_json<opr::Dimshuffle>(); |
|
|
|
add_single_param_json<opr::AxisAddRemove>(); |
|
|
|
add_single_param_json<opr::Subtensor>(); |
|
|
|
add_single_param_json<opr::Reduce>(); |
|
|
|
add_single_param_json<opr::LocalShareForward>(); |
|
|
|
add_single_param_json<opr::LocalShareBackwardData>(); |
|
|
|