You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

rnn.cpp 2.5 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. /**
  2. * \file dnn/test/arm_common/rnn.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "test/arm_common/fixture.h"
  12. #include "megdnn/oprs.h"
  13. #include "test/common/benchmarker.h"
  14. #include "test/common/checker.h"
  15. #include "test/common/task_record_check.h"
  16. using namespace megdnn;
  17. using namespace test;
  18. TEST_F(ARM_COMMON, RNNCell) {
  19. Checker<RNNCell> checker(handle());
  20. using NonlineMode = param::RNNCell::NonlineMode;
  21. param::RNNCell param;
  22. for (auto mode : {NonlineMode::IDENTITY, NonlineMode::RELU, NonlineMode::TANH})
  23. for (size_t batch : {1, 4})
  24. for (size_t n : {3, 4, 5, 23, 100})
  25. for (size_t h : {5, 23, 100})
  26. for (size_t out : {3, 6, 25, 100}) {
  27. param.nonlineMode = mode;
  28. checker.set_param(param);
  29. checker.exec(
  30. {{batch, n},
  31. {out, n},
  32. {1, out},
  33. {batch, h},
  34. {out, h},
  35. {1, out},
  36. {}});
  37. checker.exec(
  38. {{batch, n},
  39. {out, n},
  40. {batch, out},
  41. {batch, h},
  42. {out, h},
  43. {batch, out},
  44. {}});
  45. }
  46. }
  47. TEST_F(ARM_COMMON, RNNCellRecord) {
  48. TaskRecordChecker<RNNCell> checker(0);
  49. using NonlineMode = param::RNNCell::NonlineMode;
  50. param::RNNCell param;
  51. for (auto mode : {NonlineMode::IDENTITY, NonlineMode::RELU, NonlineMode::TANH}) {
  52. param.nonlineMode = mode;
  53. checker.set_param(param);
  54. checker.exec({{1, 100}, {10, 100}, {1, 10}, {1, 100}, {10, 100}, {1, 10}, {}});
  55. checker.exec({{1, 34}, {15, 34}, {1, 15}, {1, 34}, {15, 34}, {1, 15}, {}});
  56. checker.exec({{1, 73}, {25, 73}, {1, 25}, {1, 73}, {25, 73}, {1, 25}, {}});
  57. }
  58. }
  59. #if MEGDNN_WITH_BENCHMARK
  60. #endif
  61. // vim: syntax=cpp.doxygen