You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

deformable_ps_roi_pooling.cpp 3.8 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. #include "megdnn/oprs/nn.h"
  2. #include "src/common/utils.h"
  3. namespace megdnn {
  4. void DeformablePSROIPoolingBase::deduce_layout_fwd(
  5. const TensorLayout& data, const TensorLayout& rois, const TensorLayout& trans,
  6. TensorLayout& out_data, TensorLayout& out_count) {
  7. megdnn_assert_contiguous(data);
  8. megdnn_assert_contiguous(rois);
  9. megdnn_assert_contiguous(trans);
  10. auto errmsg = [&]() {
  11. return std::string("data: ") + megdnn_layout_msg(data) +
  12. ", rois: " + megdnn_layout_msg(rois) +
  13. ", trans: " + megdnn_layout_msg(trans) +
  14. ", out_data: " + megdnn_layout_msg(out_data) +
  15. ", out_count: " + megdnn_layout_msg(out_count);
  16. };
  17. MEGDNN_MARK_USED_VAR(data);
  18. MEGDNN_MARK_USED_VAR(rois);
  19. MEGDNN_MARK_USED_VAR(trans);
  20. MEGDNN_MARK_USED_VAR(out_data);
  21. MEGDNN_MARK_USED_VAR(out_count);
  22. MEGDNN_MARK_USED_VAR(out_count);
  23. MEGDNN_MARK_USED_VAR(errmsg);
  24. megdnn_assert(
  25. data.dtype.enumv() == DTypeEnum::Float32,
  26. "DeformablePSROIPooling only support float32 input");
  27. megdnn_assert(data.ndim == 4_z, "invalid data shape, %s", errmsg().c_str());
  28. megdnn_assert(
  29. rois.ndim == 2_z && rois[1] == 5, "invalid rois shape, %s",
  30. errmsg().c_str());
  31. megdnn_assert(trans.ndim == 4_z, "invalid trans shape, %s", errmsg().c_str());
  32. if (!param().no_trans) {
  33. megdnn_assert(
  34. trans[1] == 2_z && trans[2] == param().pooled_h &&
  35. trans[3] == param().pooled_w,
  36. "invalid trans shape: %s", errmsg().c_str());
  37. }
  38. out_data = {{rois[0], data[1], param().pooled_h, param().pooled_w}, data.dtype};
  39. out_count = out_data;
  40. }
  41. void DeformablePSROIPoolingBase::check_layout_fwd(
  42. const TensorLayout& data, const TensorLayout& rois, const TensorLayout& trans,
  43. const TensorLayout& out_data, const TensorLayout& out_count,
  44. size_t workspace_in_bytes) {
  45. MEGDNN_MARK_USED_VAR(workspace_in_bytes);
  46. TensorLayout exp_out_data, exp_out_count;
  47. deduce_layout_fwd(data, rois, trans, exp_out_data, exp_out_count);
  48. megdnn_assert_eq_layout(out_data, exp_out_data);
  49. megdnn_assert_eq_layout(out_count, exp_out_count);
  50. }
  51. void DeformablePSROIPoolingForward::deduce_layout(
  52. const TensorLayout& data, const TensorLayout& rois, const TensorLayout& trans,
  53. TensorLayout& out_data, TensorLayout& out_count) {
  54. deduce_layout_fwd(data, rois, trans, out_data, out_count);
  55. }
  56. void DeformablePSROIPoolingForward::check_exec(
  57. const TensorLayout& data, const TensorLayout& rois, const TensorLayout& trans,
  58. const TensorLayout& out_data, const TensorLayout& out_count,
  59. size_t workspace_in_bytes) {
  60. check_layout_fwd(data, rois, trans, out_data, out_count, workspace_in_bytes);
  61. auto required_workspace_in_bytes =
  62. get_workspace_in_bytes(data, rois, trans, out_data, out_count);
  63. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  64. }
  65. void DeformablePSROIPoolingBackward::check_exec(
  66. const TensorLayout& data, const TensorLayout& rois, const TensorLayout& trans,
  67. const TensorLayout& out_diff, const TensorLayout& out_count,
  68. const TensorLayout& data_diff, const TensorLayout& trans_diff,
  69. size_t workspace_in_bytes) {
  70. check_layout_fwd(
  71. data_diff, rois, trans_diff, out_diff, out_count, workspace_in_bytes);
  72. megdnn_assert_eq_layout(data, data_diff);
  73. megdnn_assert_eq_layout(trans, trans_diff);
  74. auto required_workspace_in_bytes = get_workspace_in_bytes(
  75. data, rois, trans, out_diff, out_count, data_diff, trans_diff);
  76. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  77. }
  78. } // namespace megdnn
  79. // vim: syntax=cpp.doxygen