You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tile_repeat.h 1.4 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. #pragma once
  2. #include "megdnn/oprs.h"
  3. namespace megdnn {
  4. namespace test {
  5. namespace tile_repeat {
  6. struct Arg {
  7. TensorShape times, src, dst;
  8. Arg(TensorShape times, TensorShape src) : times(times), src(src) {
  9. dst = src;
  10. for (size_t i = 0; i < src.ndim; ++i) {
  11. dst[i] *= times[i];
  12. }
  13. }
  14. TileRepeatBase::Param param() {
  15. TileRepeatBase::Param param;
  16. param.times = times;
  17. return param;
  18. }
  19. };
  20. inline std::vector<Arg> get_args() {
  21. std::vector<Arg> args;
  22. args.emplace_back(TensorShape{3}, TensorShape{10000});
  23. args.emplace_back(TensorShape{1, 1}, TensorShape{200, 300});
  24. args.emplace_back(TensorShape{1, 3}, TensorShape{200, 300});
  25. args.emplace_back(TensorShape{2, 1}, TensorShape{200, 300});
  26. args.emplace_back(TensorShape{2, 3}, TensorShape{200, 300});
  27. for (unsigned mask = 0; mask < 32; ++mask) {
  28. auto b = [mask](unsigned bit) { return (mask >> bit) & 1; };
  29. args.emplace_back(
  30. TensorShape{b(0) + 1, b(1) + 1, b(2) + 1, b(3) + 1, b(4) + 1},
  31. TensorShape{3, 4, 5, 6, 7});
  32. }
  33. for (size_t i = 1; i < 10; ++i)
  34. for (size_t j = 1; j < 10; ++j) {
  35. args.emplace_back(TensorShape{i, j}, TensorShape{3, 4});
  36. }
  37. return args;
  38. }
  39. } // namespace tile_repeat
  40. } // namespace test
  41. } // namespace megdnn
  42. // vim: syntax=cpp.doxygen