You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

lsq.cpp 3.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. #include "test/common/lsq.h"
  2. #include "megdnn/oprs.h"
  3. #include "test/common/checker.h"
  4. #include "test/cuda/fixture.h"
  5. namespace megdnn {
  6. namespace test {
  7. using namespace lsq;
  8. TEST_F(CUDA, LSQ) {
  9. std::vector<TestArg> args = get_args();
  10. auto dtype = dtype::Float32();
  11. for (auto&& arg : args) {
  12. auto param = arg.param;
  13. auto ishape = arg.ishape;
  14. auto scale_shape = arg.scale_shape;
  15. auto zeropoint_shape = arg.zeropoint_shape;
  16. auto gradscale_shape = arg.gradscale_shape;
  17. Checker<LSQForward> checker(handle_cuda());
  18. checker.set_param(param)
  19. .set_dtype(0, dtype)
  20. .set_dtype(1, dtype)
  21. .set_dtype(2, dtype)
  22. .set_dtype(3, dtype)
  23. .set_dtype(4, dtype)
  24. .execs({ishape, scale_shape, zeropoint_shape, gradscale_shape, ishape});
  25. }
  26. // test noncontiguous layout
  27. for (auto&& arg : args) {
  28. auto param = arg.param;
  29. auto ishape = arg.ishape;
  30. auto sshape = arg.scale_shape;
  31. auto zeropoint_shape = arg.zeropoint_shape;
  32. auto gradscale_shape = arg.gradscale_shape;
  33. Checker<LSQForward> checker(handle_cuda());
  34. TensorLayout ilayout(
  35. ishape,
  36. {(long int)(ishape[1] * ishape[2] * ishape[3] * 2),
  37. (long int)(ishape[2] * ishape[3]), (long int)ishape[3], 1},
  38. dtype::Float32());
  39. checker.set_param(param).execl(
  40. {ilayout,
  41. {sshape, dtype::Float32()},
  42. {zeropoint_shape, dtype::Float32()},
  43. {gradscale_shape, dtype::Float32()},
  44. ilayout});
  45. }
  46. }
  47. TEST_F(CUDA, LSQ_BACKWARD) {
  48. std::vector<TestArg> args = get_args();
  49. auto dtype = dtype::Float32();
  50. for (auto&& arg : args) {
  51. auto param = arg.param;
  52. auto ishape = arg.ishape;
  53. auto scale_shape = arg.scale_shape;
  54. auto zeropoint_shape = arg.zeropoint_shape;
  55. auto gradscale_shape = arg.gradscale_shape;
  56. Checker<LSQBackward> checker(handle_cuda());
  57. checker.set_param(param)
  58. .set_dtype(0, dtype)
  59. .set_dtype(1, dtype)
  60. .set_dtype(2, dtype)
  61. .set_dtype(3, dtype)
  62. .set_dtype(4, dtype)
  63. .set_dtype(5, dtype)
  64. .set_dtype(6, dtype)
  65. .execs({ishape, ishape, scale_shape, zeropoint_shape, gradscale_shape,
  66. ishape, ishape});
  67. }
  68. // test noncontiguous layout
  69. for (auto&& arg : args) {
  70. auto param = arg.param;
  71. auto ishape = arg.ishape;
  72. auto sshape = arg.scale_shape;
  73. auto zeropoint_shape = arg.zeropoint_shape;
  74. auto gradscale_shape = arg.gradscale_shape;
  75. Checker<LSQBackward> checker(handle_cuda());
  76. TensorLayout ilayout(
  77. ishape,
  78. {(long int)(ishape[1] * ishape[2] * ishape[3] * 2),
  79. (long int)(ishape[2] * ishape[3]), (long int)ishape[3], 1},
  80. dtype::Float32());
  81. checker.set_param(param).execl(
  82. {ilayout,
  83. ilayout,
  84. {sshape, dtype::Float32()},
  85. {zeropoint_shape, dtype::Float32()},
  86. {gradscale_shape, dtype::Float32()},
  87. ilayout,
  88. ilayout});
  89. }
  90. }
  91. } // namespace test
  92. } // namespace megdnn