You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pooling.cpp 3.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. /**
  2. * \file imperative/src/impl/ops/pooling.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/opr/dnn/pooling.h"
  12. #include "megbrain/imperative/ops/autogen.h"
  13. #include "megbrain/opr/utility.h"
  14. #include "megbrain/opr/internal/megdnn_opr_wrapper.h"
  15. #include "../algo_chooser.h"
  16. #include "../blob_manager_impl.h"
  17. #include "../dnn_op_helper.h"
  18. #include "../op_trait.h"
  19. namespace mgb::imperative {
  20. namespace {
  21. namespace pooling {
  22. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  23. auto&& pool = static_cast<const Pooling&>(def);
  24. OperatorNodeConfig config{pool.make_name()};
  25. return opr::Pooling::make(inputs[0], pool.param(), pool.policy(), config);
  26. }
  27. std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
  28. const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
  29. mgb_assert(
  30. inputs.size() == 1, "num of inputs of pooling should be 1 but you give %zu",
  31. inputs.size());
  32. auto&& op_def = def.cast_final_safe<Pooling>();
  33. auto&& inp = inputs[0];
  34. auto& inp_cn = inp.comp_node;
  35. if (inp.layout.ndim == 0) {
  36. return {{{TensorLayout{inp.layout.dtype}, inp_cn, {}}}, false};
  37. }
  38. TensorLayout oup_layout;
  39. megdnn::Pooling::deduce_layout_impl(inp.layout, op_def.param(), oup_layout);
  40. return {{{oup_layout, inp_cn, {}}}, true};
  41. }
  42. SmallVector<TensorPtr> apply_on_physical_tensor(
  43. const OpDef& def, const SmallVector<TensorPtr>& inputs,
  44. SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
  45. mgb_assert(
  46. inputs.size() == 1, "num of inputs of pooling should be 1 but you give %zu",
  47. inputs.size());
  48. auto&& op_def = def.cast_final_safe<Pooling>();
  49. auto cn = inputs[0]->comp_node();
  50. megdnn::TensorND inp_tensornd = inputs[0]->dnn_tensor();
  51. DnnOprCaller<megdnn::Pooling> caller(cn);
  52. auto&& dnn_opr = caller.op;
  53. dnn_opr->param() = op_def.param();
  54. TensorLayout& oup_layout = output_descs[0].layout;
  55. if (!validated) {
  56. megdnn::Pooling::deduce_layout_impl(
  57. inp_tensornd.layout, op_def.param(), oup_layout);
  58. }
  59. DeviceTensorND out_devtensor =
  60. BlobManager::inst()->alloc_workspace_with_defrag(cn, oup_layout);
  61. size_t wk_size = setup_algo<megdnn::Pooling>(
  62. {inp_tensornd.layout, oup_layout}, dnn_opr.get(), 0, false, false, cn,
  63. op_def.policy(), false);
  64. megdnn::Workspace dnn_wk;
  65. if (wk_size) {
  66. TensorLayout w_layout({wk_size}, dtype::Byte());
  67. dnn_wk = caller.create_workspace(w_layout);
  68. }
  69. dnn_opr->exec(inp_tensornd, out_devtensor.as_megdnn(), dnn_wk);
  70. return {Tensor::make(out_devtensor)};
  71. }
  72. OP_TRAIT_REG(Pooling, Pooling)
  73. .apply_on_var_node(apply_on_var_node)
  74. .infer_output_attrs_fallible(infer_output_attrs_fallible)
  75. .apply_on_physical_tensor(apply_on_physical_tensor)
  76. .fallback();
  77. } // namespace pooling
  78. } // namespace
  79. } // namespace mgb::imperative