You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

rnn_cell.cpp 3.7 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. #include "megdnn/dtype.h"
  2. #include "megdnn/oprs.h"
  3. #include "test/common/checker.h"
  4. #include "test/naive/fixture.h"
  5. namespace megdnn {
  6. namespace test {
  7. TEST_F(NAIVE, RNNCELL) {
  8. Checker<RNNCell> checker(handle(), false);
  9. for (size_t batch : {1, 4})
  10. for (size_t inp : {3, 4, 5, 23, 100})
  11. for (size_t hidden : {3, 6, 25, 100}) {
  12. checker.exec(
  13. {{batch, inp},
  14. {hidden, inp},
  15. {1, hidden},
  16. {batch, hidden},
  17. {hidden, hidden},
  18. {1, hidden},
  19. {}});
  20. }
  21. size_t batch_size = 2;
  22. size_t input_size = 3;
  23. size_t hidden_size = 2;
  24. RNNCell::Param param;
  25. param.nonlineMode = param::RNNCell::NonlineMode::TANH;
  26. checker.set_param(param).exect(
  27. Testcase{
  28. TensorValue(
  29. {batch_size, input_size}, dtype::Float32(),
  30. {1, 2, 3, 4, 5, 6}), // input
  31. TensorValue(
  32. {hidden_size, input_size}, dtype::Float32(),
  33. {0.3535, 0.3535, 0.3535, 0.3535, 0.3535,
  34. 0.3535}), // weight_ih
  35. TensorValue({1, hidden_size}, dtype::Float32(), {0, 0}), // bias_ih
  36. TensorValue(
  37. {batch_size, hidden_size}, dtype::Float32(),
  38. {1, 2, 3, 4}), // hx
  39. TensorValue(
  40. {hidden_size, hidden_size}, dtype::Float32(),
  41. {0.3535, 0.3535, 0.3535, 0.3535}), // weight_hh
  42. TensorValue({1, hidden_size}, dtype::Float32(), {0, 0}), // bias_hh
  43. {}},
  44. Testcase{
  45. {},
  46. {},
  47. {},
  48. {},
  49. {},
  50. {},
  51. TensorValue(
  52. {batch_size, hidden_size}, dtype::Float32(),
  53. {0.9966, 0.9966, 1.0, 1.0}), // dst
  54. });
  55. batch_size = 2;
  56. input_size = 2;
  57. hidden_size = 1;
  58. param.nonlineMode = param::RNNCell::NonlineMode::RELU;
  59. checker.set_param(param).exect(
  60. Testcase{
  61. TensorValue(
  62. {batch_size, input_size}, dtype::Float32(),
  63. {1, 2, 3, 4}), // input
  64. TensorValue(
  65. {hidden_size, input_size}, dtype::Float32(),
  66. {0.3535, 0.3535}), // weight_ih
  67. TensorValue(
  68. {1, hidden_size}, dtype::Float32(), {0.3535}), // bias_ih
  69. TensorValue(
  70. {batch_size, hidden_size}, dtype::Float32(),
  71. {-1, -2}), // hx
  72. TensorValue(
  73. {hidden_size, hidden_size}, dtype::Float32(),
  74. {0.3535}), // weight_hh
  75. TensorValue(
  76. {1, hidden_size}, dtype::Float32(), {0.3535}), // bias_hh
  77. {}},
  78. Testcase{
  79. {},
  80. {},
  81. {},
  82. {},
  83. {},
  84. {},
  85. TensorValue(
  86. {batch_size, hidden_size}, dtype::Float32(),
  87. {1.414, 2.4745}), // hy
  88. });
  89. }
  90. } // namespace test
  91. } // namespace megdnn