/** * \file src/opr/test/training/optimizer.cpp * * This file is part of MegBrain, a deep learning framework developed by Megvii. * * \copyright Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * */ #include "megbrain/opr/basic_arith_wrapper.h" #include "megbrain/opr/indexing.h" #include "megbrain/opr/io.h" #include "megbrain/opr/tensor_manip.h" #include "megbrain/tensor.h" #include "megbrain/test/helper.h" #include "megbrain/opr/training/optimizer.h" #include "megbrain/opr/training/utils.h" using namespace mgb; using namespace optimizer; namespace { class Device2HostCallback { public: Device2HostCallback(std::shared_ptr host) : m_host{host} {} void operator()(const DeviceTensorND& device) { m_host->copy_from(device).sync(); } private: std::shared_ptr m_host; }; template void assign_value(std::shared_ptr& tensor, std::vector& values) { ASSERT_EQ(values.size(), tensor->layout().total_nr_elems()); auto ptr = tensor->ptr(); for (size_t i = 0, it = tensor->layout().total_nr_elems(); i < it; i += 1) { ptr[i] = values.at(i); } } class OptimizerTest : public ::testing::Test { public: void verify( std::shared_ptr optimizer, std::shared_ptr weight, std::shared_ptr grad, std::shared_ptr truth, int execute_times); protected: std::shared_ptr optimizer; std::shared_ptr graph; }; void OptimizerTest::verify( std::shared_ptr optimizer, std::shared_ptr weight, std::shared_ptr grad, std::shared_ptr truth, int execute_times) { graph = cg::ComputingGraph::make(); SymbolVar symbol_weight = opr::SharedDeviceTensor::make(*graph, *weight); SymbolVar symbol_grad = opr::SharedDeviceTensor::make(*graph, *grad); cg::ComputingGraph::OutputSpec spec; spec.push_back( {optimizer->make(symbol_weight, symbol_grad, graph), Device2HostCallback(weight)}); auto func = graph->compile(spec); for (int i = 0; i < execute_times; i++) { func->execute(); } auto weight_ptr = weight->ptr(); auto truth_ptr = truth->ptr(); for (size_t i = 0, it = weight->shape().total_nr_elems(); i < it; i += 1) { ASSERT_NEAR(weight_ptr[i], truth_ptr[i], 0.001f); } } } // namespace TEST_F(OptimizerTest, SGD) { auto weight = TensorGen::constant({1}, 0.30542f); auto grad = TensorGen::constant({1}, -1.81453f); auto truth = TensorGen::constant({1}, 1.04673f); int execute_times = 10; std::shared_ptr sgd = std::make_shared(0.01f, 5e-2f, 0.9f); verify(sgd, weight, grad, truth, execute_times); } TEST_F(OptimizerTest, AdamTest) { auto weight = TensorGen::constant({1}, 1.62957f); auto grad = TensorGen::constant({1}, 1.02605f); auto truth = TensorGen::constant({1}, 1.52969f); int execute_times = 10; std::shared_ptr adam = std::make_shared(0.01f, 0.9f); verify(adam, weight, grad, truth, execute_times); }