You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

adaptive_pooling.cpp 5.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. #include "megbrain/opr/dnn/adaptive_pooling.h"
  2. #include "../internal/megdnn_opr_wrapper.inl"
  3. #include "megbrain/graph/grad_impl.h"
  4. #include "megbrain/opr/utility.h"
  5. #include "megdnn/opr_param_defs.h"
  6. #include "megdnn/oprs/nn.h"
  7. using namespace mgb;
  8. using namespace opr;
  9. MGB_DYN_TYPE_OBJ_FINAL_IMPL(AdaptivePoolingForward);
  10. AdaptivePoolingForward::AdaptivePoolingForward(
  11. VarNode* src, VarNode* out_shape, const Param& param,
  12. const OperatorNodeConfig& config)
  13. : Super(OperatorNodeBaseCtorParam{
  14. src->owner_graph(), config, "adaptive_pooling", {src, out_shape}}) {
  15. init_megdnn_opr(*this, param);
  16. add_input({src, out_shape});
  17. outshape_by_symvar_enable(1, 1);
  18. }
  19. SymbolVar AdaptivePoolingForward::make(
  20. SymbolVar src, SymbolVar out_shape, const Param& param,
  21. const OperatorNodeConfig& config) {
  22. return src.insert_single_output_opr<AdaptivePoolingForward>(
  23. src.node(), out_shape.node(), param, config);
  24. }
  25. void AdaptivePoolingForward::scn_do_execute() {
  26. megdnn_opr()->exec(
  27. input(0)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(),
  28. intl::get_megdnn_workspace_from_var(output().back()));
  29. }
  30. void AdaptivePoolingForward::outshape_by_symvar_do_get_output_shape(
  31. TensorShape& dest, const ShapeInferInfo& shpinfo) {
  32. TensorShape oshp2d;
  33. cg::copy_tensor_value_to_shape(oshp2d, *shpinfo.shpval_inp_val.at(0));
  34. auto src = shpinfo.shape_inp_shp.at(0);
  35. mgb_assert(
  36. (src.ndim == 4 || src.ndim == 5) && (oshp2d.ndim == 2 || oshp2d.ndim == 1),
  37. "shape mismatch for AdaptivePooling: src=%s, out2d=%s",
  38. src.to_string().c_str(), oshp2d.to_string().c_str());
  39. auto param_format = param().format;
  40. bool tshp1n = oshp2d.ndim == 1;
  41. if (param_format == Param::Format::NCHW) {
  42. dest.ndim = 4;
  43. dest.shape[0] = src.shape[0];
  44. dest.shape[1] = src.shape[1];
  45. dest.shape[2] = oshp2d.shape[0];
  46. dest.shape[3] = (tshp1n) ? oshp2d.shape[0] : oshp2d.shape[1];
  47. } else if (param_format == Param::Format::NHWC) {
  48. dest.ndim = 4;
  49. dest.shape[0] = src.shape[0];
  50. dest.shape[1] = oshp2d.shape[0];
  51. dest.shape[2] = (tshp1n) ? oshp2d.shape[0] : oshp2d.shape[1];
  52. dest.shape[3] = src.shape[3];
  53. } else if (
  54. param_format == Param::Format::NCHW44 ||
  55. param_format == Param::Format::NCHW88) {
  56. dest.ndim = 5;
  57. dest.shape[0] = src.shape[0];
  58. dest.shape[1] = src.shape[1];
  59. dest.shape[2] = oshp2d.shape[0];
  60. dest.shape[3] = (tshp1n) ? oshp2d.shape[0] : oshp2d.shape[1];
  61. dest.shape[4] = src.shape[4];
  62. } else {
  63. mgb_throw(
  64. MegBrainError, "AdaptivePooling not support %d format",
  65. (int)param_format);
  66. }
  67. }
  68. size_t AdaptivePoolingForward::get_workspace_size_bytes(
  69. const TensorShapeArray& input_shapes,
  70. const TensorShapeArray& output_shapes) const {
  71. return megdnn_opr()->get_workspace_in_bytes(
  72. {input_shapes[0], this->input(0)->dtype(), this->input(0)->format()},
  73. {output_shapes[0], this->output(0)->dtype(), this->output(0)->format()});
  74. }
  75. void AdaptivePoolingForward::init_output_dtype() {
  76. output(0)->dtype(input(0)->dtype());
  77. }
  78. void AdaptivePoolingForward::add_input_layout_constraint() {
  79. mixin::megdnn_utils::add_input_layout_constraint_contig(*this);
  80. }
  81. void AdaptivePoolingForward::init_output_static_infer_desc() {
  82. Super::init_output_static_infer_desc();
  83. init_output_static_infer_desc_workspace(false);
  84. }
  85. void AdaptivePoolingForward::record_execute_deps(ExecDependencyArray& deps) {
  86. record_megdnn_opr(deps);
  87. }
  88. #if MGB_ENABLE_GRAD
  89. MGB_IMPL_OPR_GRAD(AdaptivePoolingForward) {
  90. if (wrt_idx == 0) {
  91. // wrt src
  92. SymbolVar grad = AdaptivePoolingBackward::make(
  93. opr.input(0), opr.input(1), opr.output(0), out_grad[0], opr.param());
  94. return grad.node();
  95. } else {
  96. mgb_assert(wrt_idx == 1);
  97. return InvalidGrad::make(opr, wrt_idx);
  98. }
  99. }
  100. #endif
  101. MGB_DYN_TYPE_OBJ_FINAL_IMPL(AdaptivePoolingBackward);
  102. AdaptivePoolingBackward::AdaptivePoolingBackward(
  103. VarNode* src, VarNode* out_shape, VarNode* dst, VarNode* diff,
  104. const Param& param, const OperatorNodeConfig& config)
  105. : Super(
  106. OperatorNodeBaseCtorParam{
  107. src->owner_graph(), config, "adaptive_pooling_bwd", {src}},
  108. 0, true) {
  109. init_megdnn_opr(*this, param);
  110. add_input({src, out_shape, dst, diff});
  111. }
  112. SymbolVar AdaptivePoolingBackward::make(
  113. SymbolVar src, SymbolVar out_shape, SymbolVar dst, SymbolVar diff,
  114. const Param& param, const OperatorNodeConfig& config) {
  115. return src.insert_single_output_opr<AdaptivePoolingBackward>(
  116. src.node(), out_shape.node(), dst.node(), diff.node(), param, config);
  117. }
  118. void AdaptivePoolingBackward::scn_do_execute() {
  119. megdnn_opr()->exec(
  120. input(0)->dev_tensor().as_megdnn(), input(2)->dev_tensor().as_megdnn(),
  121. input(3)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(),
  122. intl::get_megdnn_workspace_from_var(output().back()));
  123. }
  124. size_t AdaptivePoolingBackward::get_workspace_size_bytes(
  125. const TensorShapeArray& input_shapes,
  126. const TensorShapeArray& output_shapes) const {
  127. return megdnn_opr()->get_workspace_in_bytes(
  128. {input_shapes[0], input(0)->dtype(), input(0)->format()},
  129. {input_shapes[2], input(2)->dtype(), input(2)->format()},
  130. {input_shapes[3], input(3)->dtype(), input(3)->format()},
  131. {output_shapes[0], output(0)->dtype(), output(0)->format()});
  132. }
  133. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}