You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

rnn.cpp 4.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. /**
  2. * \file dnn/test/naive/rnn.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing, software
  8. * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
  9. * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megdnn/dtype.h"
  12. #include "megdnn/oprs.h"
  13. #include "test/common/checker.h"
  14. #include "test/naive/fixture.h"
  15. namespace megdnn {
  16. namespace test {
  17. TEST_F(NAIVE, RNN_FORWARD) {
  18. Checker<RNN> checker(handle(), false);
  19. size_t batch_size = 2;
  20. size_t input_size = 3;
  21. size_t hidden_size = 2;
  22. size_t seq_len = 2;
  23. size_t gate_hidden_size = hidden_size;
  24. RNN::Param param;
  25. param.num_layers = 1;
  26. param.bidirectional = false;
  27. param.bias = false;
  28. param.hidden_size = hidden_size;
  29. param.nonlineMode = param::RNN::NonlineMode::RELU;
  30. checker.set_param(param).exect(
  31. Testcase{
  32. TensorValue(
  33. {seq_len, batch_size, input_size}, dtype::Float32(),
  34. {-0.66536, 0.08049, 0.12008, 0.63423, 1.37801, 0.02591,
  35. 0.09153, 0.82866, -1.70429, -1.26624, -0.06421,
  36. 0.35816}), // input
  37. TensorValue(
  38. {batch_size, hidden_size}, dtype::Float32(),
  39. {-3.19544, -1.24232, 1.99512, -0.25692}), // hx
  40. TensorValue(
  41. {gate_hidden_size, input_size + hidden_size},
  42. dtype::Float32(),
  43. {0.35355, 0.35355, 0.35355, 0.35355, 0.35355, 0.35355,
  44. 0.35355, 0.35355, 0.35355, 0.35355}), // flattern weights
  45. {},
  46. {},
  47. {}},
  48. Testcase{
  49. {},
  50. {},
  51. {},
  52. TensorValue(
  53. {seq_len, batch_size, hidden_size}, dtype::Float32(),
  54. {0.0, 0.0, 1.3351, 1.3351, 0.0, 0.0, 0.6003,
  55. 0.6003}), // output
  56. TensorValue(
  57. {batch_size, hidden_size}, dtype::Float32(),
  58. {0.0, 0.0, 0.6003, 0.6003}), // hy
  59. TensorValue(
  60. {1, 2, 2, 2}, dtype::Float32(),
  61. {0.0, 0.0, 1.33512, 1.33512, 0.0, 0.0, 0.60031,
  62. 0.60031}) // reserve space
  63. });
  64. param.num_layers = 2;
  65. checker.set_param(param).exect(
  66. Testcase{
  67. TensorValue(
  68. {seq_len, batch_size, input_size}, dtype::Float32(),
  69. {-0.66536, 0.08049, 0.12008, 0.63423, 1.37801, 0.02591,
  70. 0.09153, 0.82866, -1.70429, -1.26624, -0.06421,
  71. 0.35816}), // input
  72. TensorValue(
  73. {2, batch_size, hidden_size}, dtype::Float32(),
  74. {-3.19544, -1.24232, 1.99512, -0.25692, -3.19544, -1.24232,
  75. 1.99512, -0.25692}), // hx
  76. TensorValue(
  77. {2, 9}, dtype::Float32(),
  78. {0.35355, 0.35355, 0.35355, 0.35355, 0.35355, 0.35355,
  79. 0.35355, 0.35355, 0.35355, 0.35355, 0.35355, 0.35355,
  80. 0.35355, 0.35355, 0.35355, 0.35355, 0.35355,
  81. 0.35355}), // weights
  82. {},
  83. {},
  84. {}},
  85. Testcase{
  86. {},
  87. {},
  88. {},
  89. TensorValue(
  90. {seq_len, batch_size, hidden_size}, dtype::Float32(),
  91. {0.0, 0.0, 1.5586, 1.5586, 0.0, 0.0, 1.5266,
  92. 1.5266}), // output
  93. TensorValue(
  94. {2, batch_size, hidden_size}, dtype::Float32(),
  95. {0.0, 0.0, 0.6003, 0.6003, 0.0, 0.0, 1.5266,
  96. 1.5266}), // hy
  97. TensorValue(
  98. {2, 2, 2, 2}, dtype::Float32(),
  99. {0.0, 0.0, 1.33512, 1.33512, 0.0, 0.0, 0.60031, 0.60031,
  100. 0.0, 0.0, 1.55861, 1.55861, 0.0, 0.0, 1.52658,
  101. 1.52658}) // reserve space
  102. });
  103. }
  104. } // namespace test
  105. } // namespace megdnn