You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

123456789101112131415161718192021222324252627282930313233343536
  1. #pragma once
  2. #include "megdnn/basic_types.h"
  3. #include "megdnn/opr_param_defs.h"
  4. namespace megdnn {
  5. namespace test {
  6. namespace lamb {
  7. struct TestArg {
  8. param::LAMBUpdate param;
  9. TensorShape src;
  10. TestArg(param::LAMBUpdate param, TensorShape src) : param(param), src(src) {}
  11. };
  12. inline std::vector<TestArg> get_args() {
  13. std::vector<TestArg> args;
  14. param::LAMBUpdate cur_param;
  15. cur_param.beta_1 = 0.9;
  16. cur_param.beta_2 = 0.999;
  17. cur_param.eps = 1e-8;
  18. cur_param.weight_decay = 0;
  19. cur_param.lr = 6.25e-5;
  20. cur_param.bias_correction = true;
  21. cur_param.always_adapt = false;
  22. args.emplace_back(
  23. cur_param, TensorShape{
  24. 1280,
  25. });
  26. args.emplace_back(cur_param, TensorShape{1280, 1280});
  27. args.emplace_back(cur_param, TensorShape{1280, 3, 224, 224});
  28. return args;
  29. }
  30. } // namespace lamb
  31. } // namespace test
  32. } // namespace megdnn