You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

elemwise.cpp 2.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. /**
  2. * \file imperative/src/impl/ops/elemwise.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/imperative/ops/autogen.h"
  12. #include "megbrain/opr/basic_arith.h"
  13. #include "../op_trait.h"
  14. namespace mgb {
  15. namespace imperative {
  16. namespace {
  17. std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
  18. auto* node = &node_->cast_final_safe<opr::Elemwise>();
  19. return Elemwise::make(node->param().mode);
  20. }
  21. cg::OperatorNodeBase* apply_on_var_node(
  22. const OpDef& def,
  23. const VarNodeArray& inputs) {
  24. auto&& elemwise_opr = def.cast_final_safe<Elemwise>();
  25. return opr::Elemwise::make(inputs, elemwise_opr.mode).node()->owner_opr();
  26. }
  27. std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
  28. const OpDef& def,
  29. const SmallVector<LogicalTensorDesc>& inputs) {
  30. auto&& op_def = def.cast_final_safe<Elemwise>();
  31. auto trait = megdnn::Elemwise::ModeTrait::from_mode(op_def.mode);
  32. mgb_assert(inputs.size() == trait.arity,
  33. "%s expects %u inputs; got %zu actually", trait.name,
  34. trait.arity, inputs.size());
  35. TensorShapeArray inp_shapes;
  36. DType out_dt;
  37. CompNode out_cn;
  38. for (size_t i = 0; i < inputs.size(); ++ i) {
  39. auto &&t = inputs[i];
  40. if (!i) {
  41. out_cn = t.comp_node;
  42. out_dt = t.layout.dtype;
  43. } else {
  44. mgb_assert(t.comp_node == out_cn);
  45. mgb_assert(t.layout.dtype == out_dt);
  46. }
  47. if (t.layout.ndim > 0) {
  48. inp_shapes.push_back(t.layout);
  49. } else {
  50. TensorLayout out_layout;
  51. out_layout.ndim = 0;
  52. out_layout.dtype = out_dt;
  53. return {{{out_layout, out_cn}}, true};
  54. }
  55. }
  56. auto&& out_shape = opr::Elemwise::get_output_var_shape(op_def.mode, inp_shapes);
  57. return {{{TensorLayout(out_shape, out_dt, inputs[0].layout.format), out_cn}}, true};
  58. }
  59. OP_TRAIT_REG(Elemwise, Elemwise, opr::Elemwise)
  60. .make_from_op_node(make_from_op_node)
  61. .apply_on_var_node(apply_on_var_node)
  62. .infer_output_attrs_fallible(infer_output_attrs_fallible)
  63. .fallback();
  64. } // anonymous namespace
  65. } // namespace imperative
  66. } // namespace mgb
  67. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台