You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

group_norm.cpp 2.8 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. #include "megdnn/dtype.h"
  2. #include "megdnn/oprs.h"
  3. #include "test/common/checker.h"
  4. #include "test/naive/fixture.h"
  5. namespace megdnn {
  6. namespace test {
  7. TEST_F(NAIVE, GROUPNORM_FORWARD) {
  8. Checker<GroupNorm> checker(handle(), true);
  9. GroupNorm::Param param;
  10. param.affine = true;
  11. param.group = 3;
  12. checker.set_param(param).exect(
  13. Testcase{
  14. TensorValue(
  15. {2, 3, 2, 1}, dtype::Float32(),
  16. {3.3179, 0.109, -0.5855, 0.2566, -1.2897, 1.2683, -2.0587,
  17. 0.0711, -0.1169, 0.2509, -0.2393, 0.0876}), // input
  18. TensorValue({3}, dtype::Float32(), {1., 1., 1.}), // hx
  19. TensorValue({3}, dtype::Float32(), {0., 0., 0.}), // cx
  20. {},
  21. {},
  22. {}},
  23. Testcase{
  24. {},
  25. {},
  26. {},
  27. TensorValue(
  28. {2, 3, 2, 1}, dtype::Float32(),
  29. {1., -1., -1., 1., -1., 1., -1., 1., -0.9999, 0.9999,
  30. -0.9998, 0.9998}), // output
  31. TensorValue(
  32. {2, 3}, dtype::Float32(),
  33. {1.7135, -0.1645, -0.0107, -0.9938, 0.067,
  34. -0.0758}), // mean
  35. TensorValue(
  36. {2, 3}, dtype::Float32(),
  37. {2.5742, 0.1772, 1.6358, 1.1340, 0.0338, 0.0267}), // var
  38. });
  39. checker.set_param(param).exect(
  40. Testcase{
  41. TensorValue(
  42. {1, 3, 1, 2}, dtype::Float32(),
  43. {-2.4348, -1.7948, 0.5223, 0.0932, -0.2955,
  44. -0.0492}), // input
  45. TensorValue({3}, dtype::Float32(), {1., 1., 1.}), // hx
  46. TensorValue({3}, dtype::Float32(), {0., 0., 0.}), // cx
  47. {},
  48. {},
  49. {}},
  50. Testcase{
  51. {},
  52. {},
  53. {},
  54. TensorValue(
  55. {1, 3, 1, 2}, dtype::Float32(),
  56. {-0.9999, 0.9999, 0.9999, -0.9999, -0.9997,
  57. 0.9997}), // output
  58. TensorValue(
  59. {1, 3}, dtype::Float32(),
  60. {-2.1148, 0.3077, -0.1724}), // mean
  61. TensorValue(
  62. {1, 3}, dtype::Float32(), {0.1023, 0.0460, 0.0151}), // var
  63. });
  64. }
  65. } // namespace test
  66. } // namespace megdnn