You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

padding.cpp 2.3 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. #include "megbrain/graph/symbol_var.h"
  2. #include "megbrain/imperative/ops/autogen.h"
  3. #include "megbrain/imperative/physical_tensor.h"
  4. #include "megbrain/imperative/proxy_graph_detail.h"
  5. #include "megbrain/opr/basic_arith.h"
  6. #include "megbrain/opr/internal/megdnn_opr_wrapper.h"
  7. #include "megbrain/opr/io.h"
  8. #include "megbrain/opr/tensor_manip.h"
  9. #include "megdnn/dtype.h"
  10. #include "../blob_manager_impl.h"
  11. #include "../dnn_op_helper.h"
  12. #include "../op_trait.h"
  13. namespace mgb {
  14. namespace imperative {
  15. namespace {
  16. namespace padding {
  17. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  18. auto&& op = static_cast<const Padding&>(def);
  19. mgb_assert(inputs.size() == 1);
  20. return opr::Padding::make(inputs[0], op.param());
  21. }
  22. SmallVector<TensorPtr> apply_on_physical_tensor(
  23. const OpDef& def, const SmallVector<TensorPtr>& inputs,
  24. SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
  25. auto comp_node = inputs[0]->comp_node();
  26. auto&& op_def = def.cast_final_safe<Padding>();
  27. DnnOprCaller<megdnn::Padding> dnn_op(comp_node, op_def.param());
  28. auto dst = [&] {
  29. if (validated) {
  30. return output_descs[0].layout;
  31. } else {
  32. return dnn_op.deduce_layout(inputs[0]->layout());
  33. }
  34. }();
  35. auto out = Tensor::make(dst, comp_node);
  36. dnn_op.exec(inputs[0], out);
  37. return {out};
  38. }
  39. std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
  40. const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
  41. auto&& op_def = def.cast_final_safe<Padding>();
  42. auto&& inp = inputs[0];
  43. if (inp.layout.ndim == 0) {
  44. return {{{TensorLayout{inp.layout.dtype}, inp.comp_node, {}}}, false};
  45. }
  46. DnnOprHelper<megdnn::Padding> dnn_op(op_def.param());
  47. auto oup_layout = dnn_op.deduce_layout(inp.layout);
  48. return {{{oup_layout, inp.comp_node}}, true};
  49. }
  50. OP_TRAIT_REG(Padding, Padding, opr::Padding)
  51. .apply_on_var_node(apply_on_var_node)
  52. .apply_on_physical_tensor(apply_on_physical_tensor)
  53. .infer_output_attrs_fallible(infer_output_attrs_fallible)
  54. .fallback();
  55. } // namespace padding
  56. } // namespace
  57. } // namespace imperative
  58. } // namespace mgb
  59. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}