You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

broadcast.cpp 8.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. /**
  2. * \file imperative/src/impl/ops/broadcast.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/imperative/ops/autogen.h"
  12. #include "megbrain/opr/tensor_manip.h"
  13. #include "megbrain/graph/helper.h"
  14. #include "../op_trait.h"
  15. namespace mgb {
  16. namespace imperative {
  17. namespace broadcast {
  18. std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
  19. node_->cast_final_safe<opr::Broadcast>();
  20. return Broadcast::make();
  21. }
  22. auto apply_on_var_node(
  23. const OpDef& def,
  24. const VarNodeArray& inputs) {
  25. auto&& op = def.cast_final_safe<Broadcast>();
  26. size_t nr_inp = inputs.size();
  27. mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp);
  28. OperatorNodeConfig config{op.make_name()};
  29. return opr::Broadcast::make(inputs[0], inputs[1], config);
  30. }
  31. bool valid_broadcast(const TensorShape& src_shape,
  32. const TensorShape& tar_shape) {
  33. size_t src_ndim = src_shape.ndim, tar_ndim = tar_shape.ndim;
  34. if (src_ndim > tar_ndim) {
  35. return false;
  36. }
  37. size_t min_ndim = src_ndim;
  38. for (size_t i = 0; i < min_ndim; ++i) {
  39. if (src_shape[src_ndim - i - 1] != 1 &&
  40. src_shape[src_ndim - i - 1] != tar_shape[tar_ndim - i - 1]) {
  41. return false;
  42. }
  43. }
  44. return true;
  45. }
  46. std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
  47. const OpDef& def,
  48. const SmallVector<LogicalTensorDesc>& inputs) {
  49. def.cast_final_safe<Broadcast>();
  50. size_t nr_inp = inputs.size();
  51. mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp);
  52. auto&& src = inputs[0];
  53. auto&& tshp = inputs[1];
  54. TensorShape out_shape;
  55. if (tshp.layout.ndim == 0 || tshp.value.empty()) {
  56. out_shape.ndim = 0;
  57. return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false};
  58. }
  59. mgb_assert(
  60. tshp.layout.ndim == 1,
  61. "target shape of Broadcast expects ndim=1; got ndim=%lu actually",
  62. tshp.layout.ndim);
  63. size_t target_ndim = tshp.layout.shape[0];
  64. out_shape.ndim = target_ndim;
  65. auto* ptr = tshp.value.ptr<dt_int32>();
  66. for (size_t i = 0; i < target_ndim; ++i) {
  67. out_shape[i] = ptr[i];
  68. }
  69. mgb_assert(valid_broadcast(src.layout, out_shape),
  70. "the input shape %s can not be broadcasted to target shape %s",
  71. src.layout.to_string().c_str(),
  72. out_shape.to_string().c_str());
  73. return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true};
  74. }
  75. std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
  76. const OpDef& def,
  77. const SmallVector<TensorPtr>& inputs_tensors,
  78. const SmallVector<MemoryDesc>& inputs_mems) {
  79. auto& input = inputs_tensors[0];
  80. TensorShape target_shape;
  81. cg::copy_tensor_value_to_shape(target_shape, inputs_tensors[1]->get_value().proxy_to_default_cpu());
  82. // TODO: memory forward
  83. // if (input->shape().eq_shape(target_shape)) {
  84. // return {{{input->layout(), 0, input->comp_node(), StorageIdentifier::make(&inputs_mems[0])}}, {}};
  85. // }
  86. return {{{{target_shape, input->dtype()}, 0, input->comp_node(), StorageIdentifier::make(0)}}, {}};
  87. }
  88. void execute(
  89. const OpDef& def,
  90. SmallVector<TensorPtr> inputs,
  91. SmallVector<TensorPtr> outputs,
  92. SmallVector<TensorPtr> workspace) {
  93. if (outputs[0]->layout().is_empty()) {
  94. return;
  95. }
  96. if (inputs[0]->shape().eq_shape(outputs[0]->shape())) {
  97. mgb_assert(inputs[0]->layout().eq_layout(outputs[0]->layout()));
  98. // TODO: memory forward
  99. // mgb_assert(inputs[0]->offset() == outputs[0]->offset());
  100. // mgb_assert(inputs[0]->blob() == outputs[0]->blob());
  101. outputs[0]->dev_tensor().copy_from_fixlayout(inputs[0]->dev_tensor());
  102. } else {
  103. TensorLayout input_layout = inputs[0]->layout().broadcast(outputs[0]->shape());
  104. outputs[0]->dev_tensor().copy_from_fixlayout(inputs[0]->dev_tensor().sub(SubTensorSpec::make_from_layout(input_layout)));
  105. }
  106. }
  107. OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast)
  108. .make_from_op_node(make_from_op_node)
  109. .apply_on_var_node(apply_on_var_node)
  110. .infer_output_attrs_fallible(infer_output_attrs_fallible)
  111. .infer_output_mem_desc(infer_output_mem_desc)
  112. .execute(execute)
  113. .fallback();
  114. } // broadcast
  115. namespace reshape {
  116. auto apply_on_var_node(
  117. const OpDef& def,
  118. const VarNodeArray& inputs) {
  119. auto&& op = static_cast<const Reshape&>(def);
  120. mgb_assert(inputs.size() == 2);
  121. OperatorNodeConfig config{op.make_name()};
  122. return opr::Reshape::make(inputs[0], inputs[1], op.param(), config);
  123. }
  124. std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
  125. const OpDef& def,
  126. const SmallVector<LogicalTensorDesc>& inputs) {
  127. auto&& op = def.cast_final_safe<Reshape>();
  128. size_t nr_inp = inputs.size();
  129. mgb_assert(nr_inp == 2, "Reshape expects 2 inputs; got %lu actually", nr_inp);
  130. auto&& src = inputs[0];
  131. auto&& tshp = inputs[1];
  132. TensorShape out_shape;
  133. if (tshp.layout.ndim == 0 || tshp.value.empty()) {
  134. out_shape.ndim = 0;
  135. return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false};
  136. }
  137. mgb_assert(
  138. tshp.layout.ndim == 1,
  139. "target shape of Reshape expects ndim=1; got ndim=%lu actually",
  140. tshp.layout.ndim);
  141. if (src.layout.ndim == 0 && op.axis != opr::Reshape::Param::INVALID_AXIS) {
  142. return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false};
  143. }
  144. size_t target_ndim = tshp.layout.shape[0];
  145. out_shape.ndim = target_ndim;
  146. auto* ptr = tshp.value.ptr<dt_int32>();
  147. for (size_t i = 0; i < target_ndim; ++i) {
  148. out_shape[i] = ptr[i];
  149. }
  150. if (src.layout.ndim == 0) {
  151. return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false};
  152. }
  153. if (op.axis != opr::Reshape::Param::INVALID_AXIS) {
  154. mgb_assert(out_shape[op.axis] == -1);
  155. out_shape[op.axis] = 1;
  156. mgb_assert(src.layout.total_nr_elems() % out_shape.total_nr_elems() == 0,
  157. "can not reshape from %s to %s",
  158. src.layout.to_string().c_str(),
  159. out_shape.to_string().c_str());
  160. out_shape[op.axis] = src.layout.total_nr_elems() / out_shape.total_nr_elems();
  161. } else {
  162. mgb_assert(src.layout.total_nr_elems() == out_shape.total_nr_elems(),
  163. "can not reshape from %s to %s",
  164. src.layout.to_string().c_str(),
  165. out_shape.to_string().c_str());
  166. }
  167. return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true};
  168. }
  169. std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
  170. const OpDef& def,
  171. const SmallVector<TensorPtr>& inputs,
  172. const SmallVector<MemoryDesc>& inputs_mems) {
  173. auto&& op_def = def.cast_final_safe<Reshape>();
  174. size_t nr_inp = inputs.size();
  175. mgb_assert(nr_inp == 2, "Reshape expects 2 inputs; got %lu actually", nr_inp);
  176. auto&& src = inputs[0];
  177. auto&& tshp_nd = inputs[1];
  178. auto slayout = src->layout();
  179. TensorShape tshp;
  180. cg::copy_tensor_value_to_shape(tshp, tshp_nd->get_value().proxy_to_default_cpu());
  181. if (op_def.axis != opr::Reshape::Param::INVALID_AXIS) {
  182. mgb_assert(tshp[op_def.axis] == -1);
  183. tshp[op_def.axis] = 1;
  184. tshp[op_def.axis] = src->layout().total_nr_elems() / tshp.total_nr_elems();
  185. }
  186. TensorLayout tlayout = slayout.reshape(tshp);
  187. // memory forward
  188. return {{{tlayout, 0, src->comp_node(), StorageIdentifier::make(&inputs_mems[0])}}, {}};
  189. }
  190. void execute(
  191. const OpDef& def,
  192. SmallVector<TensorPtr> inputs,
  193. SmallVector<TensorPtr> outputs,
  194. SmallVector<TensorPtr> workspace) {
  195. mgb_assert(inputs[0]->offset() == outputs[0]->offset());
  196. mgb_assert(inputs[0]->blob() == outputs[0]->blob());
  197. }
  198. OP_TRAIT_REG(Reshape, Reshape)
  199. .apply_on_var_node(apply_on_var_node)
  200. .infer_output_attrs_fallible(infer_output_attrs_fallible)
  201. .infer_output_mem_desc(infer_output_mem_desc)
  202. .execute(execute)
  203. .fallback();
  204. } // reshape
  205. } // namespace imperative
  206. } // namespace mgb
  207. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台