You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

collective_comm.h 5.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. /**
  2. * \file src/opr-mm/include/megbrain/opr/collective_comm.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megbrain/graph.h"
  13. #include "megbrain/opr/param_defs.h"
  14. #include "megbrain/opr/group_manager.h"
  15. #include "megray.h"
  16. namespace mgb {
  17. namespace opr {
  18. //! collective communication between multiple CompNode on localhost
  19. MGB_DEFINE_OPR_CLASS(CollectiveComm, cg::OutshapePureByInshapeOpr<>) // {
  20. public:
  21. class ModeTrait;
  22. using Param = megdnn::param::CollectiveComm;
  23. CollectiveComm(
  24. VarNodeArray inputs, ComputingGraph* const graph,
  25. const std::string& key, const size_t nr_devices, const bool is_root,
  26. const int rank, std::shared_ptr<GroupClient> group_client,
  27. const Param& param, const DType& dtype, const std::string& backend,
  28. const SmallVector<std::shared_ptr<DeviceTensorND>>& dev_buffer_arr,
  29. const OperatorNodeConfig& config,
  30. const std::shared_ptr<DTypeScalar>& disable);
  31. static SymbolVarArray make(
  32. const SymbolVarArray& inputs, ComputingGraph* const graph,
  33. const std::string& key, const size_t nr_devices, const bool is_root,
  34. const int rank, std::shared_ptr<GroupClient> group_client,
  35. const SmallVector<std::shared_ptr<DeviceTensorND>>& dev_buffer_arr,
  36. const Param& param, const DType& dtype = {},
  37. const std::string& backend = "nccl",
  38. const OperatorNodeConfig& config = {},
  39. const std::shared_ptr<DTypeScalar>& disable =
  40. std::make_shared<DTypeScalar>(0));
  41. static SymbolVarArray make(const SymbolVarArray& inputs,
  42. ComputingGraph* const graph,
  43. const std::string& key, const size_t nr_devices,
  44. const bool is_root, const int rank,
  45. std::shared_ptr<GroupClient> group_client,
  46. const Param& param, const DType& dtype = {},
  47. const std::string& backend = "nccl",
  48. const OperatorNodeConfig& config = {},
  49. const std::shared_ptr<DTypeScalar>& disable =
  50. std::make_shared<DTypeScalar>(0));
  51. const Param& param() const { return m_param; }
  52. const DType& dtype() const { return m_dtype; }
  53. const std::string& backend() const { return m_backend; }
  54. //! total number of devices within the clique
  55. size_t nr_devices() const { return m_nr_devices; }
  56. //! output buffers
  57. const SmallVector<std::shared_ptr<DeviceTensorND>>& dev_buffers() const {
  58. return m_dev_buffers;
  59. }
  60. int rank() const { return m_rank; }
  61. int root() const { return m_root; }
  62. bool is_root() const { return m_is_root; }
  63. //! The key that identifies an NCCL clique.
  64. //! Operators with same keys belong to the same clique.
  65. const std::string& key() const { return m_key; }
  66. std::shared_ptr<GroupClient> group_client() const {
  67. return m_group_client;
  68. }
  69. void set_pack_hash(uint64_t hash) { m_pack_hash = hash; }
  70. uint64_t pack_hash() const { return m_pack_hash; }
  71. std::shared_ptr<MegRay::Context> megray_ctx() const {
  72. return m_megray_ctx;
  73. }
  74. VarNodeArray grad(const VarNodeArray& out_grad) const;
  75. private:
  76. Barrier m_exec_barrier;
  77. const Param m_param;
  78. const DType m_dtype;
  79. const std::string m_backend;
  80. void mem_plan_fwd_in2out_writable() override;
  81. void add_input_layout_constraint() override;
  82. void get_output_var_shape(const TensorShapeArray& inp_shape,
  83. TensorShapeArray& out_shape) const override;
  84. void init_output_comp_node() override;
  85. void do_execute(ExecEnv& env) override;
  86. NodeProp* do_make_node_prop() const override;
  87. void on_output_comp_node_stream_changed() override;
  88. void init_output_dtype() override;
  89. void init_output_static_infer_desc() override;
  90. void init_output_mem_plan(bool dynamic) override;
  91. //! init nccl communicators
  92. void opr_register();
  93. std::shared_ptr<GroupClient> m_group_client;
  94. size_t m_nr_devices = 0;
  95. bool m_is_root;
  96. int m_rank;
  97. std::string m_key;
  98. //! XXHash generated from m_key
  99. size_t m_hash;
  100. //! root of BROADCAST and REDUCE operation
  101. int m_root;
  102. //! rank of root of BROADCAST and REDUCE operation
  103. Maybe<TensorShape> m_output_shape = None;
  104. // Whether shape infer is enabled.
  105. // This is only used by BROADCAST and SCATTER operation,
  106. // whose shape infer should be disabled *during* static infer phase.
  107. bool m_enable_shape_infer = false;
  108. //! set in PackAllReduceScanPass and used in PackAllReduceReplacePass
  109. uint64_t m_pack_hash = 0;
  110. std::shared_ptr<MegRay::Context> m_megray_ctx;
  111. std::shared_ptr<MegRay::Communicator> m_megray_comm;
  112. bool m_init = false;
  113. bool m_debug_mode = false;
  114. //! dev buffers for each outputs
  115. SmallVector<std::shared_ptr<DeviceTensorND>> m_dev_buffers;
  116. //! disable flag
  117. std::shared_ptr<DTypeScalar> m_disable;
  118. };
  119. } // namespace opr
  120. } // namespace mgb
  121. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台