You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

adaptive_pooling.cpp 3.2 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. #include "hcc_detail/hcc_defs_prologue.h"
  2. #include "test/rocm/fixture.h"
  3. #include "megdnn/tensor_iter.h"
  4. #include "test/common/adaptive_pooling.h"
  5. #include "test/common/checker.h"
  6. #include "src/common/utils.h"
  7. #include "test/rocm/utils.h"
  8. #include "test/rocm/benchmarker.h"
  9. namespace megdnn {
  10. namespace test {
  11. TEST_F(ROCM, ADAPTIVE_POOLING_FORWARD) {
  12. auto args = adaptive_pooling::get_args();
  13. using Format = param::AdaptivePooling::Format;
  14. DType dtype = dtype::Float32();
  15. for (auto&& arg : args) {
  16. auto param = arg.param;
  17. auto src = arg.ishape;
  18. auto dst = arg.oshape;
  19. param.format = Format::NCHW;
  20. Checker<AdaptivePooling> checker(handle_rocm());
  21. checker.set_epsilon(1e-2);
  22. checker.set_param(param).set_dtype(0, dtype).set_dtype(1, dtype).exec(
  23. TensorShapeArray{src, dst, {}});
  24. }
  25. }
  26. TEST_F(ROCM, ADAPTIVE_POOLING_BACKWARD) {
  27. auto args = adaptive_pooling::get_args();
  28. for (auto&& arg : args) {
  29. Checker<AdaptivePoolingBackward> checker(handle_rocm());
  30. TensorLayout ilayout = TensorLayout(arg.ishape, dtype::Float32());
  31. TensorLayout olayout = TensorLayout(arg.oshape, dtype::Float32());
  32. auto constraint = [this, arg](CheckerHelper::TensorValueArray& tensors_orig) {
  33. megdnn_assert(tensors_orig.size() == 4);
  34. auto opr = handle_rocm()->create_operator<AdaptivePoolingForward>();
  35. opr->param() = arg.param;
  36. auto tensors_rocm_storage = CheckerHelper::alloc_tensors(
  37. handle_rocm(), {tensors_orig[0].layout, tensors_orig[1].layout}, 0);
  38. auto&& tensors_rocm = *tensors_rocm_storage;
  39. auto span = tensors_rocm[0].layout.span();
  40. auto dst = static_cast<dt_byte*>(tensors_rocm[0].raw_ptr()) + span.low_byte;
  41. auto src = static_cast<const dt_byte*>(tensors_orig[0].raw_ptr()) +
  42. span.low_byte;
  43. megdnn_memcpy_H2D(handle_rocm(), dst, src, span.dist_byte());
  44. auto workspace_size = opr->get_workspace_in_bytes(
  45. tensors_rocm[0].layout, tensors_rocm[1].layout);
  46. auto workspace_rocm = megdnn_malloc(handle_rocm(), workspace_size);
  47. Workspace workspace{static_cast<dt_byte*>(workspace_rocm), workspace_size};
  48. opr->exec(tensors_rocm[0], tensors_rocm[1], workspace);
  49. megdnn_free(handle_rocm(), workspace_rocm);
  50. span = tensors_rocm[1].layout.span();
  51. dst = static_cast<dt_byte*>(tensors_orig[1].raw_ptr()) + span.low_byte;
  52. src = static_cast<const dt_byte*>(tensors_rocm[1].raw_ptr()) +
  53. span.low_byte;
  54. megdnn_memcpy_D2H(handle_rocm(), dst, src, span.dist_byte());
  55. };
  56. DType dtype = dtype::Float32();
  57. checker.set_tensors_constraint(constraint)
  58. .set_dtype(0, dtype)
  59. .set_dtype(1, dtype)
  60. .set_dtype(2, dtype)
  61. .set_dtype(3, dtype)
  62. .set_param(arg.param)
  63. .exec(TensorShapeArray{ilayout, olayout, olayout, ilayout});
  64. }
  65. }
  66. } // namespace test
  67. } // namespace megdnn
  68. // vim: syntax=cpp.doxygen