You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

broadcast.cpp 3.2 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. /**
  2. * \file imperative/src/impl/ops/broadcast.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/imperative/ops/autogen.h"
  12. #include "megbrain/opr/tensor_manip.h"
  13. #include "../op_trait.h"
  14. namespace mgb {
  15. namespace imperative {
  16. namespace {
  17. std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
  18. node_->cast_final_safe<opr::Broadcast>();
  19. return Broadcast::make();
  20. }
  21. cg::OperatorNodeBase* apply_on_var_node(
  22. const OpDef& def,
  23. const VarNodeArray& inputs) {
  24. def.cast_final_safe<Broadcast>();
  25. size_t nr_inp = inputs.size();
  26. mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp);
  27. return opr::Broadcast::make(inputs[0], inputs[1]).node()->owner_opr();
  28. }
  29. bool valid_broadcast(const TensorShape& src_shape,
  30. const TensorShape& tar_shape) {
  31. size_t src_ndim = src_shape.ndim, tar_ndim = tar_shape.ndim;
  32. if (src_ndim > tar_ndim) {
  33. return false;
  34. }
  35. size_t min_ndim = src_ndim < tar_ndim ? src_ndim : tar_ndim;
  36. for (size_t i = 0; i < min_ndim; ++i) {
  37. if (src_shape[src_ndim - i - 1] != 1 &&
  38. src_shape[src_ndim - i - 1] != tar_shape[tar_ndim - i - 1]) {
  39. return false;
  40. }
  41. }
  42. return true;
  43. }
  44. std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
  45. const OpDef& def,
  46. const SmallVector<LogicalTensorDesc>& inputs) {
  47. def.cast_final_safe<Broadcast>();
  48. size_t nr_inp = inputs.size();
  49. mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp);
  50. auto&& src = inputs[0];
  51. auto&& tshp = inputs[1];
  52. TensorLayout out_layout = src.layout;
  53. if (tshp.layout.ndim == 0 || tshp.value.empty()) {
  54. out_layout.ndim = 0;
  55. return {{{out_layout, src.comp_node}}, true};
  56. }
  57. mgb_assert(
  58. tshp.layout.ndim == 1,
  59. "target shape of Broadcast expects ndim=1; got ndim=%lu actually",
  60. tshp.layout.ndim);
  61. size_t target_ndim = tshp.layout.shape[0];
  62. out_layout.ndim = target_ndim;
  63. auto* ptr = tshp.value.ptr<dt_int32>();
  64. for(size_t i=0; i<target_ndim; ++i) {
  65. out_layout.shape[i] = ptr[i];
  66. }
  67. mgb_assert(valid_broadcast(src.layout, out_layout),
  68. "the input shape %s can not be broadcasted to target shape %s",
  69. src.layout.TensorShape::to_string().c_str(),
  70. out_layout.TensorShape::to_string().c_str());
  71. return {{{out_layout, src.comp_node}}, true};
  72. }
  73. OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast)
  74. .make_from_op_node(make_from_op_node)
  75. .apply_on_var_node(apply_on_var_node)
  76. .infer_output_attrs_fallible(infer_output_attrs_fallible)
  77. .fallback();
  78. } // anonymous namespace
  79. } // namespace imperative
  80. } // namespace mgb
  81. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台