You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

accuracy_shake.cpp 3.3 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. #include "test/x86/fixture.h"
  2. #include "megdnn/opr_param_defs.h"
  3. #include "megdnn/oprs.h"
  4. #include "test/common/accuracy_shake_checker.h"
  5. #include "test/common/convolution.h"
  6. #include "test/common/rng.h"
  7. #include "test/common/tensor.h"
  8. #include "test/common/workspace_wrapper.h"
  9. namespace megdnn {
  10. namespace test {
  11. TEST_F(X86, SHAKE_CONV_BIAS_FORWARD) {
  12. AccuracyShakeChecker<ConvBiasForward> checker(handle());
  13. NormalRNG default_rng;
  14. checker.set_dtype(0, dtype::Float32())
  15. .set_dtype(1, dtype::Float32())
  16. .set_dtype(2, dtype::Float32())
  17. .set_rng(0, &default_rng)
  18. .set_rng(1, &default_rng);
  19. checker.set_before_exec_callback(AlgoGenerator<ConvBiasForward>("X86"));
  20. // convolution
  21. checker.exec({{6, 16, 32, 32}, {64, 16, 3, 3}, {}, {}, {}});
  22. // convbias without z
  23. checker.exec({{6, 16, 32, 32}, {64, 16, 3, 3}, {1, 64, 1, 1}, {}, {}});
  24. // convbias with z
  25. checker.exec({{6, 16, 32, 32}, {64, 16, 3, 3}, {1, 64, 1, 1}, {6, 64, 30, 30}, {}});
  26. // group
  27. ConvBias::Param param;
  28. param.sparse = ConvBias::Param::Sparse::GROUP;
  29. checker.set_param(param);
  30. checker.exec({{6, 16, 32, 32}, {2, 32, 8, 3, 3}, {}, {}, {}});
  31. checker.exec({{6, 16, 32, 32}, {2, 32, 8, 3, 3}, {1, 64, 1, 1}, {}, {}});
  32. checker.exec(
  33. {{6, 16, 32, 32}, {2, 32, 8, 3, 3}, {1, 64, 1, 1}, {6, 64, 30, 30}, {}});
  34. }
  35. TEST_F(X86, SHAKE_CONV_BIAS_FORWARD_INT8) {
  36. AccuracyShakeChecker<ConvBiasForward> checker(handle());
  37. UniformIntRNG rng{-50, 50};
  38. checker.set_dtype(0, dtype::QuantizedS8(2.5f))
  39. .set_dtype(1, dtype::QuantizedS8(2.5f))
  40. .set_dtype(2, dtype::QuantizedS32(6.25f))
  41. .set_dtype(3, dtype::QuantizedS32(6.25f))
  42. .set_dtype(4, {})
  43. .set_rng(0, &rng)
  44. .set_rng(1, &rng)
  45. .set_rng(2, &rng);
  46. checker.set_before_exec_callback(AlgoGenerator<ConvBiasForward>("X86"));
  47. // convolution
  48. checker.exec({{6, 16, 32, 32}, {64, 16, 3, 3}, {}, {}, {}});
  49. // convbias without z
  50. checker.exec({{6, 16, 32, 32}, {64, 16, 3, 3}, {1, 64, 1, 1}, {}, {}});
  51. // convbias with z
  52. checker.exec({{6, 16, 32, 32}, {64, 16, 3, 3}, {1, 64, 1, 1}, {6, 64, 30, 30}, {}});
  53. // group
  54. ConvBias::Param param;
  55. param.sparse = ConvBias::Param::Sparse::GROUP;
  56. checker.set_param(param);
  57. checker.exec({{6, 16, 32, 32}, {2, 32, 8, 3, 3}, {}, {}, {}});
  58. checker.exec({{6, 16, 32, 32}, {2, 32, 8, 3, 3}, {1, 64, 1, 1}, {}, {}});
  59. checker.exec(
  60. {{6, 16, 32, 32}, {2, 32, 8, 3, 3}, {1, 64, 1, 1}, {6, 64, 30, 30}, {}});
  61. }
  62. TEST_F(X86, SHAKE_MATRIX_MUL_FORWARD) {
  63. AccuracyShakeChecker<MatrixMul> checker(handle());
  64. checker.set_dtype(0, dtype::Float32())
  65. .set_dtype(1, dtype::Float32())
  66. .set_dtype(2, dtype::Float32())
  67. .exec({{20, 100}, {100, 60}, {}});
  68. }
  69. TEST_F(X86, SHAKE_MATRIX_MUL_6x16_FORWARD) {
  70. AccuracyShakeChecker<MatrixMul> checker(handle());
  71. checker.set_before_exec_callback(AlgoGenerator<MatrixMul>("X86_F32_6x16"));
  72. checker.set_dtype(0, dtype::Float32())
  73. .set_dtype(1, dtype::Float32())
  74. .set_dtype(2, dtype::Float32())
  75. .exec({{20, 100}, {100, 60}, {}});
  76. }
  77. } // namespace test
  78. } // namespace megdnn
  79. // vim: syntax=cpp.doxygen