You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv_bias.cpp 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. /**
  2. * \file dnn/src/common/conv_bias.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "src/common/conv_bias.h"
  13. #include "megdnn/oprs/nn.h"
  14. #include "src/common/utils.h"
  15. #include "src/common/opr_delegate.h"
  16. namespace megdnn {
  17. void ConvBiasForward::deduce_dtype(DType src, DType filter, DType /* bias */,
  18. DType /* z */, DType& dst) {
  19. check_or_deduce_dtype_fwd(src, filter, dst);
  20. }
  21. void ConvBiasForward::deduce_layout(const TensorLayout& src,
  22. const TensorLayout& filter,
  23. const TensorLayout& /* bias */,
  24. const TensorLayout& /* z */,
  25. TensorLayout& dst) {
  26. deduce_layout_fwd(src, filter, dst);
  27. }
  28. ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec(
  29. const TensorLayout& src, const TensorLayout& filter,
  30. const TensorLayout& bias, const TensorLayout& z,
  31. const TensorLayout& dst, size_t workspace_in_bytes,
  32. const PreprocessedFilter* preprocessed_filter) {
  33. megdnn_assert((src.dtype.enumv() == filter.dtype.enumv()) ||
  34. (src.dtype.enumv() == DTypeEnum::Quantized4Asymm &&
  35. filter.dtype.enumv() == DTypeEnum::QuantizedS4));
  36. // check compatibility of bias's scale
  37. if (src.dtype.category() == DTypeCategory::QUANTIZED) {
  38. if (bias.dtype.enumv() == DTypeEnum::QuantizedS32) {
  39. float scale_expected = mul_scale(src.dtype, filter.dtype);
  40. float scale_bias = bias.dtype.param<dtype::QuantizedS32>().scale;
  41. megdnn_assert(std::abs(scale_expected - scale_bias) < 1e-6,
  42. "scale_src: %f scale_filter: %f scale_bias: %f",
  43. get_scale(src.dtype), get_scale(filter.dtype),
  44. scale_bias);
  45. } else {
  46. megdnn_assert(bias.dtype.enumv() == DTypeEnum::Float32);
  47. }
  48. }
  49. auto ret = check_layout_fwd(src, filter, dst);
  50. megdnn_assert_contiguous(bias);
  51. auto required_workspace_in_bytes = get_workspace_in_bytes(
  52. src, filter, bias, z, dst, preprocessed_filter);
  53. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes,
  54. "worksapce have size of %zu, but need %zu",
  55. workspace_in_bytes, required_workspace_in_bytes);
  56. if (bias.ndim != 0) {
  57. //! bias.layout == dst.layout failed, no assert information
  58. auto check_eq = [](const TensorLayout& bias, const TensorLayout& dst) {
  59. if (dst.dtype.category() == DTypeCategory::QUANTIZED) {
  60. return bias.eq_shape(dst);
  61. } else {
  62. return bias.eq_layout(dst);
  63. }
  64. };
  65. if (check_eq(bias, dst))
  66. return ret;
  67. if (param().format == param::ConvBias::Format::NCHW ||
  68. param().format == param::ConvBias::Format::NCHW4_NCHW) {
  69. megdnn_assert(bias.shape[0] == 1);
  70. megdnn_assert(bias.shape[1] == dst.shape[1], "bias:%s, dst:%s",
  71. bias.to_string().c_str(), dst.to_string().c_str());
  72. megdnn_assert(bias.shape[2] == 1);
  73. megdnn_assert(bias.shape[3] == 1);
  74. } else if (param().format == param::ConvBias::Format::NHWC) {
  75. megdnn_assert(bias.shape[0] == 1);
  76. megdnn_assert(bias.shape[1] == 1);
  77. megdnn_assert(bias.shape[2] == 1);
  78. megdnn_assert(bias.shape[3] == dst.shape[3], "bias:%s, dst:%s",
  79. bias.to_string().c_str(), dst.to_string().c_str());
  80. } else if (param().format == param::ConvBias::Format::NCHW4 ||
  81. param().format == param::ConvBias::Format::NCHW44 ||
  82. param().format == param::ConvBias::Format::NCHW44_DOT ||
  83. param().format == param::ConvBias::Format::NCHW32_NCHW4) {
  84. megdnn_assert(bias.shape[0] == 1);
  85. megdnn_assert(bias.shape[1] == dst.shape[1], "bias:%s, dst:%s",
  86. bias.to_string().c_str(), dst.to_string().c_str());
  87. megdnn_assert(bias.shape[2] == 1);
  88. megdnn_assert(bias.shape[3] == 1);
  89. megdnn_assert(bias.shape[4] == 4);
  90. } else if (param().format == param::ConvBias::Format::NCHW8 ||
  91. param().format == param::ConvBias::Format::NCHW88 ) {
  92. megdnn_assert(bias.shape[0] == 1);
  93. megdnn_assert(bias.shape[1] == dst.shape[1], "bias:%s, dst:%s",
  94. bias.to_string().c_str(), dst.to_string().c_str());
  95. megdnn_assert(bias.shape[2] == 1);
  96. megdnn_assert(bias.shape[3] == 1);
  97. megdnn_assert(bias.shape[4] == 8);
  98. } else if (param().format == param::ConvBias::Format::NCHW32 ||
  99. param().format == param::ConvBias::Format::NCHW4_NCHW32) {
  100. megdnn_assert(bias.shape[0] == 1);
  101. megdnn_assert(bias.shape[1] == dst.shape[1], "bias:%s, dst:%s",
  102. bias.to_string().c_str(), dst.to_string().c_str());
  103. megdnn_assert(bias.shape[2] == 1);
  104. megdnn_assert(bias.shape[3] == 1);
  105. megdnn_assert(bias.shape[4] == 32);
  106. } else if (param().format == param::ConvBias::Format::CHWN4) {
  107. megdnn_assert(bias.shape[0] == dst.shape[0], "bias:%s, dst:%s",
  108. bias.to_string().c_str(), dst.to_string().c_str());
  109. megdnn_assert(bias.shape[1] == 1);
  110. megdnn_assert(bias.shape[2] == 1);
  111. megdnn_assert(bias.shape[3] == 1);
  112. megdnn_assert(bias.shape[4] == 4);
  113. } else if (param().format == param::ConvBias::Format::NCHW64) {
  114. megdnn_assert(bias.shape[0] == 1);
  115. megdnn_assert(bias.shape[1] == dst.shape[1], "bias:%s, dst:%s",
  116. bias.to_string().c_str(), dst.to_string().c_str());
  117. megdnn_assert(bias.shape[2] == 1);
  118. megdnn_assert(bias.shape[3] == 1);
  119. megdnn_assert(bias.shape[4] == 64);
  120. } else {
  121. megdnn_assert(param().format == param::ConvBias::Format::NHWCD4);
  122. megdnn_assert(bias.shape[0] == 1);
  123. megdnn_assert(bias.shape[1] == 1);
  124. megdnn_assert(bias.shape[2] == dst.shape[2], "bias:%s, dst:%s",
  125. bias.to_string().c_str(), dst.to_string().c_str());
  126. megdnn_assert(bias.shape[3] == 1);
  127. megdnn_assert(bias.shape[4] == 4);
  128. }
  129. }
  130. if (z.ndim != 0) {
  131. megdnn_assert(param().format != param::ConvBias::Format::NCHW4_NCHW32);
  132. megdnn_assert(param().format != param::ConvBias::Format::NCHW32_NCHW4);
  133. megdnn_assert(z.dtype.enumv() == dst.dtype.enumv());
  134. megdnn_assert(z.eq_shape(dst));
  135. }
  136. return ret;
  137. }
  138. template <typename T>
  139. struct NCHWParamTrait;
  140. template <typename T>
  141. struct NCHW44ParamTrait;
  142. std::string ConvBias::WinogradParam::to_string() const {
  143. return ssprintf("%u:%u:%u", channel_block_size, output_block_size,
  144. tile_size);
  145. }
  146. template <typename T>
  147. std::string ConvBias::algo_name(const std::string& base, const T& p,
  148. param::ConvBias::Format format) {
  149. if (format == param::ConvBias::Format::NCHW) {
  150. return ssprintf("%s:%s:%s", NCHWParamTrait<T>::category.c_str(),
  151. base.c_str(), p.to_string().c_str());
  152. } else if (format == param::ConvBias::Format::NCHW44) {
  153. return ssprintf("%s:%s:%s", NCHW44ParamTrait<T>::category.c_str(),
  154. base.c_str(), p.to_string().c_str());
  155. }
  156. megdnn_throw("Invalid format");
  157. return "";
  158. }
  159. #define FOREACH_CONV_BIAS_PARAM(cb) \
  160. cb(WinogradParam) cb(DirectParam) cb(MatmulParam) cb(DefaultParam)
  161. #define cb(pt) \
  162. template <> \
  163. struct NCHWParamTrait<ConvBias::pt> { \
  164. static const std::string category; \
  165. }; \
  166. template <> \
  167. struct NCHW44ParamTrait<ConvBias::pt> { \
  168. static const std::string category; \
  169. };
  170. FOREACH_CONV_BIAS_PARAM(cb)
  171. #undef cb
  172. #define cb(pt, ct) \
  173. const std::string NCHWParamTrait<ConvBias::pt>::category = ct; \
  174. const std::string NCHW44ParamTrait<ConvBias::pt>::category = ct
  175. cb(DirectParam, "DIRECT");
  176. cb(MatmulParam, "MATMUL");
  177. cb(DefaultParam, "DEFAULT");
  178. #undef cb
  179. const std::string NCHWParamTrait<ConvBias::WinogradParam>::category =
  180. "WINOGRAD";
  181. const std::string NCHW44ParamTrait<ConvBias::WinogradParam>::category =
  182. "WINOGRAD_NCHW44";
  183. #define cb(t) \
  184. template std::string ConvBias::algo_name<ConvBias::t>( \
  185. const std::string& base, const ConvBias::t& p, \
  186. param::ConvBias::Format format);
  187. FOREACH_CONV_BIAS_PARAM(cb)
  188. #undef cb
  189. ConvBias::WinogradParam ConvBias::parse_winograd_name(
  190. const std::string& algo_name) {
  191. ConvBias::WinogradParam ret = INVALID_WINOGRAD_PARAM;
  192. char base[128];
  193. char name[128];
  194. auto parse = [&](const std::string& algo_name,
  195. const std::string& pre) -> auto {
  196. memset(name, 0, 128);
  197. sscanf(algo_name.c_str(), "%[^:]:%[^:]:%u:%u:%u", name, base,
  198. &(ret.channel_block_size), &(ret.output_block_size),
  199. &(ret.tile_size));
  200. if (strcmp(name, pre.c_str())) {
  201. ret = INVALID_WINOGRAD_PARAM;
  202. return false;
  203. }
  204. if (ret.tile_size == 0 || ret.output_block_size == 0 ||
  205. ret.channel_block_size == 0) {
  206. ret = INVALID_WINOGRAD_PARAM;
  207. return false;
  208. }
  209. return true;
  210. };
  211. if (parse(algo_name, "WINOGRAD_NCHW44")) {
  212. return ret;
  213. } else {
  214. parse(algo_name, "WINOGRAD");
  215. return ret;
  216. }
  217. }
  218. constexpr ConvBias::WinogradParam ConvBias::INVALID_WINOGRAD_PARAM;
  219. void handle_bias_and_nonlinear(Handle* handle, param::ConvBias args,
  220. const TensorND* conv_dst_tensor,
  221. const TensorND* dst_tensor,
  222. const TensorND* bias_tensor) {
  223. using NonlineMode = param::ConvBias::NonlineMode;
  224. switch (args.nonlineMode) {
  225. #define cb(_mode) \
  226. case NonlineMode::_mode: { \
  227. if (conv_dst_tensor->layout.dtype.category() != \
  228. DTypeCategory::QUANTIZED) { \
  229. auto nonlinear = handle->create_operator<ElemwiseForward>(); \
  230. if (bias_tensor->layout.ndim > 0) { \
  231. nonlinear->param().mode = \
  232. Elemwise::Param::Mode::FUSE_ADD_##_mode; \
  233. nonlinear->exec({*conv_dst_tensor, *bias_tensor}, \
  234. *dst_tensor); \
  235. } else { \
  236. nonlinear->param().mode = Elemwise::Param::Mode::_mode; \
  237. nonlinear->exec({*conv_dst_tensor}, *dst_tensor); \
  238. } \
  239. } else { \
  240. auto nonlinear = handle->create_operator<ElemwiseMultiType>(); \
  241. if (bias_tensor->layout.ndim > 0) { \
  242. nonlinear->param().mode = \
  243. ElemwiseMultiType::Param::Mode::QFUSE_ADD_##_mode; \
  244. nonlinear->exec({*conv_dst_tensor, *bias_tensor}, \
  245. *dst_tensor); \
  246. } else { \
  247. nonlinear->param().mode = \
  248. ElemwiseMultiType::Param::Mode::Q##_mode; \
  249. nonlinear->exec({*conv_dst_tensor}, *dst_tensor); \
  250. } \
  251. } \
  252. break; \
  253. }
  254. cb(RELU);
  255. cb(H_SWISH);
  256. #undef cb
  257. case NonlineMode::SIGMOID: {
  258. megdnn_assert(conv_dst_tensor->layout.dtype.category() !=
  259. DTypeCategory::QUANTIZED);
  260. auto nonlinear = handle->create_operator<ElemwiseForward>();
  261. if (bias_tensor->layout.ndim > 0) {
  262. nonlinear->param().mode =
  263. Elemwise::Param::Mode::FUSE_ADD_SIGMOID;
  264. nonlinear->exec({*conv_dst_tensor, *bias_tensor},
  265. *conv_dst_tensor);
  266. } else {
  267. nonlinear->param().mode = Elemwise::Param::Mode::SIGMOID;
  268. nonlinear->exec({*conv_dst_tensor}, *conv_dst_tensor);
  269. }
  270. break;
  271. }
  272. case NonlineMode::IDENTITY: {
  273. if (bias_tensor->layout.ndim > 0) {
  274. if (dst_tensor->layout.dtype.category() ==
  275. DTypeCategory::QUANTIZED) {
  276. auto nonlinear =
  277. handle->create_operator<ElemwiseMultiType>();
  278. nonlinear->param().mode =
  279. ElemwiseMultiType::Param::Mode::QADD;
  280. nonlinear->exec({*conv_dst_tensor, *bias_tensor},
  281. *dst_tensor);
  282. } else {
  283. auto nonlinear = handle->create_operator<Elemwise>();
  284. nonlinear->param().mode = Elemwise::Param::Mode::ADD;
  285. nonlinear->exec({*conv_dst_tensor, *bias_tensor},
  286. *dst_tensor);
  287. }
  288. } else {
  289. if (conv_dst_tensor->layout.dtype != dst_tensor->layout.dtype) {
  290. handle->create_operator<TypeCvt>()->exec({*conv_dst_tensor},
  291. *dst_tensor);
  292. }
  293. }
  294. break;
  295. }
  296. default:
  297. megdnn_assert(false);
  298. }
  299. }
  300. } // namespace megdnn
  301. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台