You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor_remap.cpp 1.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. #include "test/cuda/fixture.h"
  2. #include "test/common/checker.h"
  3. #include "test/common/tensor_remap.h"
  4. namespace megdnn {
  5. namespace test {
  6. TEST_F(CUDA, TENSOR_REMAP_FORWARD) {
  7. Checker<IndexingRemapForward> checker(handle_cuda());
  8. TensorShape src{11, 13, 17}, map{3, 5, 7, 3}, dst{3, 5, 7};
  9. checker.set_dtype(1, dtype::Int32());
  10. for (auto dt : std::vector<DType>{dtype::Float32(), dtype::Int32()}) {
  11. checker.set_dtype(0, dt);
  12. checker.set_dtype(2, dt);
  13. using namespace tensor_remap;
  14. {
  15. MapRNG rng(src);
  16. checker.set_rng(1, &rng).execs({src, map, {}});
  17. }
  18. {
  19. NonoverlappingMapRNG rng(src);
  20. checker.set_rng(1, &rng).execs({src, map, {}});
  21. }
  22. }
  23. }
  24. TEST_F(CUDA, TENSOR_REMAP_BACKWARD) {
  25. Checker<IndexingRemapBackward> checker(handle_cuda());
  26. checker.set_dtype(1, dtype::Int32());
  27. TensorShape src{11, 13, 17}, map{3, 5, 7, 3}, dst{3, 5, 7};
  28. checker.set_dtype(1, dtype::Int32());
  29. for (auto dt : std::vector<DType>{dtype::Float32(), dtype::Int32()}) {
  30. checker.set_dtype(0, dt);
  31. checker.set_dtype(2, dt);
  32. using namespace tensor_remap;
  33. {
  34. MapRNG rng(src);
  35. checker.set_rng(1, &rng).execs({dst, map, src});
  36. }
  37. {
  38. NonoverlappingMapRNG rng(src);
  39. checker.set_rng(1, &rng).execs({dst, map, src});
  40. }
  41. }
  42. }
  43. } // namespace test
  44. } // namespace megdnn
  45. // vim: syntax=cpp.doxygen