You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. #include "test/cuda/fixture.h"
  2. #include "megdnn/opr_param_defs.h"
  3. #include "megdnn/oprs.h"
  4. #include "src/cuda/batch_normalization/opr_impl.h"
  5. #include "src/cuda/utils.h"
  6. #include "test/common/bn.h"
  7. #include "test/common/checker.h"
  8. #include "test/common/rng.h"
  9. #include "test/common/tensor.h"
  10. #include "test/common/workspace_wrapper.h"
  11. namespace megdnn {
  12. namespace test {
  13. TEST_F(CUDA, BN_FORWARD_BACKWARD) {
  14. using namespace batch_normalization;
  15. using cuda::cudnn_handle;
  16. using cuda::batch_normalization::BNTensorDescHolder;
  17. using cuda::batch_normalization::get_reserve_size;
  18. std::vector<TestArg> args = batch_normalization::get_args();
  19. Checker<BNForward> checker(handle_cuda());
  20. Checker<BNBackward> checker_bwd(handle_cuda());
  21. for (auto&& arg : args) {
  22. auto tensor_desc = BNTensorDescHolder(
  23. {arg.src, arg.dtype}, arg.param.param_dim, arg.param.fwd_mode);
  24. auto reserve = get_reserve_size(cudnn_handle(handle_cuda()), tensor_desc);
  25. // Forward
  26. for (int i = 0; i < 9; ++i) {
  27. checker.set_dtype(i, dtype::Float32());
  28. }
  29. checker.set_dtype(0, arg.dtype);
  30. checker.set_dtype(7, dtype::Byte());
  31. checker.set_dtype(8, arg.dtype);
  32. checker.set_bypass(7);
  33. checker.set_epsilon(1e-3).set_param(arg.param);
  34. for (bool need_statistic : {false, true})
  35. checker.exec({
  36. arg.src,
  37. arg.param_shape, // bn_scale
  38. arg.param_shape, // bn_bias
  39. need_statistic ? arg.param_shape : TensorShape({0}), // mean
  40. need_statistic ? arg.param_shape : TensorShape({0}), // variance
  41. arg.param_shape, // batch_mean
  42. arg.param_shape, // batch_inv_variance
  43. {reserve}, // reserve
  44. arg.src // dst
  45. });
  46. // Backward
  47. for (int i = 0; i < 9; ++i) {
  48. checker_bwd.set_dtype(i, dtype::Float32());
  49. }
  50. checker_bwd
  51. .set_dtype(0, arg.dtype) // x
  52. .set_dtype(1, arg.dtype) // dy
  53. .set_dtype(5, dtype::Byte()) // reserve
  54. .set_dtype(8, arg.dtype) // dx
  55. .set_bypass(5);
  56. checker_bwd.set_epsilon(1e-3).set_param(arg.param).exec(
  57. {arg.src,
  58. arg.src,
  59. arg.param_shape,
  60. arg.param_shape,
  61. arg.param_shape,
  62. {reserve},
  63. arg.param_shape,
  64. arg.param_shape,
  65. arg.src});
  66. }
  67. }
  68. } // namespace test
  69. } // namespace megdnn
  70. // vim: syntax=cpp.doxygen