You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

region_restricted_convolution.cpp 8.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. #include "megdnn/oprs/nn.h"
  2. #include "src/common/utils.cuh"
  3. #include "src/common/utils.h"
  4. using namespace megdnn;
  5. namespace {
  6. template <typename Param>
  7. std::string get_errmsg(
  8. const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst,
  9. const Param& param) {
  10. MEGDNN_MARK_USED_VAR(src);
  11. MEGDNN_MARK_USED_VAR(filter);
  12. MEGDNN_MARK_USED_VAR(dst);
  13. return megdnn_layout_msg(src) + ", " + megdnn_layout_msg(filter) + ", " +
  14. megdnn_layout_msg(dst) + ", " + "is_nchw=" +
  15. std::to_string(param.format == param::Convolution::Format::NCHW) + ", " +
  16. "is_xcorr=" +
  17. std::to_string((param.mode == Convolution::Mode::CROSS_CORRELATION)) + ", " +
  18. "pad_h=" + std::to_string(param.pad_h) + ", " +
  19. "pad_w=" + std::to_string(param.pad_w) + ", " +
  20. "stride_h=" + std::to_string(param.stride_h) + ", " +
  21. "stride_w=" + std::to_string(param.stride_w) + ", " +
  22. "dilate_h=" + std::to_string(param.dilate_h) + ", " +
  23. "dilate_w=" + std::to_string(param.dilate_w);
  24. }
  25. } // namespace
  26. namespace megdnn {
  27. void RegionRestrictedConvolutionForward::deduce_dtype(
  28. DType src, DType filter, DType rin, DType rout, DType& dst) {
  29. check_or_deduce_dtype_fwd(src, filter, dst);
  30. megdnn_assert(
  31. src.category() == DTypeCategory::FLOAT &&
  32. filter.category() == DTypeCategory::FLOAT &&
  33. dst.category() == DTypeCategory::FLOAT,
  34. "only float type is supported for region_restricted_conv forward");
  35. megdnn_assert(
  36. rin == rout && (rin == dtype::Int32() || rin == dtype::Uint8()),
  37. "the dtype of rin/rout should be Int32 or Uint8, got %s.", rin.name());
  38. }
  39. void RegionRestrictedConvolutionForward::deduce_layout(
  40. const TensorLayout& src, const TensorLayout& filter, const TensorLayout& rin,
  41. const TensorLayout& rout, TensorLayout& dst) {
  42. MEGDNN_MARK_USED_VAR(rin);
  43. MEGDNN_MARK_USED_VAR(rout);
  44. deduce_layout_fwd(src, filter, dst);
  45. }
  46. RegionRestrictedConvolutionForward::CanonizedFilterMeta
  47. RegionRestrictedConvolutionForward::check_exec(
  48. const TensorLayout& src, const TensorLayout& filter, const TensorLayout& rin,
  49. const TensorLayout& rout, const TensorLayout& dst, size_t workspace_in_bytes) {
  50. auto ret = check_layout_fwd(src, filter, dst);
  51. megdnn_assert(
  52. param().format == Param::Format::NCHW,
  53. "RegionRestrictedConv only support NCHW format mow.");
  54. megdnn_assert(
  55. param().stride_h == 1 && param().stride_w == 1,
  56. "RegionRestrictedConv only support stride 1.");
  57. #define err_msg(lhs, rhs) \
  58. megdnn_assert(lhs == rhs, "shape mismatch, #lhs:%zu, #rhs:%zu", lhs, rhs);
  59. err_msg(rin.shape[0], src.shape[0]);
  60. err_msg(rin.shape[1], src.shape[2]);
  61. err_msg(rin.shape[2], src.shape[3]);
  62. err_msg(rout.shape[0], dst.shape[0]);
  63. err_msg(rout.shape[1], dst.shape[2]);
  64. err_msg(rout.shape[2], dst.shape[3]);
  65. #undef err_msg
  66. auto required_workspace_in_bytes =
  67. get_workspace_in_bytes(src, filter, rin, rout, dst);
  68. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  69. return ret;
  70. }
  71. RegionRestrictedConvolutionBackwardData::CanonizedFilterMeta
  72. RegionRestrictedConvolutionBackwardData::check_exec(
  73. const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& rin,
  74. const TensorLayout& rout, const TensorLayout& grad, size_t workspace_in_bytes) {
  75. auto grad_fwd = grad;
  76. auto filter_fwd = filter;
  77. auto diff_fwd = diff;
  78. std::swap(grad_fwd.dtype, diff_fwd.dtype);
  79. grad_fwd.init_contiguous_stride();
  80. diff_fwd.init_contiguous_stride();
  81. auto ret = check_layout_fwd(grad_fwd, filter_fwd, diff_fwd);
  82. #define err_msg(lhs, rhs) \
  83. megdnn_assert(lhs == rhs, "shape mismatch, #lhs:%zu, #rhs:%zu", lhs, rhs);
  84. err_msg(rin.shape[0], grad_fwd.shape[0]); // batch
  85. err_msg(rin.shape[1], grad_fwd.shape[2]); // ih
  86. err_msg(rin.shape[2], grad_fwd.shape[3]); // iw
  87. err_msg(rout.shape[0], diff_fwd.shape[0]); // batch
  88. err_msg(rout.shape[1], diff_fwd.shape[2]); // oh
  89. err_msg(rout.shape[2], diff_fwd.shape[3]); // ow
  90. #undef err_msg
  91. auto required_workspace_in_bytes =
  92. get_workspace_in_bytes(filter, diff, rin, rout, grad);
  93. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  94. return ret;
  95. }
  96. void RegionRestrictedConvolutionBackwardData::deduce_dtype(
  97. DType filter, DType diff, DType rin, DType rout, DType& grad) {
  98. // FIXME: infering dtype of grad via naive impl only support fp32
  99. // (lack of quantized dtype infering or others) may not suitable in the furture
  100. #if !MEGDNN_DISABLE_FLOAT16
  101. if (diff.enumv() == DTypeEnum::Float32 || diff.enumv() == DTypeEnum::Float16) {
  102. grad = diff;
  103. }
  104. #endif
  105. megdnn_assert(grad.valid(), "dtype of grad requires deducing of assigned");
  106. megdnn_assert(
  107. diff.category() == DTypeCategory::FLOAT &&
  108. filter.category() == DTypeCategory::FLOAT &&
  109. grad.category() == DTypeCategory::FLOAT,
  110. "only float type is supported for region_restricted_conv backward data");
  111. megdnn_assert(
  112. rin == rout && (rin == dtype::Int32() || rin == dtype::Uint8()),
  113. "the dtype of rin/rout should be Int32 or Uint8, got %s.", rin.name());
  114. }
  115. void RegionRestrictedConvolutionBackwardData::deduce_layout(
  116. const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& rin,
  117. const TensorLayout& rout, TensorLayout& grad) {
  118. auto errmsg = [&]() { return get_errmsg(filter, diff, grad, param()); };
  119. MEGDNN_MARK_USED_VAR(errmsg);
  120. megdnn_assert_contiguous(filter);
  121. megdnn_assert_contiguous(diff);
  122. megdnn_assert(filter.ndim == 4_z || filter.ndim == 5_z, "%s", errmsg().c_str());
  123. megdnn_assert(diff.ndim == 4_z || diff.ndim == 5_z, "%s", errmsg().c_str());
  124. deduce_dtype(filter.dtype, diff.dtype, rin.dtype, rout.dtype, grad.dtype);
  125. auto cflt = make_canonized_filter_meta(diff.ndim, filter);
  126. auto deduce = [&errmsg](size_t out, size_t filter, size_t stride, size_t pad) {
  127. MEGDNN_MARK_USED_VAR(errmsg);
  128. auto i = (out - 1) * stride + filter;
  129. megdnn_assert(i > pad * 2, "%s", errmsg().c_str());
  130. return i - pad * 2;
  131. };
  132. megdnn_assert(
  133. param().format == Param::Format::NCHW,
  134. "RegionRestrictedConvolutionBackwardData only support NCHW format mow.");
  135. size_t src_or_dst_c_pos = 1;
  136. size_t src_or_dst_spatial_start = 2;
  137. megdnn_assert(
  138. cflt.ocpg * cflt.group == diff[src_or_dst_c_pos], "%s", errmsg().c_str());
  139. grad.ndim = diff.ndim;
  140. grad[0] = diff[0];
  141. grad[src_or_dst_c_pos] = cflt.icpg * cflt.group;
  142. for (size_t i = 0; i < cflt.spatial_ndim; ++i) {
  143. grad[i + src_or_dst_spatial_start] =
  144. deduce(diff[i + src_or_dst_spatial_start], cflt.dilated_spatial[i],
  145. cflt.stride[i], cflt.padding[i]);
  146. }
  147. grad.format = diff.format;
  148. grad.init_contiguous_stride();
  149. }
  150. RegionRestrictedConvolutionBackwardFilter::CanonizedFilterMeta
  151. RegionRestrictedConvolutionBackwardFilter::check_exec(
  152. const TensorLayout& src, const TensorLayout& diff, const TensorLayout& rin,
  153. const TensorLayout& rout, const TensorLayout& grad, size_t workspace_in_bytes) {
  154. megdnn_assert(
  155. src.dtype.category() == DTypeCategory::FLOAT &&
  156. diff.dtype.category() == DTypeCategory::FLOAT &&
  157. grad.dtype.category() == DTypeCategory::FLOAT,
  158. "only float type is supported for conv backward filter");
  159. auto src_fwd = src;
  160. auto diff_fwd = diff;
  161. src_fwd.init_contiguous_stride();
  162. diff_fwd.init_contiguous_stride();
  163. auto ret = check_layout_fwd(src_fwd, grad, diff_fwd);
  164. megdnn_assert(
  165. param().format == Param::Format::NCHW,
  166. "RegionRestrictedConvolutionBackwardFilter only support NCHW format mow.");
  167. #define err_msg(lhs, rhs) \
  168. megdnn_assert(lhs == rhs, "shape mismatch, #lhs:%zu, #rhs:%zu", lhs, rhs);
  169. err_msg(rin.shape[0], src_fwd.shape[0]);
  170. err_msg(rin.shape[1], src_fwd.shape[2]);
  171. err_msg(rin.shape[2], src_fwd.shape[3]);
  172. err_msg(rout.shape[0], diff_fwd.shape[0]);
  173. err_msg(rout.shape[1], diff_fwd.shape[2]);
  174. err_msg(rout.shape[2], diff_fwd.shape[3]);
  175. #undef err_msg
  176. auto required_workspace_in_bytes =
  177. get_workspace_in_bytes(src, diff, rin, rout, grad);
  178. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  179. return ret;
  180. }
  181. } // namespace megdnn
  182. // vim: syntax=cpp.doxygen