You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

lstmcell.cpp 5.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. #include "megdnn/dtype.h"
  2. #include "megdnn/oprs.h"
  3. #include "test/common/checker.h"
  4. #include "test/naive/fixture.h"
  5. namespace megdnn {
  6. namespace test {
  7. TEST_F(NAIVE, LSTMCELL) {
  8. Checker<LSTMCell> checker(handle(), true);
  9. for (size_t batch : {1, 4})
  10. for (size_t n : {3, 4, 5, 23, 100})
  11. for (size_t out : {3, 6, 25, 100}) {
  12. checker.exec(
  13. {{batch, n},
  14. {out * 4, n},
  15. {1, out * 4},
  16. {batch, out},
  17. {out * 4, out},
  18. {1, out * 4},
  19. {batch, out},
  20. {},
  21. {},
  22. {}});
  23. }
  24. size_t batch_size = 2;
  25. size_t input_size = 3;
  26. size_t hidden_size = 2;
  27. checker.exect(
  28. Testcase{
  29. TensorValue(
  30. {batch_size, input_size}, dtype::Float32(),
  31. {1, 2, 3, 4, 5, 6}), // input
  32. TensorValue(
  33. {4 * hidden_size, input_size}, dtype::Float32(),
  34. {
  35. 0.3535, 0.3535, 0.3535, 0.3535, 0.3535, 0.3535,
  36. 0.3535, 0.3535, 0.3535, 0.3535, 0.3535, 0.3535,
  37. 0.3535, 0.3535, 0.3535, 0.3535, 0.3535, 0.3535,
  38. 0.3535, 0.3535, 0.3535, 0.3535, 0.3535, 0.3535,
  39. }), // weight_ih
  40. TensorValue(
  41. {4 * hidden_size}, dtype::Float32(),
  42. {0, 0, 0, 0, 0, 0, 0, 0}), // bias_ih
  43. TensorValue(
  44. {batch_size, hidden_size}, dtype::Float32(),
  45. {1, 2, 3, 4}), // hx
  46. TensorValue(
  47. {4 * hidden_size, hidden_size}, dtype::Float32(),
  48. {0.3535, 0.3535, 0.3535, 0.3535, 0.3535, 0.3535, 0.3535,
  49. 0.3535, 0.3535, 0.3535, 0.3535, 0.3535, 0.3535, 0.3535,
  50. 0.3535, 0.3535}), // weight_hh
  51. TensorValue(
  52. {4 * hidden_size}, dtype::Float32(),
  53. {0, 0, 0, 0, 0, 0, 0, 0}), // bias_hh
  54. TensorValue(
  55. {batch_size, hidden_size}, dtype::Float32(),
  56. {2, 3, 4, 5}), // cx
  57. {},
  58. {},
  59. {}},
  60. Testcase{
  61. {},
  62. {},
  63. {},
  64. {},
  65. {},
  66. {},
  67. {},
  68. TensorValue(
  69. {batch_size, hidden_size}, dtype::Float32(),
  70. {0.9541, 0.9593, 0.9995, 0.9996}), // hy
  71. TensorValue(
  72. {batch_size, hidden_size}, dtype::Float32(),
  73. {2.8771, 3.8373, 4.9979, 5.9975}), // cy
  74. TensorValue(
  75. {batch_size, 4 * hidden_size}, dtype::Float32(),
  76. {3.18198, 3.18198, 7.7781, 7.7781, 3.18198, 3.18198,
  77. 7.77817, 7.77817, 3.18198, 3.18198, 7.77817, 7.77817,
  78. 3.18198, 3.18198, 7.77817, 7.77817}), // cy
  79. });
  80. batch_size = 2;
  81. input_size = 2;
  82. hidden_size = 1;
  83. checker.exect(
  84. Testcase{
  85. TensorValue(
  86. {batch_size, input_size}, dtype::Float32(),
  87. {1, 2, 3, 4}), // input
  88. TensorValue(
  89. {4 * hidden_size, input_size}, dtype::Float32(),
  90. {0.3535, 0.3535, 0.3535, 0.3535, 0.3535, 0.3535, 0.3535,
  91. 0.3535}), // weight_ih
  92. TensorValue(
  93. {4 * hidden_size}, dtype::Float32(),
  94. {0.3535, 0.3535, 0.3535, 0.3535}), // bias_ih
  95. TensorValue(
  96. {batch_size, hidden_size}, dtype::Float32(), {1, 2}), // hx
  97. TensorValue(
  98. {4 * hidden_size, hidden_size}, dtype::Float32(),
  99. {0.3535, 0.3535, 0.3535, 0.3535}), // weight_hh
  100. TensorValue(
  101. {4 * hidden_size}, dtype::Float32(),
  102. {0.3535, 0.3535, 0.3535, 0.3535}), // bias_hh
  103. TensorValue(
  104. {batch_size, hidden_size}, dtype::Float32(), {4, 5}), // cx
  105. {},
  106. {},
  107. {}},
  108. Testcase{
  109. {},
  110. {},
  111. {},
  112. {},
  113. {},
  114. {},
  115. {},
  116. TensorValue(
  117. {batch_size, hidden_size}, dtype::Float32(),
  118. {0.8927, 0.9799}), // hy
  119. TensorValue(
  120. {batch_size, hidden_size}, dtype::Float32(),
  121. {4.4393, 5.8788}), // cy
  122. TensorValue(
  123. {batch_size, 4 * hidden_size}, dtype::Float32(),
  124. {2.1210, 3.8885, 2.1210, 3.8885, 2.1210, 3.8885, 2.1210,
  125. 3.8885}), // gates
  126. });
  127. }
  128. } // namespace test
  129. } // namespace megdnn