You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

rnn.cpp 4.3 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. #include "megdnn/dtype.h"
  2. #include "megdnn/oprs.h"
  3. #include "test/common/checker.h"
  4. #include "test/naive/fixture.h"
  5. namespace megdnn {
  6. namespace test {
  7. TEST_F(NAIVE, RNN_FORWARD) {
  8. Checker<RNN> checker(handle(), false);
  9. size_t batch_size = 2;
  10. size_t input_size = 3;
  11. size_t hidden_size = 2;
  12. size_t seq_len = 2;
  13. size_t gate_hidden_size = hidden_size;
  14. RNN::Param param;
  15. param.num_layers = 1;
  16. param.bidirectional = false;
  17. param.bias = false;
  18. param.hidden_size = hidden_size;
  19. param.nonlineMode = param::RNN::NonlineMode::RELU;
  20. checker.set_param(param).exect(
  21. Testcase{
  22. TensorValue(
  23. {seq_len, batch_size, input_size}, dtype::Float32(),
  24. {-0.66536, 0.08049, 0.12008, 0.63423, 1.37801, 0.02591,
  25. 0.09153, 0.82866, -1.70429, -1.26624, -0.06421,
  26. 0.35816}), // input
  27. TensorValue(
  28. {batch_size, hidden_size}, dtype::Float32(),
  29. {-3.19544, -1.24232, 1.99512, -0.25692}), // hx
  30. TensorValue(
  31. {gate_hidden_size, input_size + hidden_size},
  32. dtype::Float32(),
  33. {0.35355, 0.35355, 0.35355, 0.35355, 0.35355, 0.35355,
  34. 0.35355, 0.35355, 0.35355, 0.35355}), // flattern weights
  35. {},
  36. {},
  37. {}},
  38. Testcase{
  39. {},
  40. {},
  41. {},
  42. TensorValue(
  43. {seq_len, batch_size, hidden_size}, dtype::Float32(),
  44. {0.0, 0.0, 1.3351, 1.3351, 0.0, 0.0, 0.6003,
  45. 0.6003}), // output
  46. TensorValue(
  47. {batch_size, hidden_size}, dtype::Float32(),
  48. {0.0, 0.0, 0.6003, 0.6003}), // hy
  49. TensorValue(
  50. {1, 2, 2, 2}, dtype::Float32(),
  51. {0.0, 0.0, 1.33512, 1.33512, 0.0, 0.0, 0.60031,
  52. 0.60031}) // reserve space
  53. });
  54. param.num_layers = 2;
  55. checker.set_param(param).exect(
  56. Testcase{
  57. TensorValue(
  58. {seq_len, batch_size, input_size}, dtype::Float32(),
  59. {-0.66536, 0.08049, 0.12008, 0.63423, 1.37801, 0.02591,
  60. 0.09153, 0.82866, -1.70429, -1.26624, -0.06421,
  61. 0.35816}), // input
  62. TensorValue(
  63. {2, batch_size, hidden_size}, dtype::Float32(),
  64. {-3.19544, -1.24232, 1.99512, -0.25692, -3.19544, -1.24232,
  65. 1.99512, -0.25692}), // hx
  66. TensorValue(
  67. {2, 9}, dtype::Float32(),
  68. {0.35355, 0.35355, 0.35355, 0.35355, 0.35355, 0.35355,
  69. 0.35355, 0.35355, 0.35355, 0.35355, 0.35355, 0.35355,
  70. 0.35355, 0.35355, 0.35355, 0.35355, 0.35355,
  71. 0.35355}), // weights
  72. {},
  73. {},
  74. {}},
  75. Testcase{
  76. {},
  77. {},
  78. {},
  79. TensorValue(
  80. {seq_len, batch_size, hidden_size}, dtype::Float32(),
  81. {0.0, 0.0, 1.5586, 1.5586, 0.0, 0.0, 1.5266,
  82. 1.5266}), // output
  83. TensorValue(
  84. {2, batch_size, hidden_size}, dtype::Float32(),
  85. {0.0, 0.0, 0.6003, 0.6003, 0.0, 0.0, 1.5266,
  86. 1.5266}), // hy
  87. TensorValue(
  88. {2, 2, 2, 2}, dtype::Float32(),
  89. {0.0, 0.0, 1.33512, 1.33512, 0.0, 0.0, 0.60031, 0.60031,
  90. 0.0, 0.0, 1.55861, 1.55861, 0.0, 0.0, 1.52658,
  91. 1.52658}) // reserve space
  92. });
  93. }
  94. } // namespace test
  95. } // namespace megdnn