You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

lamb.cpp 1.2 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. #include "test/cuda/fixture.h"
  2. #include "test/common/checker.h"
  3. #include "test/common/rng.h"
  4. namespace megdnn {
  5. namespace test {
  6. TEST_F(CUDA, LAMBUpdate) {
  7. LAMBUpdate::Param param;
  8. param.beta_1 = 0.9;
  9. param.beta_2 = 0.999;
  10. param.eps = 1e-5;
  11. param.weight_decay = 0.4;
  12. param.lr = 1e-3;
  13. param.step = 1;
  14. param.bias_correction = true;
  15. param.always_adapt = false;
  16. Checker<LAMBUpdate> checker(handle_cuda());
  17. checker.set_epsilon(1e-3);
  18. UniformFloatRNG rng0(0, 1);
  19. auto run = [&](DType d) {
  20. checker.set_param(param)
  21. .set_rng(0, &rng0)
  22. .set_rng(1, &rng0)
  23. .set_dtype(0, dtype::Float32())
  24. .set_dtype(1, dtype::Float32())
  25. .set_dtype(2, dtype::Float32())
  26. .set_dtype(3, d)
  27. .set_dtype(4, dtype::Float32())
  28. .set_dtype(5, dtype::Float32())
  29. .set_dtype(6, dtype::Float32())
  30. .execs({{2}, {2}, {2}, {2}, {}, {}, {}});
  31. };
  32. run(dtype::Float32());
  33. run(dtype::Float16());
  34. run(dtype::BFloat16());
  35. }
  36. } // namespace test
  37. } // namespace megdnn