You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tqt.cpp 2.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. #include "test/common/tqt.h"
  2. #include "megdnn/oprs.h"
  3. #include "test/common/checker.h"
  4. #include "test/cuda/fixture.h"
  5. namespace megdnn {
  6. namespace test {
  7. using namespace tqt;
  8. TEST_F(CUDA, TQT) {
  9. std::vector<TestArg> args = get_args();
  10. auto dtype = dtype::Float32();
  11. for (auto&& arg : args) {
  12. auto param = arg.param;
  13. auto ishape = arg.ishape;
  14. auto scale_shape = arg.scale_shape;
  15. Checker<TQTForward> checker(handle_cuda());
  16. checker.set_param(param)
  17. .set_dtype(0, dtype)
  18. .set_dtype(1, dtype)
  19. .set_dtype(2, dtype)
  20. .execs({ishape, scale_shape, ishape});
  21. }
  22. // test noncontiguous layout
  23. for (auto&& arg : args) {
  24. auto param = arg.param;
  25. auto ishape = arg.ishape;
  26. auto sshape = arg.scale_shape;
  27. Checker<TQTForward> checker(handle_cuda());
  28. TensorLayout ilayout(
  29. ishape,
  30. {(long int)(ishape[1] * ishape[2] * ishape[3] * 2),
  31. (long int)(ishape[2] * ishape[3]), (long int)ishape[3], 1},
  32. dtype::Float32());
  33. checker.set_param(param).execl({ilayout, {sshape, dtype::Float32()}, ilayout});
  34. }
  35. }
  36. TEST_F(CUDA, TQT_BACKWARD) {
  37. std::vector<TestArg> args = get_args();
  38. auto dtype = dtype::Float32();
  39. for (auto&& arg : args) {
  40. auto param = arg.param;
  41. auto ishape = arg.ishape;
  42. auto scale_shape = arg.scale_shape;
  43. Checker<TQTBackward> checker(handle_cuda());
  44. checker.set_param(param)
  45. .set_dtype(0, dtype)
  46. .set_dtype(1, dtype)
  47. .set_dtype(2, dtype)
  48. .set_dtype(3, dtype)
  49. .set_dtype(4, dtype)
  50. .execs({ishape, ishape, scale_shape, ishape, ishape});
  51. }
  52. // test noncontiguous layout
  53. for (auto&& arg : args) {
  54. auto param = arg.param;
  55. auto ishape = arg.ishape;
  56. auto sshape = arg.scale_shape;
  57. Checker<TQTBackward> checker(handle_cuda());
  58. TensorLayout ilayout(
  59. ishape,
  60. {(long int)(ishape[1] * ishape[2] * ishape[3] * 2),
  61. (long int)(ishape[2] * ishape[3]), (long int)ishape[3], 1},
  62. dtype::Float32());
  63. checker.set_param(param).execl(
  64. {ilayout, ilayout, {sshape, dtype::Float32()}, ilayout, ilayout});
  65. }
  66. }
  67. } // namespace test
  68. } // namespace megdnn