You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

rnn.cpp 3.0 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. /**
  2. * \file dnn/test/naive/rnn.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing, software
  8. * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
  9. * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "test/common/rnn.h"
  12. #include "megdnn/dtype.h"
  13. #include "megdnn/oprs.h"
  14. #include "test/common/checker.h"
  15. #include "test/naive/fixture.h"
  16. namespace megdnn {
  17. namespace test {
  18. /*TEST_F(NAIVE, RNN) {
  19. std::vector<rnn::TestArg> args = rnn::get_args();
  20. Checker<RNN> checker(handle());
  21. for (auto&& arg : args) {
  22. checker.set_param(arg.param)
  23. .set_dtype(0, dtype::Float32())
  24. .set_dtype(1, dtype::Float32())
  25. .set_dtype(2, dtype::Float32())
  26. .set_dtype(3, dtype::Float32())
  27. .set_dtype(4, dtype::Float32())
  28. .set_dtype(5, dtype::Float32())
  29. .execs({arg.input, arg.hx, arg.flatten_weights, {}, {}, {}});
  30. }
  31. }*/
  32. TEST_F(NAIVE, RNN_HAND_MADE) {
  33. Checker<RNN> checker(handle(), false);
  34. size_t batch_size = 2;
  35. size_t input_size = 3;
  36. size_t hidden_size = 2;
  37. size_t seq_len = 2;
  38. size_t gate_hidden_size = hidden_size;
  39. RNN::Param param;
  40. param.num_layers = 1;
  41. param.bidirectional = false;
  42. param.bias = false;
  43. param.hidden_size = hidden_size;
  44. param.nonlineMode = param::RNN::NonlineMode::RELU;
  45. checker.set_param(param).exect(
  46. Testcase{
  47. TensorValue(
  48. {seq_len, batch_size, input_size}, dtype::Float32(),
  49. {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}), // input
  50. TensorValue(
  51. {batch_size, hidden_size}, dtype::Float32(),
  52. {2, 1, 3, 5}), // hx
  53. TensorValue(
  54. {gate_hidden_size, input_size + hidden_size},
  55. dtype::Float32(),
  56. {3, 6, 1, 3, 2, 7, 9, 3, 5, 1}), // weights
  57. {},
  58. {},
  59. {}},
  60. Testcase{
  61. {},
  62. {},
  63. {},
  64. TensorValue(
  65. {seq_len, batch_size, hidden_size}, dtype::Float32(),
  66. {39, 39, 90, 84, 300, 216, 546, 366}), // output
  67. TensorValue(
  68. {batch_size, hidden_size}, dtype::Float32(),
  69. {21, 11, 42, 20}), // hy
  70. TensorValue(
  71. {1, 2, 2, 2}, dtype::Float32(),
  72. {2, 1, 3, 5, 21, 11, 42, 20}) // reserve space
  73. });
  74. }
  75. } // namespace test
  76. } // namespace megdnn