You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tile_repeat.cpp 5.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. #include "megdnn/oprs.h"
  2. #include "src/common/utils.h"
  3. #include <numeric>
  4. namespace megdnn {
  5. void TileRepeatBase::check_layout_fwd(
  6. const TensorLayout& src, const TensorLayout& dst) {
  7. auto errmsg = megdnn_layout_msg(src) + ", " + megdnn_layout_msg(dst) + ", " +
  8. "times=" + param().times.to_string();
  9. auto errmsg_c = errmsg.c_str();
  10. MEGDNN_MARK_USED_VAR(errmsg_c);
  11. megdnn_assert_contiguous(src);
  12. megdnn_assert_contiguous(dst);
  13. auto expected_ndim = param().times.ndim;
  14. megdnn_assert(expected_ndim == src.ndim, "%s", errmsg_c);
  15. megdnn_assert(expected_ndim == dst.ndim, "%s", errmsg_c);
  16. rep(i, expected_ndim) {
  17. megdnn_assert(dst.shape[i] == param().times[i] * src.shape[i], "%s", errmsg_c);
  18. }
  19. megdnn_assert(src.dtype == dst.dtype);
  20. }
  21. void TileRepeatBase::deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst) {
  22. dst.ndim = src.ndim;
  23. rep(i, src.ndim) { dst.shape[i] = src.shape[i] * param().times[i]; }
  24. dst.dtype = src.dtype;
  25. dst.init_contiguous_stride();
  26. check_layout_fwd(src, dst);
  27. }
  28. size_t TileRepeatBase::get_workspace_in_bytes_fwd(
  29. const TensorShape& /* src */, const TensorShape& dst, const TensorShape& times,
  30. DType dtype) {
  31. size_t nr_workspace = 0;
  32. auto nr_reduces = count_not_ones_in_shape(times);
  33. if (nr_reduces == 0) {
  34. // case 1: no tile/repeat is needed, let alone workspace.
  35. nr_workspace = 0;
  36. } else if (nr_reduces == 1) {
  37. // case 2: only one tile/repeat is needed, so we don't need workspace.
  38. nr_workspace = 0;
  39. } else if (nr_reduces == 2) {
  40. // case 3: two tile/repeats are needed, so we need a single workspace.
  41. nr_workspace = 1;
  42. } else {
  43. // case 4: multiple tile/repeats are needed, so we need two workspace in
  44. // an alternate fashion.
  45. nr_workspace = 2;
  46. }
  47. if (nr_workspace == 0) {
  48. return 0;
  49. } else {
  50. WorkspaceBundle workspaces{
  51. nullptr, {nr_workspace, dst.total_nr_elems() * dtype.size()}};
  52. return workspaces.total_size_in_bytes();
  53. }
  54. }
  55. void TileBase::simplify_shape(
  56. const TensorShape& src, const TensorShape& dst, const TensorShape& times,
  57. TensorShape& src2, TensorShape& dst2, TensorShape& times2) {
  58. size_t n = 0;
  59. for (size_t i = 0; i < src.ndim; ++i) {
  60. if (times.shape[i] == 1 && n > 0) {
  61. src2.shape[n - 1] *= src.shape[i];
  62. dst2.shape[n - 1] *= dst.shape[i];
  63. } else {
  64. src2.shape[n] = src.shape[i];
  65. dst2.shape[n] = dst.shape[i];
  66. times2.shape[n] = times.shape[i];
  67. ++n;
  68. }
  69. }
  70. src2.ndim = dst2.ndim = times2.ndim = n;
  71. }
  72. size_t TileBase::get_workspace_in_bytes_fwd(
  73. const TensorLayout& src_, const TensorLayout& dst_) {
  74. TensorShape src, dst, times;
  75. simplify_shape(src_, dst_, param().times, src, dst, times);
  76. return TileRepeatBase::get_workspace_in_bytes_fwd(src, dst, times, src_.dtype);
  77. }
  78. void TileForward::deduce_layout(const TensorLayout& src, TensorLayout& dst) {
  79. deduce_layout_fwd(src, dst);
  80. }
  81. void TileForward::check_exec(
  82. const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes) {
  83. check_layout_fwd(src, dst);
  84. auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst);
  85. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  86. }
  87. void TileBackward::check_exec(
  88. const TensorLayout& diff, const TensorLayout& grad, size_t workspace_in_bytes) {
  89. check_layout_fwd(grad, diff);
  90. auto required_workspace_in_bytes = get_workspace_in_bytes(diff, grad);
  91. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  92. }
  93. void RepeatBase::simplify_shape(
  94. const TensorShape& src, const TensorShape& /* dst */, const TensorShape& times,
  95. TensorShape& src2, TensorShape& dst2, TensorShape& times2) {
  96. auto n = 0u;
  97. size_t i = 0;
  98. while (i < times.ndim) {
  99. size_t j = i;
  100. while (j < times.ndim && times.shape[j] == 1)
  101. ++j;
  102. // Here: j is times.ndim, or times.shape[j] != 1
  103. if (j < times.ndim)
  104. ++j;
  105. src2.shape[n] = std::accumulate(
  106. src.shape + i, src.shape + j, 1_z, SafeMultiplies<size_t>());
  107. times2.shape[n] = times.shape[j - 1];
  108. dst2.shape[n] = src2.shape[n] * times2.shape[n];
  109. ++n;
  110. i = j;
  111. }
  112. src2.ndim = dst2.ndim = times2.ndim = n;
  113. }
  114. size_t RepeatBase::get_workspace_in_bytes_fwd(
  115. const TensorLayout& src_, const TensorLayout& dst_) {
  116. TensorShape src, dst, times;
  117. simplify_shape(src_, dst_, param().times, src, dst, times);
  118. return TileRepeatBase::get_workspace_in_bytes_fwd(src, dst, times, src_.dtype);
  119. }
  120. void RepeatForward::deduce_layout(const TensorLayout& src, TensorLayout& dst) {
  121. deduce_layout_fwd(src, dst);
  122. }
  123. void RepeatForward::check_exec(
  124. const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes) {
  125. check_layout_fwd(src, dst);
  126. auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst);
  127. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  128. }
  129. void RepeatBackward::check_exec(
  130. const TensorLayout& diff, const TensorLayout& grad, size_t workspace_in_bytes) {
  131. check_layout_fwd(grad, diff);
  132. auto required_workspace_in_bytes = get_workspace_in_bytes(diff, grad);
  133. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  134. }
  135. } // namespace megdnn
  136. // vim: syntax=cpp.doxygen