You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

layer_norm.cpp 2.6 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. #include "test/cuda/fixture.h"
  2. #include "test/common/checker.h"
  3. namespace megdnn {
  4. namespace test {
  5. TEST_F(CUDA, LAYERNORM_FORWARD) {
  6. using Param = LayerNormForward::Param;
  7. Param param;
  8. param.affine = true;
  9. param.eps = 1e-6;
  10. param.normalized_dim = 1;
  11. Checker<LayerNormForward> checker(handle_cuda());
  12. checker.set_epsilon(1e-2);
  13. auto run = [&](DType d) {
  14. for (size_t n_slices : {10, 30})
  15. for (size_t slice_len : {10, 30}) {
  16. param.normalized_size = slice_len;
  17. checker.set_param(param)
  18. .set_dtype(0, d)
  19. .set_dtype(1, d)
  20. .set_dtype(2, d)
  21. .set_dtype(3, d)
  22. .set_dtype(4, dtype::Float32())
  23. .set_dtype(5, dtype::Float32())
  24. .execs({{n_slices, slice_len},
  25. {slice_len},
  26. {slice_len},
  27. {n_slices, slice_len},
  28. {n_slices},
  29. {n_slices}});
  30. }
  31. };
  32. run(dtype::Float32());
  33. run(dtype::Float16());
  34. run(dtype::BFloat16());
  35. }
  36. TEST_F(CUDA, LAYERNORM_BACKWARD) {
  37. using Param = LayerNormBackward::Param;
  38. Param param;
  39. param.affine = true;
  40. param.eps = 1e-6;
  41. param.normalized_dim = 1;
  42. Checker<LayerNormBackward> checker(handle_cuda());
  43. checker.set_epsilon(1e-1);
  44. auto run = [&](DType d) {
  45. for (size_t n_slices : {10, 30})
  46. for (size_t slice_len : {10, 30}) {
  47. param.normalized_size = slice_len;
  48. checker.set_param(param)
  49. .set_dtype(0, d)
  50. .set_dtype(1, d)
  51. .set_dtype(2, d)
  52. .set_dtype(3, dtype::Float32())
  53. .set_dtype(4, dtype::Float32())
  54. .set_dtype(5, d)
  55. .set_dtype(6, d)
  56. .set_dtype(7, d)
  57. .execs({{n_slices, slice_len},
  58. {n_slices, slice_len},
  59. {slice_len},
  60. {n_slices},
  61. {n_slices},
  62. {n_slices, slice_len},
  63. {slice_len},
  64. {slice_len}});
  65. }
  66. };
  67. run(dtype::Float32());
  68. run(dtype::Float16());
  69. run(dtype::BFloat16());
  70. }
  71. } // namespace test
  72. } // namespace megdnn
  73. // vim: syntax=cpp.doxygen