You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

group_norm.cpp 3.5 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. #include "megbrain/opr/dnn/group_norm.h"
  2. #include "megbrain/imperative/ops/autogen.h"
  3. #include "megbrain/opr/internal/megdnn_opr_wrapper.h"
  4. #include "../blob_manager_impl.h"
  5. #include "../dnn_op_helper.h"
  6. #include "../op_trait.h"
  7. namespace mgb::imperative {
  8. namespace group_norm {
  9. cg::OperatorNodeBase* apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  10. auto&& op = static_cast<const GroupNorm&>(def);
  11. size_t nr_inp = inputs.size();
  12. auto p = op.param();
  13. mgb_assert((nr_inp == 3 && p.affine) || (nr_inp == 1 && !p.affine));
  14. OperatorNodeConfig config{op.make_name()};
  15. if (nr_inp == 3) {
  16. return opr::GroupNorm::make(
  17. inputs[0], inputs[1], inputs[2], op.param(), config)[0]
  18. .node()
  19. ->owner_opr();
  20. } else {
  21. return opr::GroupNorm::make(inputs[0], op.param(), config)[0]
  22. .node()
  23. ->owner_opr();
  24. }
  25. }
  26. std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
  27. const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
  28. auto&& group_norm = def.cast_final_safe<GroupNorm>();
  29. size_t nr_inp = inputs.size();
  30. auto affine = group_norm.affine;
  31. mgb_assert(
  32. (nr_inp == 3 && affine) || (nr_inp == 1 && !affine),
  33. "num of inputs of pooling should be 1 or 3 but you give %zu",
  34. inputs.size());
  35. auto&& inp = inputs[0];
  36. auto& inp_cn = inp.comp_node;
  37. if (inp.layout.ndim == 0) {
  38. return {{{TensorLayout{inp.layout.dtype}, inp_cn, {}},
  39. {TensorLayout{dtype::Float32()}, inp_cn, {}},
  40. {TensorLayout{dtype::Float32()}, inp_cn, {}}},
  41. false};
  42. }
  43. DnnOprHelper<megdnn::GroupNorm> dnn_opr(group_norm.param());
  44. auto&& [oup_layout, mean_layout, rstd_layout] =
  45. dnn_opr.deduce_layouts<3>(inp.layout, TensorLayout{}, TensorLayout{});
  46. return {{{oup_layout, inp_cn, {}},
  47. {mean_layout, inp_cn, {}},
  48. {rstd_layout, inp_cn, {}}},
  49. true};
  50. }
  51. SmallVector<TensorPtr> apply_on_physical_tensor(
  52. const OpDef& def, const SmallVector<TensorPtr>& inputs,
  53. SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
  54. auto&& op_def = def.cast_final_safe<GroupNorm>();
  55. size_t nr_inp = inputs.size();
  56. auto p = op_def.param();
  57. mgb_assert(
  58. (nr_inp == 3 && p.affine) || (nr_inp == 1 && !p.affine),
  59. "num of inputs of groupnorm should be 1 or 3 but you give %zu",
  60. inputs.size());
  61. auto cn = inputs[0]->comp_node();
  62. DnnOprCaller<megdnn::GroupNorm> caller(cn, op_def.param());
  63. auto&& [oup_layout, mean_layout, rstd_layout] = caller.deduce_layouts<3>(
  64. inputs[0]->layout(), TensorLayout{}, TensorLayout{});
  65. auto out = Tensor::make(oup_layout, cn);
  66. auto mean = Tensor::make(mean_layout, cn);
  67. auto rstd = Tensor::make(rstd_layout, cn);
  68. if (p.affine) {
  69. caller.exec_with_ws(inputs[0], inputs[1], inputs[2], out, mean, rstd);
  70. } else {
  71. megdnn::TensorND empty_dnn;
  72. caller.exec_with_ws(inputs[0], empty_dnn, empty_dnn, out, mean, rstd);
  73. }
  74. return {out, mean, rstd};
  75. }
  76. OP_TRAIT_REG(GroupNorm, GroupNorm)
  77. .apply_on_var_node(apply_on_var_node)
  78. .infer_output_attrs_fallible(infer_output_attrs_fallible)
  79. .apply_on_physical_tensor(apply_on_physical_tensor)
  80. .fallback();
  81. } // namespace group_norm
  82. } // namespace mgb::imperative