You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

layer_norm.cpp 3.1 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. /**
  2. * \file dnn/test/cuda/layer_norm.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "test/cuda/fixture.h"
  13. #include "test/common/checker.h"
  14. namespace megdnn {
  15. namespace test {
  16. TEST_F(CUDA, LAYERNORM_FORWARD) {
  17. using Param = LayerNormForward::Param;
  18. Param param;
  19. param.affine = true;
  20. param.eps = 1e-6;
  21. param.normalized_dim = 1;
  22. Checker<LayerNormForward> checker(handle_cuda());
  23. checker.set_epsilon(1e-2);
  24. auto run = [&](DType d) {
  25. for (size_t n_slices : {10, 30})
  26. for (size_t slice_len : {10, 30}) {
  27. param.normalized_size = slice_len;
  28. checker.set_param(param)
  29. .set_dtype(0, d)
  30. .set_dtype(1, d)
  31. .set_dtype(2, d)
  32. .set_dtype(3, d)
  33. .set_dtype(4, dtype::Float32())
  34. .set_dtype(5, dtype::Float32())
  35. .execs({{n_slices, slice_len},
  36. {slice_len},
  37. {slice_len},
  38. {n_slices, slice_len},
  39. {n_slices},
  40. {n_slices}});
  41. }
  42. };
  43. run(dtype::Float32());
  44. run(dtype::Float16());
  45. run(dtype::BFloat16());
  46. }
  47. TEST_F(CUDA, LAYERNORM_BACKWARD) {
  48. using Param = LayerNormBackward::Param;
  49. Param param;
  50. param.affine = true;
  51. param.eps = 1e-6;
  52. param.normalized_dim = 1;
  53. Checker<LayerNormBackward> checker(handle_cuda());
  54. checker.set_epsilon(1e-1);
  55. auto run = [&](DType d) {
  56. for (size_t n_slices : {10, 30})
  57. for (size_t slice_len : {10, 30}) {
  58. param.normalized_size = slice_len;
  59. checker.set_param(param)
  60. .set_dtype(0, d)
  61. .set_dtype(1, d)
  62. .set_dtype(2, d)
  63. .set_dtype(3, dtype::Float32())
  64. .set_dtype(4, dtype::Float32())
  65. .set_dtype(5, d)
  66. .set_dtype(6, d)
  67. .set_dtype(7, d)
  68. .execs({{n_slices, slice_len},
  69. {n_slices, slice_len},
  70. {slice_len},
  71. {n_slices},
  72. {n_slices},
  73. {n_slices, slice_len},
  74. {slice_len},
  75. {slice_len}});
  76. }
  77. };
  78. run(dtype::Float32());
  79. run(dtype::Float16());
  80. run(dtype::BFloat16());
  81. }
  82. } // namespace test
  83. } // namespace megdnn
  84. // vim: syntax=cpp.doxygen