You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

rnn_cell.cpp 4.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. /**
  2. * \file dnn/test/naive/rnncell.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing, software
  8. * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
  9. * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megdnn/dtype.h"
  12. #include "megdnn/oprs.h"
  13. #include "test/common/checker.h"
  14. #include "test/naive/fixture.h"
  15. namespace megdnn {
  16. namespace test {
  17. TEST_F(NAIVE, RNNCELL) {
  18. Checker<RNNCell> checker(handle(), false);
  19. for (size_t batch : {1, 4})
  20. for (size_t inp : {3, 4, 5, 23, 100})
  21. for (size_t hidden : {3, 6, 25, 100}) {
  22. checker.exec(
  23. {{batch, inp},
  24. {hidden, inp},
  25. {1, hidden},
  26. {batch, hidden},
  27. {hidden, hidden},
  28. {1, hidden},
  29. {}});
  30. }
  31. size_t batch_size = 2;
  32. size_t input_size = 3;
  33. size_t hidden_size = 2;
  34. RNNCell::Param param;
  35. param.nonlineMode = param::RNNCell::NonlineMode::TANH;
  36. checker.set_param(param).exect(
  37. Testcase{
  38. TensorValue(
  39. {batch_size, input_size}, dtype::Float32(),
  40. {1, 2, 3, 4, 5, 6}), // input
  41. TensorValue(
  42. {hidden_size, input_size}, dtype::Float32(),
  43. {0.3535, 0.3535, 0.3535, 0.3535, 0.3535,
  44. 0.3535}), // weight_ih
  45. TensorValue({1, hidden_size}, dtype::Float32(), {0, 0}), // bias_ih
  46. TensorValue(
  47. {batch_size, hidden_size}, dtype::Float32(),
  48. {1, 2, 3, 4}), // hx
  49. TensorValue(
  50. {hidden_size, hidden_size}, dtype::Float32(),
  51. {0.3535, 0.3535, 0.3535, 0.3535}), // weight_hh
  52. TensorValue({1, hidden_size}, dtype::Float32(), {0, 0}), // bias_hh
  53. {}},
  54. Testcase{
  55. {},
  56. {},
  57. {},
  58. {},
  59. {},
  60. {},
  61. TensorValue(
  62. {batch_size, hidden_size}, dtype::Float32(),
  63. {0.9966, 0.9966, 1.0, 1.0}), // dst
  64. });
  65. batch_size = 2;
  66. input_size = 2;
  67. hidden_size = 1;
  68. param.nonlineMode = param::RNNCell::NonlineMode::RELU;
  69. checker.set_param(param).exect(
  70. Testcase{
  71. TensorValue(
  72. {batch_size, input_size}, dtype::Float32(),
  73. {1, 2, 3, 4}), // input
  74. TensorValue(
  75. {hidden_size, input_size}, dtype::Float32(),
  76. {0.3535, 0.3535}), // weight_ih
  77. TensorValue(
  78. {1, hidden_size}, dtype::Float32(), {0.3535}), // bias_ih
  79. TensorValue(
  80. {batch_size, hidden_size}, dtype::Float32(),
  81. {-1, -2}), // hx
  82. TensorValue(
  83. {hidden_size, hidden_size}, dtype::Float32(),
  84. {0.3535}), // weight_hh
  85. TensorValue(
  86. {1, hidden_size}, dtype::Float32(), {0.3535}), // bias_hh
  87. {}},
  88. Testcase{
  89. {},
  90. {},
  91. {},
  92. {},
  93. {},
  94. {},
  95. TensorValue(
  96. {batch_size, hidden_size}, dtype::Float32(),
  97. {1.414, 2.4745}), // hy
  98. });
  99. }
  100. } // namespace test
  101. } // namespace megdnn