You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

rnn.cpp 2.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. #include "test/arm_common/fixture.h"
  2. #include "megdnn/oprs.h"
  3. #include "test/common/benchmarker.h"
  4. #include "test/common/checker.h"
  5. #include "test/common/task_record_check.h"
  6. using namespace megdnn;
  7. using namespace test;
  8. TEST_F(ARM_COMMON, RNNCell) {
  9. Checker<RNNCell> checker(handle());
  10. using NonlineMode = param::RNNCell::NonlineMode;
  11. param::RNNCell param;
  12. for (auto mode : {NonlineMode::IDENTITY, NonlineMode::RELU, NonlineMode::TANH})
  13. for (size_t batch : {1, 4})
  14. for (size_t n : {3, 4, 5, 23, 100})
  15. for (size_t h : {5, 23, 100})
  16. for (size_t out : {3, 6, 25, 100}) {
  17. param.nonlineMode = mode;
  18. checker.set_param(param);
  19. checker.exec(
  20. {{batch, n},
  21. {out, n},
  22. {1, out},
  23. {batch, h},
  24. {out, h},
  25. {1, out},
  26. {}});
  27. checker.exec(
  28. {{batch, n},
  29. {out, n},
  30. {batch, out},
  31. {batch, h},
  32. {out, h},
  33. {batch, out},
  34. {}});
  35. }
  36. }
  37. TEST_F(ARM_COMMON, RNNCellRecord) {
  38. TaskRecordChecker<RNNCell> checker(0);
  39. using NonlineMode = param::RNNCell::NonlineMode;
  40. param::RNNCell param;
  41. for (auto mode : {NonlineMode::IDENTITY, NonlineMode::RELU, NonlineMode::TANH}) {
  42. param.nonlineMode = mode;
  43. checker.set_param(param);
  44. checker.exec({{1, 100}, {10, 100}, {1, 10}, {1, 100}, {10, 100}, {1, 10}, {}});
  45. checker.exec({{1, 34}, {15, 34}, {1, 15}, {1, 34}, {15, 34}, {1, 15}, {}});
  46. checker.exec({{1, 73}, {25, 73}, {1, 25}, {1, 73}, {25, 73}, {1, 25}, {}});
  47. }
  48. }
  49. #if MEGDNN_WITH_BENCHMARK
  50. #endif
  51. // vim: syntax=cpp.doxygen