You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fake_quant.cpp 5.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. #include "test/common/fake_quant.h"
  2. #include "megdnn/oprs.h"
  3. #include "test/common/checker.h"
  4. #include "test/cuda/fixture.h"
  5. namespace megdnn {
  6. namespace test {
  7. using namespace fake_quant;
  8. TEST_F(CUDA, FAKE_QUANT) {
  9. std::vector<TestArg> args = get_args();
  10. auto dtype = dtype::Float32();
  11. UniformFloatRNG rng(-1.0f, 1.0f);
  12. const auto nan = std::numeric_limits<float>::quiet_NaN();
  13. UniformFloatWithValueRNG rng1 = UniformFloatWithValueRNG(-1.0f, 1.0f, 0.5f, nan);
  14. for (auto&& arg : args) {
  15. auto param = arg.param;
  16. auto ishape = arg.ishape;
  17. auto scale_shape = arg.scale_shape;
  18. auto zeropoint_shape = arg.zeropoint_shape;
  19. Checker<FakeQuantForward> checker(handle_cuda());
  20. checker.set_param(param)
  21. .set_dtype(0, dtype)
  22. .set_dtype(1, dtype)
  23. .set_dtype(2, dtype)
  24. .set_dtype(3, dtype)
  25. .execs(TensorShapeArray{ishape, scale_shape, zeropoint_shape, ishape});
  26. checker.set_allow_invalid_check(true);
  27. checker.set_rng(0, &rng1);
  28. checker.set_param(param)
  29. .set_dtype(0, dtype)
  30. .set_dtype(1, dtype)
  31. .set_dtype(2, dtype)
  32. .set_dtype(3, dtype)
  33. .execs(TensorShapeArray{ishape, scale_shape, zeropoint_shape, ishape});
  34. checker.set_rng(0, &rng);
  35. checker.set_allow_invalid_check(false);
  36. }
  37. // test noncontiguous layout
  38. for (auto&& arg : args) {
  39. auto param = arg.param;
  40. auto ishape = arg.ishape;
  41. auto scale_shape = arg.scale_shape;
  42. auto zeropoint_shape = arg.zeropoint_shape;
  43. Checker<FakeQuantForward> checker(handle_cuda());
  44. TensorLayout ilayout(
  45. ishape,
  46. {(long int)(ishape[1] * ishape[2] * ishape[3] * 2),
  47. (long int)(ishape[2] * ishape[3]), (long int)ishape[3], 1},
  48. dtype::Float32());
  49. checker.set_param(param).execl(
  50. {ilayout,
  51. {scale_shape, dtype::Float32()},
  52. {zeropoint_shape, dtype::Float32()},
  53. ilayout});
  54. checker.set_allow_invalid_check(true);
  55. checker.set_rng(0, &rng1);
  56. checker.set_param(param).execl(
  57. {ilayout,
  58. {scale_shape, dtype::Float32()},
  59. {zeropoint_shape, dtype::Float32()},
  60. ilayout});
  61. checker.set_rng(0, &rng);
  62. checker.set_allow_invalid_check(false);
  63. }
  64. }
  65. TEST_F(CUDA, FAKE_QUANT_BACKWARD) {
  66. std::vector<TestArg> args = get_args();
  67. auto dtype = dtype::Float32();
  68. UniformFloatRNG rng(-1.0f, 1.0f);
  69. const auto nan = std::numeric_limits<float>::quiet_NaN();
  70. UniformFloatWithValueRNG rng1 = UniformFloatWithValueRNG(-1.0f, 1.0f, 0.5f, nan);
  71. for (auto&& arg : args) {
  72. auto param = arg.param;
  73. auto ishape = arg.ishape;
  74. auto scale_shape = arg.scale_shape;
  75. auto zeropoint_shape = arg.zeropoint_shape;
  76. Checker<FakeQuantBackward> checker(handle_cuda());
  77. checker.set_param(param)
  78. .set_dtype(0, dtype)
  79. .set_dtype(1, dtype)
  80. .set_dtype(2, dtype)
  81. .set_dtype(3, dtype)
  82. .set_dtype(4, dtype)
  83. .execs(TensorShapeArray{
  84. ishape, ishape, scale_shape, zeropoint_shape, ishape});
  85. checker.set_allow_invalid_check(true);
  86. checker.set_rng(0, &rng1);
  87. checker.set_param(param)
  88. .set_dtype(0, dtype)
  89. .set_dtype(1, dtype)
  90. .set_dtype(2, dtype)
  91. .set_dtype(3, dtype)
  92. .set_dtype(4, dtype)
  93. .execs(TensorShapeArray{
  94. ishape, ishape, scale_shape, zeropoint_shape, ishape});
  95. checker.set_rng(0, &rng);
  96. checker.set_allow_invalid_check(false);
  97. }
  98. // test noncontiguous layout
  99. for (auto&& arg : args) {
  100. auto param = arg.param;
  101. auto ishape = arg.ishape;
  102. auto scale_shape = arg.scale_shape;
  103. auto zeropoint_shape = arg.zeropoint_shape;
  104. Checker<FakeQuantBackward> checker(handle_cuda());
  105. TensorLayout ilayout(
  106. ishape,
  107. {(long int)(ishape[1] * ishape[2] * ishape[3] * 2),
  108. (long int)(ishape[2] * ishape[3]), (long int)ishape[3], 1},
  109. dtype::Float32());
  110. checker.set_param(param).execl(
  111. {ilayout,
  112. ilayout,
  113. {scale_shape, dtype::Float32()},
  114. {zeropoint_shape, dtype::Float32()},
  115. ilayout});
  116. checker.set_allow_invalid_check(true);
  117. checker.set_rng(0, &rng1);
  118. checker.set_param(param).execl(
  119. {ilayout,
  120. ilayout,
  121. {scale_shape, dtype::Float32()},
  122. {zeropoint_shape, dtype::Float32()},
  123. ilayout});
  124. checker.set_rng(0, &rng);
  125. checker.set_allow_invalid_check(false);
  126. }
  127. }
  128. } // namespace test
  129. } // namespace megdnn