You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

rnn.h 1.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. /**
  2. * \file dnn/test/common/rnn.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing, software
  8. * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
  9. * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include <vector>
  13. #include "megdnn/basic_types.h"
  14. #include "megdnn/opr_param_defs.h"
  15. namespace megdnn {
  16. namespace test {
  17. namespace rnn {
  18. struct TestArg {
  19. param::RNN param;
  20. TensorShape input, hx, flatten_weights;
  21. TestArg(param::RNN param, TensorShape input, TensorShape hx,
  22. TensorShape flatten_weights)
  23. : param(param), input(input), hx(hx), flatten_weights(flatten_weights) {}
  24. };
  25. inline std::vector<TestArg> get_args() {
  26. std::vector<TestArg> args;
  27. size_t batch_size = 2;
  28. size_t input_size = 3;
  29. size_t hidden_size = 2;
  30. size_t seq_len = 2;
  31. size_t gate_hidden_size = hidden_size;
  32. param::RNN param;
  33. param.num_layers = 1;
  34. param.bidirectional = false;
  35. param.bias = false;
  36. param.hidden_size = hidden_size;
  37. param.nonlineMode = param::RNN::NonlineMode::RELU;
  38. args.emplace_back(
  39. param, TensorShape{seq_len, batch_size, input_size},
  40. TensorShape{batch_size, hidden_size},
  41. TensorShape{gate_hidden_size, input_size + hidden_size});
  42. return args;
  43. }
  44. } // namespace rnn
  45. } // namespace test
  46. } // namespace megdnn