You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_impl.h 863 B

12345678910111213141516171819202122232425
  1. #pragma once
  2. #include "megdnn/oprs.h"
  3. #include "src/cuda/cudnn_wrapper.h"
  4. namespace megdnn {
  5. namespace cuda {
  6. class LAMBUpdateImpl final : public LAMBUpdate {
  7. public:
  8. using LAMBUpdate::LAMBUpdate;
  9. void exec(
  10. _megdnn_tensor_in m_t_1, _megdnn_tensor_in v_t_1,
  11. _megdnn_tensor_in lamb_param, _megdnn_tensor_in grad,
  12. _megdnn_tensor_out m_t, _megdnn_tensor_out v_t,
  13. _megdnn_tensor_out new_param, _megdnn_workspace workspace) override;
  14. size_t get_workspace_in_bytes(
  15. const TensorLayout& m_t_1, const TensorLayout& v_t_1,
  16. const TensorLayout& lamb_param, const TensorLayout& grad,
  17. const TensorLayout& m_t, const TensorLayout& v_t,
  18. const TensorLayout& new_param) override {
  19. return m_t.access_bytes();
  20. };
  21. };
  22. } // namespace cuda
  23. } // namespace megdnn