You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

roi_pooling.cpp 2.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. #include "megdnn/oprs.h"
  2. #include "src/common/utils.h"
  3. namespace megdnn {
  4. void ROIPoolingBase::check_layout_fwd(
  5. const TensorLayout& src, const TensorLayout& rois, const TensorLayout& dst,
  6. const TensorLayout& index) {
  7. // all should be contiguous
  8. megdnn_assert_contiguous(src);
  9. megdnn_assert_contiguous(rois);
  10. megdnn_assert_contiguous(dst);
  11. megdnn_assert_contiguous(index);
  12. auto errmsg = [&]() {
  13. return megdnn_layout_msg(src) + ", " + megdnn_layout_msg(rois) + ", " +
  14. megdnn_layout_msg(dst) + ", " + megdnn_layout_msg(index);
  15. };
  16. MEGDNN_MARK_USED_VAR(errmsg);
  17. // src
  18. megdnn_assert(src.ndim == 4_z, "%s", errmsg().c_str());
  19. auto C = src.shape[1];
  20. // rois
  21. megdnn_assert(rois.ndim == 2_z, "%s", errmsg().c_str());
  22. auto M = rois.shape[0];
  23. megdnn_assert(rois[1] == 5_z, "%s", errmsg().c_str());
  24. // dst
  25. megdnn_assert(dst[0] == M, "%s", errmsg().c_str());
  26. megdnn_assert(dst[1] == C, "%s", errmsg().c_str());
  27. // index
  28. megdnn_assert_eq_shape(index, dst);
  29. megdnn_assert(src.dtype.category() == DTypeCategory::FLOAT);
  30. megdnn_assert(rois.dtype.category() == DTypeCategory::FLOAT);
  31. megdnn_assert(dst.dtype.category() == DTypeCategory::FLOAT);
  32. megdnn_assert(index.dtype == dtype::Int32());
  33. }
  34. void ROIPoolingForward::check_exec(
  35. const TensorLayout& src, const TensorLayout& rois, const TensorLayout& dst,
  36. const TensorLayout& index, size_t workspace_in_bytes) {
  37. check_layout_fwd(src, rois, dst, index);
  38. auto required_workspace_in_bytes = get_workspace_in_bytes(src, rois, dst, index);
  39. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  40. }
  41. void ROIPoolingBackward::check_exec(
  42. const TensorLayout& diff, const TensorLayout& src, const TensorLayout& rois,
  43. const TensorLayout& index, const TensorLayout& grad,
  44. size_t workspace_in_bytes) {
  45. check_layout_fwd(src, rois, diff, index);
  46. megdnn_assert_eq_layout(src, grad);
  47. auto required_workspace_in_bytes =
  48. get_workspace_in_bytes(diff, src, rois, index, grad);
  49. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  50. }
  51. } // namespace megdnn
  52. // vim: syntax=cpp.doxygen