You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

lamb.cpp 912 B

12345678910111213141516171819202122232425
  1. #include "megdnn/oprs.h"
  2. #include "src/common/utils.h"
  3. namespace megdnn {
  4. void LAMBUpdate::deduce_layout(
  5. const TensorLayout& m_t_1, const TensorLayout& v_t_1,
  6. const TensorLayout& lamb_param, const TensorLayout& grad, TensorLayout& m_t,
  7. TensorLayout& v_t, TensorLayout& new_param) {
  8. m_t = m_t_1;
  9. v_t = v_t_1;
  10. new_param = lamb_param;
  11. MEGDNN_MARK_USED_VAR(grad);
  12. }
  13. void LAMBUpdate::check_exec(
  14. const TensorLayout& m_t_1, const TensorLayout& v_t_1,
  15. const TensorLayout& lamb_param, const TensorLayout& grad,
  16. const TensorLayout& m_t, const TensorLayout& v_t, const TensorLayout& new_param,
  17. size_t workspace_in_bytes) {
  18. auto required_workspace_in_bytes =
  19. get_workspace_in_bytes(m_t_1, v_t_1, lamb_param, grad, m_t, v_t, new_param);
  20. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  21. }
  22. } // namespace megdnn