You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cumsum.cpp 2.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. #include "test/cuda/fixture.h"
  2. #include "megdnn/oprs.h"
  3. #include "test/common/checker.h"
  4. namespace megdnn {
  5. namespace test {
  6. TEST_F(CUDA, CUMSUM) {
  7. Checker<Cumsum> checker(handle_cuda());
  8. struct TestArg {
  9. param::Cumsum param;
  10. TensorShape shape;
  11. TestArg(param::Cumsum param, TensorShape shape) : param(param), shape(shape) {}
  12. };
  13. std::vector<TestArg> args, args_int32;
  14. for (auto shape :
  15. TensorShapeArray{{10000}, {33000, 33}, {100, 100, 100}, {30, 30, 30, 30}}) {
  16. for (size_t axis = 0; axis < shape.ndim; ++axis) {
  17. args.emplace_back(param::Cumsum(axis, true, true), shape);
  18. args.emplace_back(param::Cumsum(axis, true, false), shape);
  19. args.emplace_back(param::Cumsum(axis, false, true), shape);
  20. args.emplace_back(param::Cumsum(axis, false, false), shape);
  21. }
  22. }
  23. for (auto shape : TensorShapeArray{{1}, {10}, {100}, {1000}, {10000}, {100000}}) {
  24. args.emplace_back(param::Cumsum(0, true, true), shape);
  25. args.emplace_back(param::Cumsum(0, true, false), shape);
  26. args.emplace_back(param::Cumsum(0, false, true), shape);
  27. args.emplace_back(param::Cumsum(0, false, false), shape);
  28. }
  29. for (auto shape : TensorShapeArray{
  30. {1},
  31. {10},
  32. {100},
  33. {1000},
  34. {10000},
  35. {100000},
  36. {1000000},
  37. {1050000},
  38. {2100000}}) {
  39. args_int32.emplace_back(param::Cumsum(0, true, true), shape);
  40. args_int32.emplace_back(param::Cumsum(0, true, false), shape);
  41. args_int32.emplace_back(param::Cumsum(0, false, true), shape);
  42. args_int32.emplace_back(param::Cumsum(0, false, false), shape);
  43. }
  44. for (auto arg : args) {
  45. checker.set_param(arg.param);
  46. checker.set_epsilon(1e-2);
  47. checker.set_dtype(0, dtype::Float32()).execs({{arg.shape}, {}});
  48. checker.set_dtype(0, dtype::Int16()).execs({{arg.shape}, {}});
  49. checker.set_dtype(0, dtype::Int32()).execs({{arg.shape}, {}});
  50. }
  51. for (auto arg : args_int32) {
  52. checker.set_param(arg.param);
  53. checker.set_epsilon(1e-2);
  54. checker.set_dtype(0, dtype::Int32()).execs({{arg.shape}, {}});
  55. }
  56. }
  57. } // namespace test
  58. } // namespace megdnn
  59. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}