#include "test/cuda/fixture.h" #include "test/common/checker.h" #include "test/common/rng.h" namespace megdnn { namespace test { TEST_F(CUDA, LAMBUpdate) { LAMBUpdate::Param param; param.beta_1 = 0.9; param.beta_2 = 0.999; param.eps = 1e-5; param.weight_decay = 0.4; param.lr = 1e-3; param.step = 1; param.bias_correction = true; param.always_adapt = false; Checker checker(handle_cuda()); checker.set_epsilon(1e-3); UniformFloatRNG rng0(0, 1); auto run = [&](DType d) { checker.set_param(param) .set_rng(0, &rng0) .set_rng(1, &rng0) .set_dtype(0, dtype::Float32()) .set_dtype(1, dtype::Float32()) .set_dtype(2, dtype::Float32()) .set_dtype(3, d) .set_dtype(4, dtype::Float32()) .set_dtype(5, dtype::Float32()) .set_dtype(6, dtype::Float32()) .execs({{2}, {2}, {2}, {2}, {}, {}, {}}); }; run(dtype::Float32()); run(dtype::Float16()); run(dtype::BFloat16()); } } // namespace test } // namespace megdnn