You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pooling.cpp 3.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. /**
  2. * \file src/opr/impl/dnn/pooling.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/opr/dnn/pooling.h"
  12. #include "../internal/megdnn_opr_wrapper.inl"
  13. #include "megbrain/graph/grad_impl.h"
  14. #include "megbrain/opr/search_policy/algo_chooser.h"
  15. #include "../search_policy/workspace_need_limit_getter.inl"
  16. using namespace mgb;
  17. using namespace opr;
  18. MGB_DYN_TYPE_OBJ_FINAL_IMPL(PoolingForward);
  19. PoolingForward::PoolingForward(
  20. VarNode* i0, const Param& param, const ExecutionPolicy& policy,
  21. const OperatorNodeConfig& config)
  22. : Super(OperatorNodeBaseCtorParam{i0->owner_graph(), config, "pooling", {i0}}) {
  23. init_megdnn_opr(*this, param);
  24. add_input({i0});
  25. m_policy = policy;
  26. intl::MegDNNOprInitPostCtor<PoolingForward>::apply(*this);
  27. }
  28. SymbolVar PoolingForward::make(
  29. SymbolVar i0, const Param& param, const ExecutionPolicy& policy,
  30. const OperatorNodeConfig& config) {
  31. intl::MegDNNOprInitInputsModifier<PoolingForward>::apply(param, {&i0});
  32. return i0.insert_single_output_opr<PoolingForward>(
  33. i0.node(), param, policy, config);
  34. }
  35. void PoolingForward::init_output_static_infer_desc() {
  36. Super::set_nr_managed_outputs(this->output().size() - 1);
  37. Super::Super::init_output_static_infer_desc();
  38. init_output_static_infer_desc_workspace(
  39. intl::AutoAddWorkspaceNeedLimitGetter<megdnn::PoolingForward>::val);
  40. }
  41. size_t PoolingForward::get_workspace_size_bytes(
  42. const TensorShapeArray& input_shapes,
  43. const TensorShapeArray& output_shapes) const {
  44. return AlgoChooser<megdnn::PoolingForward>::setup_algo(
  45. {TensorLayout{input_shapes[0], input(0)->dtype(), input(0)->format()},
  46. {output_shapes[0], output(0)->dtype(), output(0)->format()}},
  47. megdnn_opr(), this, false);
  48. }
  49. #if MGB_ENABLE_GRAD
  50. MGB_IMPL_OPR_GRAD(PoolingForward) {
  51. mgb_assert(wrt_idx == 0);
  52. SymbolVar grad = PoolingBackward::make(
  53. opr.input(0), opr.output(0), out_grad[0], opr.param());
  54. return grad.node();
  55. }
  56. #endif
  57. MGB_DYN_TYPE_OBJ_FINAL_IMPL(PoolingBackward);
  58. PoolingBackward::PoolingBackward(
  59. VarNode* i0, VarNode* i1, VarNode* i2, const Param& param,
  60. const ExecutionPolicy& policy, const OperatorNodeConfig& config)
  61. : Super(
  62. OperatorNodeBaseCtorParam{
  63. i0->owner_graph(), config, "pooling_bwd", {i0}},
  64. 0, true) {
  65. init_megdnn_opr(*this, param);
  66. add_input({i0, i1, i2});
  67. m_policy = policy;
  68. intl::MegDNNOprInitPostCtor<PoolingBackward>::apply(*this);
  69. }
  70. SymbolVar PoolingBackward::make(
  71. SymbolVar i0, SymbolVar i1, SymbolVar i2, const Param& param,
  72. const ExecutionPolicy& policy, const OperatorNodeConfig& config) {
  73. intl::MegDNNOprInitInputsModifier<PoolingBackward>::apply(param, {&i0, &i1, &i2});
  74. return i0.insert_single_output_opr<PoolingBackward>(
  75. i0.node(), i1.node(), i2.node(), param, policy, config);
  76. }
  77. size_t PoolingBackward::get_workspace_size_bytes(
  78. const TensorShapeArray& input_shapes,
  79. const TensorShapeArray& output_shapes) const {
  80. return AlgoChooser<megdnn::PoolingBackward>::setup_algo(
  81. {TensorLayout{input_shapes[0], input(0)->dtype(), input(0)->format()},
  82. {input_shapes[1], input(1)->dtype(), input(1)->format()},
  83. {input_shapes[2], input(2)->dtype(), input(2)->format()},
  84. {output_shapes[0], output(0)->dtype(), output(0)->format()}},
  85. megdnn_opr(), this, false);
  86. }
  87. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}