You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

lamb.cpp 1.1 kB

123456789101112131415161718192021222324252627282930313233
  1. #include "test/common/lamb.h"
  2. #include "megdnn/dtype.h"
  3. #include "megdnn/oprs.h"
  4. #include "test/common/checker.h"
  5. #include "test/naive/fixture.h"
  6. using namespace megdnn;
  7. using namespace test;
  8. TEST_F(NAIVE, LAMBUpdate) {
  9. Checker<LAMBUpdate> checker(handle(), false);
  10. LAMBUpdate::Param param;
  11. param.beta_1 = 0;
  12. param.beta_2 = 0;
  13. param.eps = 0;
  14. param.weight_decay = 0;
  15. param.lr = 1;
  16. param.step = 1;
  17. param.bias_correction = true;
  18. param.always_adapt = false;
  19. TensorND m_t_1 = TensorValue({2}, dtype::Float32(), {1, 1});
  20. TensorND v_t_1 = TensorValue({2}, dtype::Float32(), {1, 1});
  21. TensorND param_lamb = TensorValue({2}, dtype::Float32(), {1, 1});
  22. TensorND grad = TensorValue({2}, dtype::Float16(), {1, 1});
  23. TensorND m_t = TensorValue({2}, dtype::Float32(), {1, 1});
  24. TensorND v_t = TensorValue({2}, dtype::Float32(), {1, 1});
  25. TensorND new_param = TensorValue({2}, dtype::Float32(), {0, 0});
  26. checker.set_param(param).exect(
  27. Testcase{m_t_1, v_t_1, param_lamb, grad, {}, {}, {}},
  28. Testcase{{}, {}, {}, {}, m_t, v_t, new_param});
  29. }