You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

convolution3d.cpp 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. /**
  2. * \file dnn/src/common/convolution3d.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megdnn/oprs/nn.h"
  12. #include "src/common/utils.h"
  13. using namespace megdnn;
  14. namespace {
  15. std::string get_errmsg(const TensorLayout& src, const TensorLayout& filter,
  16. const TensorLayout& dst,
  17. const Convolution3D::Param& param) {
  18. MEGDNN_MARK_USED_VAR(src);
  19. MEGDNN_MARK_USED_VAR(filter);
  20. MEGDNN_MARK_USED_VAR(dst);
  21. return megdnn_layout_msg(src) + ", " + megdnn_layout_msg(filter) + ", " +
  22. megdnn_layout_msg(dst) + ", " + megdnn_mangle("is_ncdhw=") +
  23. std::to_string(param.format == param::Convolution3D::Format::NCDHW) +
  24. ", " + +megdnn_mangle("is_xcorr=") +
  25. std::to_string(
  26. (param.mode == Convolution3D::Mode::CROSS_CORRELATION)) +
  27. ", " + megdnn_mangle("pad_d=") + std::to_string(param.pad_d) + ", " +
  28. megdnn_mangle("pad_h=") + std::to_string(param.pad_h) + ", " +
  29. megdnn_mangle("pad_w=") + std::to_string(param.pad_w) + ", " +
  30. megdnn_mangle("stride_d=") + std::to_string(param.stride_d) + ", " +
  31. megdnn_mangle("stride_h=") + std::to_string(param.stride_h) + ", " +
  32. megdnn_mangle("stride_w=") + std::to_string(param.stride_w) + ", " +
  33. megdnn_mangle("dilate_d=") + std::to_string(param.dilate_d) + ", " +
  34. megdnn_mangle("dilate_h=") + std::to_string(param.dilate_h) + ", " +
  35. megdnn_mangle("dilate_w=") + std::to_string(param.dilate_w);
  36. }
  37. } // namespace
  38. Convolution3DBase::CanonizedFilterMeta
  39. Convolution3DBase::make_canonized_filter_meta(
  40. size_t src_ndim, const TensorLayout& filter) const {
  41. megdnn_assert_contiguous(filter);
  42. auto img_ndim = src_ndim - 2;
  43. CanonizedFilterMeta ret;
  44. ret.dtype_enum = filter.dtype.enumv();
  45. ret.format = param().format;
  46. if (param().mode == Mode::CONVOLUTION) {
  47. ret.should_flip = true;
  48. } else {
  49. megdnn_assert(param().mode == Mode::CROSS_CORRELATION,
  50. "invalid conv mode");
  51. ret.should_flip = false;
  52. }
  53. size_t flt_start, flt_spatial_start, ocpg_pos, icpg_pos;
  54. MEGDNN_MARK_USED_VAR(flt_spatial_start);
  55. MEGDNN_MARK_USED_VAR(ocpg_pos);
  56. MEGDNN_MARK_USED_VAR(icpg_pos);
  57. if (param().sparse == Param::Sparse::DENSE) {
  58. megdnn_assert(filter.ndim == img_ndim + 2,
  59. "bad filter ndim for dense convolution: "
  60. "spatial_ndim=%zu filter_ndim=%zu",
  61. img_ndim, filter.ndim);
  62. ret.group = 1;
  63. flt_start = 0;
  64. } else {
  65. megdnn_assert(param().sparse == Param::Sparse::GROUP,
  66. "invalid convolution sparse type");
  67. megdnn_assert(filter.ndim == img_ndim + 3,
  68. "bad filter ndim for group convolution: "
  69. "spatial_ndim=%zu filter_ndim=%zu",
  70. img_ndim, filter.ndim);
  71. ret.group = filter[0];
  72. flt_start = 1;
  73. }
  74. if (param().format == Param::Format::NCDHW) {
  75. // filter should be (oc, ic, fd, fh, fw)
  76. flt_spatial_start = 2;
  77. ocpg_pos = 0;
  78. icpg_pos = 1;
  79. } else {
  80. megdnn_assert(param().format == Param::Format::NDHWC,
  81. "invalid conv tensor format");
  82. // filter should be (oc, fd, fh, fw, ic)
  83. flt_spatial_start = 1;
  84. ocpg_pos = 0;
  85. icpg_pos = 4;
  86. }
  87. ret.spatial_ndim = src_ndim - 2;
  88. megdnn_assert(
  89. ret.spatial_ndim == 3,
  90. "only 3D convolution is supported, and input should be 5-dim; "
  91. "got input dim = %zu",
  92. src_ndim);
  93. ret.stride[0] = this->param().stride_d;
  94. ret.stride[1] = this->param().stride_h;
  95. ret.stride[2] = this->param().stride_w;
  96. ret.padding[0] = this->param().pad_d;
  97. ret.padding[1] = this->param().pad_h;
  98. ret.padding[2] = this->param().pad_w;
  99. ret.dilation[0] = param().dilate_d;
  100. ret.dilation[1] = param().dilate_h;
  101. ret.dilation[2] = param().dilate_w;
  102. ret.ocpg = filter[flt_start + ocpg_pos];
  103. ret.icpg = filter[flt_start + icpg_pos];
  104. for (size_t i = 0; i < ret.spatial_ndim; ++i) {
  105. megdnn_assert(ret.dilation[i] > 0,
  106. "invalid dilation on spatial dim %zu: %u", i,
  107. ret.dilation[i]);
  108. ret.spatial[i] = filter[i + flt_start + flt_spatial_start];
  109. ret.dilated_spatial[i] = (ret.spatial[i] - 1) * ret.dilation[i] + 1;
  110. }
  111. return ret;
  112. }
  113. Convolution3DBase::CanonizedFilterMeta Convolution3DBase::deduce_layout_fwd(
  114. const TensorLayout& src, const TensorLayout& filter,
  115. TensorLayout& dst) const {
  116. auto errmsg = [&]() { return get_errmsg(src, filter, dst, param()); };
  117. MEGDNN_MARK_USED_VAR(errmsg);
  118. megdnn_assert_contiguous(src);
  119. megdnn_assert_contiguous(filter);
  120. megdnn_assert(src.ndim >= 5_z, "%s", errmsg().c_str());
  121. megdnn_assert(src.dtype == filter.dtype, "%s", errmsg().c_str());
  122. if (param().data_type == Param::DataType::FLOAT) {
  123. megdnn_assert(src.dtype == dtype::Float32() MEGDNN_INC_FLOAT16(
  124. || src.dtype == dtype::Float16()),
  125. "invalid src dtype for conv: %s", src.dtype.name());
  126. dst.dtype = src.dtype;
  127. } else {
  128. megdnn_assert(param().data_type == Param::DataType::FLOAT_IO16xC32);
  129. MEGDNN_INC_FLOAT16(megdnn_assert(src.dtype == dtype::Float16(),
  130. "invalid src dtype for conv: %s", src.dtype.name()));
  131. MEGDNN_INC_FLOAT16(dst.dtype = dtype::Float16());
  132. }
  133. auto img_dim = src.ndim - 2;
  134. megdnn_assert(img_dim == 3, "this is the convolution for 3D image");
  135. megdnn_assert(filter.ndim == img_dim + 2 || filter.ndim == img_dim + 3,
  136. "%s", errmsg().c_str());
  137. auto cflt = make_canonized_filter_meta(src.ndim, filter);
  138. size_t src_or_dst_c_pos = 0;
  139. size_t src_or_dst_spatial_start = 0;
  140. if (param().format == Param::Format::NCDHW) {
  141. src_or_dst_c_pos = 1;
  142. src_or_dst_spatial_start = 2;
  143. } else {
  144. megdnn_assert(param().format == Param::Format::NDHWC,
  145. "invalid conv format");
  146. src_or_dst_c_pos = 4;
  147. src_or_dst_spatial_start = 1;
  148. }
  149. megdnn_assert(cflt.icpg * cflt.group == src[src_or_dst_c_pos], "%s",
  150. errmsg().c_str());
  151. dst.ndim = src.ndim;
  152. dst[0] = src[0];
  153. dst[src_or_dst_c_pos] = cflt.ocpg * cflt.group;
  154. for (size_t i = 0; i < cflt.spatial_ndim; ++i) {
  155. dst[i + src_or_dst_spatial_start] = infer_conv_shape(
  156. src[i + src_or_dst_spatial_start], cflt.dilated_spatial[i],
  157. cflt.stride[i], cflt.padding[i]);
  158. }
  159. dst.init_contiguous_stride();
  160. return cflt;
  161. }
  162. Convolution3DBase::CanonizedFilterMeta Convolution3DBase::check_layout_fwd(
  163. const TensorLayout& src, const TensorLayout& filter,
  164. const TensorLayout& dst) const {
  165. TensorLayout dst_expected;
  166. auto ret = deduce_layout_fwd(src, filter, dst_expected);
  167. megdnn_assert_eq_layout(dst_expected, dst);
  168. return ret;
  169. }
  170. void Convolution3DForward::deduce_layout(const TensorLayout& src,
  171. const TensorLayout& filter,
  172. TensorLayout& dst) {
  173. deduce_layout_fwd(src, filter, dst);
  174. }
  175. Convolution3DBase::CanonizedFilterMeta Convolution3DForward::check_exec(
  176. const TensorLayout& src, const TensorLayout& filter,
  177. const TensorLayout& dst, size_t workspace_in_bytes) {
  178. auto ret = check_layout_fwd(src, filter, dst);
  179. auto required_workspace_in_bytes = get_workspace_in_bytes(src, filter, dst);
  180. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  181. return ret;
  182. }
  183. Convolution3DBase::CanonizedFilterMeta Convolution3DBackwardData::check_exec(
  184. const TensorLayout& filter, const TensorLayout& diff,
  185. const TensorLayout& grad, size_t workspace_in_bytes) {
  186. megdnn_assert(param().data_type == Param::DataType::FLOAT,
  187. "only float type is supported for conv backward");
  188. auto ret = check_layout_fwd(grad, filter, diff);
  189. auto required_workspace_in_bytes =
  190. get_workspace_in_bytes(filter, diff, grad);
  191. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  192. return ret;
  193. }
  194. void Convolution3DBackwardData::deduce_layout(const TensorLayout& filter,
  195. const TensorLayout& diff,
  196. TensorLayout& grad) {
  197. megdnn_assert(param().data_type == Param::DataType::FLOAT,
  198. "only float type is supported for conv backward");
  199. auto errmsg = [&]() { return get_errmsg(filter, diff, grad, param()); };
  200. MEGDNN_MARK_USED_VAR(errmsg);
  201. megdnn_assert_contiguous(filter);
  202. megdnn_assert_contiguous(diff);
  203. megdnn_assert(filter.ndim == 5_z || filter.ndim == 6_z, "%s",
  204. errmsg().c_str());
  205. megdnn_assert(diff.ndim == 5_z, "%s", errmsg().c_str());
  206. megdnn_assert(filter.dtype == diff.dtype, "%s", errmsg().c_str());
  207. auto cflt = make_canonized_filter_meta(diff.ndim, filter);
  208. megdnn_assert(cflt.ocpg * cflt.group == diff[1], "%s", errmsg().c_str());
  209. auto deduce = [&errmsg](size_t out, size_t filter, size_t stride,
  210. size_t pad) {
  211. MEGDNN_MARK_USED_VAR(errmsg);
  212. auto i = (out - 1) * stride + filter;
  213. megdnn_assert(i > pad * 2, "%s", errmsg().c_str());
  214. return i - pad * 2;
  215. };
  216. grad.ndim = diff.ndim;
  217. grad[0] = diff[0];
  218. grad[1] = cflt.group * cflt.icpg;
  219. grad.dtype = diff.dtype;
  220. for (size_t i = 0; i < cflt.spatial_ndim; ++i) {
  221. grad[i + 2] = deduce(diff[i + 2], cflt.dilated_spatial[i],
  222. cflt.stride[i], cflt.padding[i]);
  223. }
  224. grad.init_contiguous_stride();
  225. }
  226. Convolution3DBase::CanonizedFilterMeta Convolution3DBackwardFilter::check_exec(
  227. const TensorLayout& src, const TensorLayout& diff,
  228. const TensorLayout& grad, size_t workspace_in_bytes) {
  229. megdnn_assert(param().data_type == Param::DataType::FLOAT,
  230. "only float type is supported for conv backward");
  231. auto ret = check_layout_fwd(src, grad, diff);
  232. auto required_workspace_in_bytes = get_workspace_in_bytes(src, diff, grad);
  233. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  234. return ret;
  235. }
  236. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台