You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

winograd_helper.h 2.8 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. /**
  2. * \file dnn/src/common/winograd/winograd_helper.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #pragma once
  13. #include <vector>
  14. #include "megdnn/dtype.h"
  15. #include "megdnn/oprs.h"
  16. namespace megdnn {
  17. namespace winograd {
  18. using NonlineMode = ::megdnn::ConvBias::Param::NonlineMode;
  19. using BiasMode = ConvBiasForward::BiasMode;
  20. /**
  21. * \brief Strategy helper, contains some helper function for debug kernel
  22. * implementation
  23. *
  24. * \warning The layout should be NCHW
  25. */
  26. template <typename ctype, typename dst_type, typename input_filter_compute_type,
  27. typename output_compute_type,
  28. param::ConvBias::Format layout = param::ConvBias::Format::NCHW,
  29. param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT>
  30. class StrategyHelper {
  31. public:
  32. static void filter(const ctype* filter,
  33. input_filter_compute_type* filter_transform_buf,
  34. input_filter_compute_type* transform_mid_buf, size_t OC,
  35. size_t IC, size_t oc_start, size_t oc_end, size_t m,
  36. size_t r, const std::vector<float>& interp_points,
  37. DType dtype, float rescale = 1.0f);
  38. static void input(const ctype* input,
  39. input_filter_compute_type* input_transform_buf,
  40. input_filter_compute_type* transform_mid_buf,
  41. int ih_start, int iw_start, size_t IH, size_t IW,
  42. size_t IC, size_t ic, size_t unit_idx, size_t nr_units_in_tile,
  43. size_t m, size_t r,
  44. const std::vector<float>& interp_points, DType dtype,
  45. float rescale = 1.0f);
  46. static void
  47. output(const output_compute_type* output_transform_buf,
  48. const output_compute_type* bias, dst_type* output,
  49. output_compute_type* transform_mid_buf, BiasMode bmode,
  50. NonlineMode nonline_mode, size_t oh_start, size_t ow_start,
  51. size_t OH, size_t OW, size_t oc_start, size_t oc_index,
  52. size_t unit_idx, size_t nr_units_in_tile, size_t m, size_t r,
  53. const std::vector<float>& interp_points, DType dtype,
  54. float input_filter_scale = 1.0f, // input_scale * filter_scale
  55. float input_filter_rescale = 1.0f, // input_rescale * filter_rescale
  56. float rescale = 1.0f);
  57. };
  58. } // namespace winograd
  59. } // namespace megdnn
  60. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台

Contributors (1)