You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. #include "test/rocm/fixture.h"
  2. #include "megdnn/opr_param_defs.h"
  3. #include "megdnn/oprs.h"
  4. #include "test/common/bn.h"
  5. #include "test/common/checker.h"
  6. #include "test/common/rng.h"
  7. #include "test/common/tensor.h"
  8. #include "test/common/workspace_wrapper.h"
  9. namespace megdnn {
  10. namespace test {
  11. TEST_F(ROCM, BN_FORWARD) {
  12. using namespace batch_normalization;
  13. std::vector<TestArg> args = get_args();
  14. Checker<BNForward> checker(handle_rocm());
  15. for (auto&& arg : args) {
  16. for (int i = 0; i < 8; ++i) {
  17. checker.set_dtype(i, dtype::Float32());
  18. }
  19. checker.set_dtype(0, arg.dtype);
  20. checker.set_dtype(8, arg.dtype);
  21. checker.set_epsilon(1e-3).set_param(arg.param);
  22. for (bool need_statistic : {false, true})
  23. checker.exec({
  24. arg.src,
  25. arg.param_shape, // bn_scale
  26. arg.param_shape, // bn_bias
  27. need_statistic ? arg.param_shape : TensorShape({0}), // mean
  28. need_statistic ? arg.param_shape : TensorShape({0}), // variance
  29. arg.param_shape, // batch_mean
  30. arg.param_shape, // batch_inv_variance
  31. {0}, // reserve
  32. arg.src // dst
  33. });
  34. }
  35. }
  36. TEST_F(ROCM, BN_BACKWARD) {
  37. using namespace batch_normalization;
  38. std::vector<TestArg> args = get_args();
  39. Checker<BNBackward> checker(handle_rocm());
  40. for (auto&& arg : args) {
  41. for (int i = 0; i < 9; ++i) {
  42. checker.set_dtype(i, dtype::Float32());
  43. }
  44. checker.set_dtype(0, arg.dtype) // x
  45. .set_dtype(1, arg.dtype) // dy
  46. .set_dtype(8, arg.dtype); // dx
  47. checker.set_epsilon(1e-3).set_param(arg.param).exec(
  48. {arg.src,
  49. arg.src,
  50. arg.param_shape,
  51. arg.param_shape,
  52. arg.param_shape,
  53. {0},
  54. arg.param_shape,
  55. arg.param_shape,
  56. arg.src});
  57. }
  58. }
  59. } // namespace test
  60. } // namespace megdnn
  61. // vim: syntax=cpp.doxygen