You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

broadcast.cpp 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. /**
  2. * \file imperative/src/impl/ops/broadcast.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/imperative/ops/autogen.h"
  12. #include "megbrain/opr/tensor_manip.h"
  13. #include "megbrain/graph/helper.h"
  14. #include "../op_trait.h"
  15. namespace mgb {
  16. namespace imperative {
  17. namespace broadcast {
  18. std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
  19. node_->cast_final_safe<opr::Broadcast>();
  20. return Broadcast::make();
  21. }
  22. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  23. auto&& op = def.cast_final_safe<Broadcast>();
  24. size_t nr_inp = inputs.size();
  25. mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp);
  26. OperatorNodeConfig config{op.make_name()};
  27. return opr::Broadcast::make(inputs[0], inputs[1], config);
  28. }
  29. bool valid_broadcast(const TensorShape& src_shape, const TensorShape& tar_shape) {
  30. size_t src_ndim = src_shape.ndim, tar_ndim = tar_shape.ndim;
  31. if (src_ndim > tar_ndim) {
  32. return false;
  33. }
  34. size_t min_ndim = src_ndim;
  35. for (size_t i = 0; i < min_ndim; ++i) {
  36. if (src_shape[src_ndim - i - 1] != 1 &&
  37. src_shape[src_ndim - i - 1] != tar_shape[tar_ndim - i - 1]) {
  38. return false;
  39. }
  40. }
  41. return true;
  42. }
  43. std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
  44. const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
  45. auto&& op = def.cast_final_safe<Broadcast>();
  46. size_t nr_inp = inputs.size();
  47. auto&& src = inputs[0];
  48. TensorShape out_shape;
  49. if (nr_inp == 1) {
  50. out_shape.ndim = op.shape.size();
  51. for (size_t i = 0; i < out_shape.ndim; ++i) {
  52. out_shape[i] = op.shape[i];
  53. }
  54. } else {
  55. auto&& tshp = inputs[1];
  56. if (tshp.layout.ndim == 0 || tshp.value.empty()) {
  57. out_shape.ndim = 0;
  58. return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}},
  59. false};
  60. }
  61. mgb_assert(
  62. tshp.layout.ndim == 1,
  63. "target shape of Broadcast expects ndim=1; got ndim=%lu actually",
  64. tshp.layout.ndim);
  65. size_t target_ndim = tshp.layout.shape[0];
  66. out_shape.ndim = target_ndim;
  67. auto* ptr = tshp.value.ptr<dt_int32>();
  68. for (size_t i = 0; i < target_ndim; ++i) {
  69. out_shape[i] = ptr[i];
  70. }
  71. }
  72. mgb_assert(
  73. valid_broadcast(src.layout, out_shape),
  74. "the input shape %s can not be broadcasted to target shape %s",
  75. src.layout.to_string().c_str(), out_shape.to_string().c_str());
  76. return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true};
  77. }
  78. SmallVector<TensorPtr> apply_on_physical_tensor(
  79. const OpDef& def, const SmallVector<TensorPtr>& inputs,
  80. SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
  81. auto&& op = def.cast_final_safe<Broadcast>();
  82. size_t nr_inp = inputs.size();
  83. TensorShape tshp;
  84. auto&& src = inputs[0];
  85. auto slayout = src->layout();
  86. if (nr_inp == 1) {
  87. tshp.ndim = op.shape.size();
  88. for (size_t i = 0; i < tshp.ndim; ++i) {
  89. tshp[i] = op.shape[i];
  90. }
  91. } else {
  92. auto&& tshp_nd = inputs[1];
  93. cg::copy_tensor_value_to_shape(
  94. tshp, tshp_nd->get_value().proxy_to_default_cpu());
  95. }
  96. TensorLayout tlayout = slayout.broadcast(tshp);
  97. // memory forward
  98. return {Tensor::make(src->blob(), src->offset(), tlayout)};
  99. }
  100. SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint(
  101. const OpDef& def, const SmallVector<TensorPtr>& inputs) {
  102. SmallVector<VarNode::LayoutConstraintCallback> layout_checker(inputs.size());
  103. return layout_checker;
  104. }
  105. OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast)
  106. .make_from_op_node(make_from_op_node)
  107. .apply_on_var_node(apply_on_var_node)
  108. .infer_output_attrs_fallible(infer_output_attrs_fallible)
  109. .apply_on_physical_tensor(apply_on_physical_tensor)
  110. .get_input_layout_constraint(get_input_layout_constraint)
  111. .fallback();
  112. } // namespace broadcast
  113. namespace reshape {
  114. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  115. auto&& op = static_cast<const Reshape&>(def);
  116. mgb_assert(inputs.size() == 2);
  117. OperatorNodeConfig config{op.make_name()};
  118. return opr::Reshape::make(inputs[0], inputs[1], op.param(), config);
  119. }
  120. std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
  121. const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
  122. auto&& op = def.cast_final_safe<Reshape>();
  123. size_t nr_inp = inputs.size();
  124. auto&& src = inputs[0];
  125. TensorShape out_shape;
  126. if (nr_inp == 1) {
  127. if (src.layout.ndim == 0 && op.axis != opr::Reshape::Param::INVALID_AXIS) {
  128. return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}},
  129. false};
  130. }
  131. out_shape.ndim = op.shape.size();
  132. for (size_t i = 0; i < out_shape.ndim; ++i) {
  133. out_shape[i] = op.shape[i];
  134. }
  135. if (src.layout.ndim == 0) {
  136. return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}},
  137. false};
  138. }
  139. } else {
  140. auto&& tshp = inputs[1];
  141. if (tshp.layout.ndim == 0 || tshp.value.empty()) {
  142. out_shape.ndim = 0;
  143. return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}},
  144. false};
  145. }
  146. mgb_assert(
  147. tshp.layout.ndim == 1,
  148. "target shape of Reshape expects ndim=1; got ndim=%lu actually",
  149. tshp.layout.ndim);
  150. if (src.layout.ndim == 0 && op.axis != opr::Reshape::Param::INVALID_AXIS) {
  151. return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}},
  152. false};
  153. }
  154. size_t target_ndim = tshp.layout.shape[0];
  155. out_shape.ndim = target_ndim;
  156. auto* ptr = tshp.value.ptr<dt_int32>();
  157. for (size_t i = 0; i < target_ndim; ++i) {
  158. out_shape[i] = ptr[i];
  159. }
  160. if (src.layout.ndim == 0) {
  161. return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}},
  162. false};
  163. }
  164. }
  165. if (op.axis != opr::Reshape::Param::INVALID_AXIS) {
  166. mgb_assert(out_shape[op.axis] == -1);
  167. out_shape[op.axis] = 1;
  168. mgb_assert(
  169. src.layout.total_nr_elems() % out_shape.total_nr_elems() == 0,
  170. "can not reshape from %s to %s", src.layout.to_string().c_str(),
  171. out_shape.to_string().c_str());
  172. out_shape[op.axis] = src.layout.total_nr_elems() / out_shape.total_nr_elems();
  173. } else {
  174. mgb_assert(
  175. src.layout.total_nr_elems() == out_shape.total_nr_elems(),
  176. "can not reshape from %s to %s", src.layout.to_string().c_str(),
  177. out_shape.to_string().c_str());
  178. }
  179. return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true};
  180. }
  181. SmallVector<TensorPtr> apply_on_physical_tensor(
  182. const OpDef& def, const SmallVector<TensorPtr>& inputs,
  183. SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
  184. auto&& op = def.cast_final_safe<Reshape>();
  185. size_t nr_inp = inputs.size();
  186. auto&& src = inputs[0];
  187. auto slayout = src->layout();
  188. TensorShape tshp;
  189. if (nr_inp == 1) {
  190. tshp.ndim = op.shape.size();
  191. for (size_t i = 0; i < tshp.ndim; ++i) {
  192. tshp[i] = op.shape[i];
  193. }
  194. } else {
  195. auto&& tshp_nd = inputs[1];
  196. cg::copy_tensor_value_to_shape(
  197. tshp, tshp_nd->get_value().proxy_to_default_cpu());
  198. }
  199. if (op.axis != opr::Reshape::Param::INVALID_AXIS) {
  200. mgb_assert(tshp[op.axis] == -1);
  201. tshp[op.axis] = 1;
  202. tshp[op.axis] = src->layout().total_nr_elems() / tshp.total_nr_elems();
  203. }
  204. TensorLayout tlayout;
  205. mgb_assert(slayout.try_reshape(tlayout, tshp));
  206. return {Tensor::make(src->blob(), src->offset(), tlayout)};
  207. }
  208. SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint(
  209. const OpDef& def, const SmallVector<TensorPtr>& inputs) {
  210. auto&& op = def.cast_final_safe<Reshape>();
  211. SmallVector<VarNode::LayoutConstraintCallback> layout_checker(inputs.size());
  212. layout_checker[0] = [&](const TensorLayout& layout) {
  213. TensorShape tshp;
  214. TensorLayout ret;
  215. if (inputs.size() == 1) {
  216. tshp.ndim = op.shape.size();
  217. for (size_t i = 0; i < tshp.ndim; ++i) {
  218. tshp[i] = op.shape[i];
  219. }
  220. } else {
  221. cg::copy_tensor_value_to_shape(
  222. tshp, inputs[1]->get_value().proxy_to_default_cpu());
  223. }
  224. if (op.axis != opr::Reshape::Param::INVALID_AXIS) {
  225. mgb_assert(tshp[op.axis] == -1);
  226. tshp[op.axis] = 1;
  227. tshp[op.axis] = layout.total_nr_elems() / tshp.total_nr_elems();
  228. }
  229. if (layout.try_reshape(ret, tshp)) {
  230. return true;
  231. } else {
  232. return false;
  233. }
  234. };
  235. return layout_checker;
  236. }
  237. OP_TRAIT_REG(Reshape, Reshape)
  238. .apply_on_var_node(apply_on_var_node)
  239. .infer_output_attrs_fallible(infer_output_attrs_fallible)
  240. .apply_on_physical_tensor(apply_on_physical_tensor)
  241. .get_input_layout_constraint(get_input_layout_constraint)
  242. .fallback();
  243. } // namespace reshape
  244. } // namespace imperative
  245. } // namespace mgb
  246. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}