You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

collective_comm.cpp 27 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745
  1. /**
  2. * \file src/opr-mm/impl/collective_comm.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/opr/collective_comm.h"
  12. #include "megbrain/comp_node_env.h"
  13. #include "megbrain/graph/event.h"
  14. #include "megbrain/graph/grad_impl.h"
  15. #include "megbrain/opr/basic_arith.h"
  16. #include "megbrain/opr/megray_helper.h"
  17. #include "megbrain/opr/group_manager.h"
  18. #include "megbrain/serialization/sereg.h"
  19. #include "megbrain/version_symbol.h"
  20. using namespace mgb;
  21. using namespace opr;
  22. MGB_DYN_TYPE_OBJ_FINAL_IMPL(CollectiveComm);
  23. #define FOREACH_MODE(cb) \
  24. cb(ALL_REDUCE_SUM) cb(ALL_REDUCE_MAX) cb(ALL_REDUCE_MIN) cb(BROADCAST) \
  25. cb(REDUCE_SUM) cb(ALL_GATHER) cb(REDUCE_SCATTER_SUM)
  26. namespace {
  27. const char* get_param_name(CollectiveComm::Param param) {
  28. using Mode = CollectiveComm::Param::Mode;
  29. switch (param.mode) {
  30. #define C(_m) \
  31. case Mode::_m: \
  32. return #_m;
  33. FOREACH_MODE(C)
  34. #undef C
  35. default:
  36. mgb_throw(MegBrainError, "bad CollectiveComm mode");
  37. }
  38. }
  39. MegRay::DType get_megray_dtype(megdnn::DType dtype) {
  40. switch(dtype.enumv()) {
  41. case DTypeEnum::Int8:
  42. return MegRay::DType::MEGRAY_INT8;
  43. case DTypeEnum::Int32:
  44. return MegRay::DType::MEGRAY_INT32;
  45. case DTypeEnum::Float32:
  46. return MegRay::DType::MEGRAY_FLOAT32;
  47. #ifndef MEGDNN_DISABLE_FLOAT16
  48. case DTypeEnum::Float16:
  49. return MegRay::DType::MEGRAY_FLOAT16;
  50. #endif
  51. default:
  52. mgb_throw(MegBrainError, "bad CollectiveComm dtype");
  53. }
  54. }
  55. MegRay::Backend get_megray_backend(const std::string& backend) {
  56. if (backend == "nccl") {
  57. return MegRay::MEGRAY_NCCL;
  58. } else if (backend == "ucx") {
  59. return MegRay::MEGRAY_UCX;
  60. } else {
  61. mgb_throw(MegBrainError, "back CollectiveComm backend");
  62. }
  63. }
  64. cudaStream_t get_stream(VarNode* var) {
  65. return CompNodeEnv::from_comp_node(var->comp_node()).cuda_env().stream;
  66. }
  67. } // anonymous namespace
  68. class CollectiveComm::ModeTrait {
  69. class BROADCAST;
  70. class REDUCE_SUM;
  71. class REDUCE_SCATTER_SUM;
  72. class ALL_GATHER;
  73. class ALL_REDUCE_SUM;
  74. class ALL_REDUCE_MAX;
  75. class ALL_REDUCE_MIN;
  76. class ReducedBasedTrait;
  77. class AllReduceBase;
  78. class ReduceBase;
  79. protected:
  80. using Mode = Param::Mode;
  81. static void chk_shape_equal(const TensorShapeArray& shp) {
  82. for (size_t i = 1; i < shp.size(); ++i) {
  83. mgb_throw_if(!shp[0].eq_shape(shp[i]), GraphError,
  84. "input shapes should be equal");
  85. }
  86. }
  87. static void add_output_var_all2all(CollectiveComm* opr) {
  88. mgb_assert(opr->nr_devices() >= 2);
  89. auto pname = get_param_name(opr->param());
  90. // sublinear would setup opr->config if inputs.size() is 1,
  91. // bypass this situation
  92. mgb_assert(
  93. !opr->config().has_comp_node_set() || opr->input().size() == 1,
  94. "comp node should not be set in %s mode", pname);
  95. for (auto i : opr->input()) {
  96. opr->add_output(ssprintf("%s:%s", pname, i->cname()))
  97. ->comp_node(i->comp_node());
  98. }
  99. }
  100. public:
  101. virtual ~ModeTrait() = default;
  102. //! add output var for the opr
  103. virtual void add_output_var(CollectiveComm* opr,
  104. const CompNode::UnorderedSet& inp_cn) = 0;
  105. /*!
  106. * \brief the vars on whose comp node the computing should be performed
  107. * if None, output vars would be used
  108. */
  109. virtual Maybe<VarNodeArray> comp_vars(CollectiveComm* opr) {
  110. return None;
  111. }
  112. virtual void get_output_var_shape(const CollectiveComm* opr,
  113. const TensorShapeArray& ishp,
  114. TensorShapeArray& oshp) = 0;
  115. virtual void exec(CollectiveComm* opr) = 0;
  116. //! gradient mode
  117. virtual Mode grad_mode() = 0;
  118. static ModeTrait& from_mode(Mode mode);
  119. };
  120. class CollectiveComm::ModeTrait::ALL_GATHER : public ModeTrait {
  121. void add_output_var(CollectiveComm* opr,
  122. const CompNode::UnorderedSet&) override {
  123. add_output_var_all2all(opr);
  124. }
  125. void get_output_var_shape(const CollectiveComm* opr,
  126. const TensorShapeArray& ishp,
  127. TensorShapeArray& oshp) override {
  128. chk_shape_equal(ishp);
  129. auto soshp = ishp[0];
  130. soshp[0] *= opr->nr_devices();
  131. for (auto& i : oshp)
  132. i = soshp;
  133. }
  134. void exec(CollectiveComm* opr) override {
  135. auto ivar = opr->input(0), ovar = opr->output(0);
  136. auto &&iv = ivar->dev_tensor(), &&ov = ovar->dev_tensor();
  137. mgb_assert(ivar->comp_node().mem_node() ==
  138. ovar->comp_node().mem_node());
  139. auto status = opr->m_megray_comm->all_gather(
  140. (void*)iv.raw_ptr(), (void*)ov.raw_ptr(),
  141. iv.shape().total_nr_elems(),
  142. get_megray_dtype(iv.dtype()),
  143. opr->megray_ctx());
  144. mgb_assert(status == MegRay::MEGRAY_OK, "MegRay all_gather failed");
  145. }
  146. Mode grad_mode() override { return Mode::REDUCE_SCATTER_SUM; }
  147. };
  148. class CollectiveComm::ModeTrait::REDUCE_SCATTER_SUM : public ModeTrait {
  149. void add_output_var(CollectiveComm* opr,
  150. const CompNode::UnorderedSet&) override {
  151. add_output_var_all2all(opr);
  152. }
  153. void get_output_var_shape(const CollectiveComm* opr,
  154. const TensorShapeArray& ishp,
  155. TensorShapeArray& oshp) override {
  156. chk_shape_equal(ishp);
  157. auto soshp = ishp[0];
  158. mgb_throw_if(soshp.shape[0] % opr->nr_devices(), GraphError,
  159. "input size can not be divided equally: "
  160. "size=%zu parts=%zu",
  161. soshp[0], ishp.size());
  162. soshp[0] /= opr->nr_devices();
  163. for (auto& i : oshp)
  164. i = soshp;
  165. }
  166. void exec(CollectiveComm* opr) override {
  167. auto ivar = opr->input(0), ovar = opr->output(0);
  168. auto &&iv = ivar->dev_tensor(), &&ov = ovar->dev_tensor();
  169. mgb_assert(ivar->comp_node().mem_node() ==
  170. ovar->comp_node().mem_node());
  171. size_t buff_len = ov.shape().total_nr_elems();// * opr->m_nr_devices;
  172. auto status = opr->m_megray_comm->reduce_scatter(
  173. (void*)iv.raw_ptr(), (void*)ov.raw_ptr(), buff_len,
  174. get_megray_dtype(ov.dtype()), MegRay::ReduceOp::MEGRAY_SUM,
  175. opr->megray_ctx());
  176. mgb_assert(status == MegRay::MEGRAY_OK, "MegRay reduce_scatter failed");
  177. }
  178. Mode grad_mode() override { return Mode::ALL_GATHER; }
  179. };
  180. /* ================= ModeTrait impls ================= */
  181. class CollectiveComm::ModeTrait::ReducedBasedTrait {
  182. protected:
  183. ~ReducedBasedTrait() = default;
  184. virtual MegRay::ReduceOp op() const = 0;
  185. };
  186. class CollectiveComm::ModeTrait::AllReduceBase : public ReducedBasedTrait,
  187. public ModeTrait {
  188. void add_output_var(CollectiveComm* opr,
  189. const CompNode::UnorderedSet&) override {
  190. add_output_var_all2all(opr);
  191. }
  192. void get_output_var_shape(const CollectiveComm*,
  193. const TensorShapeArray& ishp,
  194. TensorShapeArray& oshp) override {
  195. chk_shape_equal(ishp);
  196. oshp = ishp;
  197. }
  198. void exec(CollectiveComm* opr) override {
  199. auto ivar = opr->input(0), ovar = opr->output(0);
  200. auto &&iv = ivar->dev_tensor(), &&ov = ovar->dev_tensor();
  201. mgb_assert(ivar->comp_node().mem_node() ==
  202. ovar->comp_node().mem_node());
  203. auto status = opr->m_megray_comm->all_reduce(
  204. (void*)iv.raw_ptr(), (void*)ov.raw_ptr(),
  205. iv.shape().total_nr_elems(),
  206. get_megray_dtype(iv.dtype()), op(),
  207. opr->megray_ctx());
  208. mgb_assert(status == MegRay::MEGRAY_OK, "MegRay all_reduce failed");
  209. }
  210. Mode grad_mode() override { return Mode::ALL_REDUCE_SUM; }
  211. };
  212. class CollectiveComm::ModeTrait::ALL_REDUCE_SUM final : public AllReduceBase {
  213. MegRay::ReduceOp op() const override { return MegRay::ReduceOp::MEGRAY_SUM; }
  214. };
  215. class CollectiveComm::ModeTrait::ALL_REDUCE_MAX final : public AllReduceBase {
  216. MegRay::ReduceOp op() const override { return MegRay::ReduceOp::MEGRAY_MAX; }
  217. };
  218. class CollectiveComm::ModeTrait::ALL_REDUCE_MIN final : public AllReduceBase {
  219. MegRay::ReduceOp op() const override { return MegRay::ReduceOp::MEGRAY_MIN; }
  220. };
  221. class CollectiveComm::ModeTrait::ReduceBase : public ReducedBasedTrait,
  222. public ModeTrait {
  223. void add_output_var(CollectiveComm* opr,
  224. const CompNode::UnorderedSet& inp_cn) override {
  225. add_output_var_all2all(opr);
  226. }
  227. void get_output_var_shape(const CollectiveComm* opr,
  228. const TensorShapeArray& ishp,
  229. TensorShapeArray& oshp) override {
  230. MGB_MARK_USED_VAR(opr);
  231. chk_shape_equal(ishp);
  232. if (opr->is_root()) {
  233. oshp[0] = ishp[0];
  234. } else {
  235. oshp[0] = TensorShape{1};
  236. }
  237. }
  238. void exec(CollectiveComm* opr) override {
  239. auto ovar = opr->output(0);
  240. auto&& iv = opr->input(0)->dev_tensor();
  241. void* recvbuf = nullptr;
  242. if (opr->is_root()) {
  243. recvbuf = ovar->dev_tensor().raw_ptr();
  244. }
  245. auto status = opr->m_megray_comm->reduce(
  246. (void*)iv.raw_ptr(), recvbuf,
  247. iv.shape().total_nr_elems(),
  248. get_megray_dtype(iv.dtype()), op(),
  249. opr->m_root, opr->megray_ctx());
  250. mgb_assert(status == MegRay::MEGRAY_OK, "MegRay reduce failed");
  251. }
  252. };
  253. class CollectiveComm::ModeTrait::REDUCE_SUM final : public ReduceBase {
  254. MegRay::ReduceOp op() const override { return MegRay::ReduceOp::MEGRAY_SUM; }
  255. Mode grad_mode() override { return Mode::BROADCAST; }
  256. };
  257. class CollectiveComm::ModeTrait::BROADCAST : public ModeTrait {
  258. void add_output_var(CollectiveComm* opr,
  259. const CompNode::UnorderedSet&) override {
  260. if (opr->input().size() > 0) {
  261. add_output_var_all2all(opr);
  262. return;
  263. }
  264. const auto& cns = opr->config().comp_node();
  265. mgb_assert(cns.size() == 1, "exactly one comp_node expected, got %zu", cns.size());
  266. auto pname = get_param_name(opr->param());
  267. opr->add_output(ssprintf("%s:%s", pname, opr->key().c_str()))->comp_node(cns[0]);
  268. }
  269. void get_output_var_shape(const CollectiveComm*,
  270. const TensorShapeArray& ishp,
  271. TensorShapeArray& oshp) override {
  272. mgb_assert(false, "BROADCAST should not use get_output_var_shape");
  273. }
  274. void exec(CollectiveComm* opr) override {
  275. auto ovar = opr->output(0);
  276. auto&& ov = ovar->dev_tensor();
  277. mgb_assert(opr->input().size() < 2,
  278. "input size of BROADCAST must be either 0 or 1");
  279. void* buff;
  280. DType datatype;
  281. size_t length;
  282. if (opr->is_root()) {
  283. auto ivar = opr->input(0);
  284. auto&& iv = ivar->dev_tensor();
  285. datatype = iv.dtype();
  286. buff = (void*)iv.raw_ptr();
  287. length = iv.shape().total_nr_elems();
  288. } else {
  289. buff = NULL;
  290. datatype = ov.dtype();
  291. length = ov.shape().total_nr_elems();
  292. }
  293. auto status = opr->m_megray_comm->broadcast(
  294. buff, (void*)ov.raw_ptr(), length,
  295. get_megray_dtype(datatype), opr->m_root,
  296. opr->megray_ctx());
  297. mgb_assert(status == MegRay::MEGRAY_OK, "MegRay broadcast failed");
  298. }
  299. Mode grad_mode() override { return Mode::REDUCE_SUM; }
  300. };
  301. CollectiveComm::ModeTrait& CollectiveComm::ModeTrait::from_mode(Mode mode) {
  302. switch (mode) {
  303. #define c(_m) \
  304. case Mode::_m: { \
  305. static _m ins; \
  306. return ins; \
  307. }
  308. FOREACH_MODE(c)
  309. default:
  310. mgb_assert(0);
  311. #undef c
  312. }
  313. }
  314. /* ================= CollectiveComm ================= */
  315. CollectiveComm::CollectiveComm(
  316. VarNodeArray inputs, ComputingGraph* const graph,
  317. const std::string& key, const size_t nr_devices, const uint32_t rank,
  318. const uint32_t root, std::shared_ptr<GroupClient> group_client,
  319. const Param& param, const DType& dtype, const std::string& backend,
  320. const SmallVector<std::shared_ptr<DeviceTensorND>>& dev_buffer_arr,
  321. const OperatorNodeConfig& config,
  322. const std::shared_ptr<DTypeScalar>& disable)
  323. : Super{graph, config, get_param_name(param), inputs},
  324. m_param{param},
  325. m_dtype(dtype),
  326. m_backend(backend),
  327. m_group_client{std::move(group_client)},
  328. m_nr_devices(nr_devices),
  329. m_rank(rank),
  330. m_key(key),
  331. m_root(root),
  332. m_dev_buffers(dev_buffer_arr),
  333. m_disable{disable} {
  334. for (auto i : inputs) {
  335. mgb_assert(i->comp_node().device_type() == CompNode::DeviceType::CUDA,
  336. "CollectiveComm currectly only supports CUDA");
  337. }
  338. for (auto i : config.comp_node()) {
  339. mgb_assert(i.device_type() == CompNode::DeviceType::CUDA,
  340. "CollectiveComm currectly only supports CUDA");
  341. }
  342. CompNode::UnorderedSet inp_cn;
  343. ThinHashSet<int> inp_dev;
  344. for (auto i : inputs) {
  345. add_input({i});
  346. inp_cn.insert(i->comp_node());
  347. inp_dev.insert(
  348. CompNodeEnv::from_comp_node(i->comp_node()).cuda_env().device);
  349. }
  350. mgb_assert(
  351. inp_dev.size() == inputs.size(),
  352. "CollectiveComm inputs should not contain duplicated input device");
  353. ModeTrait::from_mode(param.mode).add_output_var(this, inp_cn);
  354. m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0)));
  355. const char* c_debug = MGB_GETENV("MGE_MM_OPR_DEBUG");
  356. if (c_debug != nullptr and strcmp(c_debug, "1") == 0) {
  357. m_debug_mode = true;
  358. }
  359. add_equivalence_component<PODHash<Param>>(&m_param);
  360. add_equivalence_component<PODHash<size_t>>(&m_nr_devices);
  361. m_hash = XXHash{}.update(key.data(), key.size() * sizeof(char)).digest();
  362. add_equivalence_component<PODHash<size_t>>(&m_hash);
  363. }
  364. SymbolVarArray CollectiveComm::make(
  365. const SymbolVarArray& inputs, ComputingGraph* const graph,
  366. const std::string& key, const size_t nr_devices, const uint32_t rank,
  367. const uint32_t root, std::shared_ptr<GroupClient> group_client,
  368. const Param& param, const DType& dtype, const std::string& backend,
  369. const OperatorNodeConfig& config,
  370. const std::shared_ptr<DTypeScalar>& disable) {
  371. SmallVector<std::shared_ptr<DeviceTensorND>> dev_buffer_arr(nr_devices,
  372. nullptr);
  373. return make(inputs, graph, key, nr_devices, rank, root, group_client,
  374. dev_buffer_arr, param, dtype, backend, config);
  375. }
  376. SymbolVarArray CollectiveComm::make(
  377. const SymbolVarArray& inputs, ComputingGraph* const graph,
  378. const std::string& key, const size_t nr_devices, const uint32_t rank,
  379. const uint32_t root, std::shared_ptr<GroupClient> group_client,
  380. const SmallVector<std::shared_ptr<DeviceTensorND>>& dev_buffer_arr,
  381. const Param& param, const DType& dtype, const std::string& backend,
  382. const OperatorNodeConfig& config,
  383. const std::shared_ptr<DTypeScalar>& disable) {
  384. auto inpvars = cg::to_var_node_array(inputs);
  385. auto opr = graph->insert_opr(std::make_unique<CollectiveComm>(
  386. inpvars, graph, key, nr_devices, rank, root, std::move(group_client),
  387. param, dtype, backend, dev_buffer_arr, config, disable));
  388. mgb_assert(!opr->output().empty());
  389. return cg::to_symbol_var_array(opr->output());
  390. }
  391. void CollectiveComm::opr_register() {
  392. if (m_init)
  393. return;
  394. auto&& cuda_env = CompNodeEnv::from_comp_node(output(0)->comp_node())
  395. .cuda_env();
  396. auto hash = m_group_client->opr_register(m_key, m_nr_devices, m_rank,
  397. reinterpret_cast<uintptr_t>(cuda_env.stream));
  398. auto megray_comm_builder =
  399. owner_graph()
  400. ->options()
  401. .user_data
  402. .get_user_data_or_create<MegRayCommunicatorBuilder>();
  403. m_megray_comm = megray_comm_builder->get_megray_comm(
  404. hash, m_key, m_nr_devices, m_rank,
  405. get_megray_backend(m_backend), m_group_client);
  406. m_init = true;
  407. }
  408. void CollectiveComm::add_input_layout_constraint() {
  409. // Enable shape infer *after* static infer phase. This is only used by
  410. // BROADCAST operation.
  411. m_enable_shape_infer = true;
  412. for (auto i : input()) {
  413. i->add_layout_constraint_contiguous();
  414. }
  415. }
  416. void CollectiveComm::get_output_var_shape(const TensorShapeArray& inp_shape,
  417. TensorShapeArray& out_shape) const {
  418. ModeTrait::from_mode(m_param.mode)
  419. .get_output_var_shape(const_cast<CollectiveComm*>(this),
  420. inp_shape, out_shape);
  421. }
  422. void CollectiveComm::init_output_comp_node() {
  423. mgb_assert(output().size() == 1, "exactly one output expected, got %zu", output().size());
  424. owner_graph()->seq_comp_node_optimizer().register_stream_var(output()[0],
  425. {CompNode::Stream::NCCL, cg::SeqCompNodeOptimizer::StreamPropType::WEAK});
  426. }
  427. void CollectiveComm::init_output_mem_plan(bool dynamic) {
  428. for (size_t i = 0; i < output().size(); i++) {
  429. if (m_dev_buffers[i]) {
  430. output(i)->init_mem_plan(m_dev_buffers[i].get());
  431. } else {
  432. if (is_static_var_storage(output(i)) == !dynamic &&
  433. !output(i)->contain_flag(VarNode::Flag::NO_SYS_MEM_ALLOC))
  434. output(i)->init_mem_plan();
  435. }
  436. }
  437. }
  438. void CollectiveComm::mem_plan_fwd_in2out_writable() {
  439. if (m_param.mode == Param::Mode::ALL_REDUCE_SUM) {
  440. for (size_t i = 0; i < output().size(); ++i) {
  441. output(i)->set_fwd_in2out_writable(input(i));
  442. }
  443. }
  444. }
  445. cg::OperatorNodeBase::NodeProp* CollectiveComm::do_make_node_prop() const {
  446. auto prop = OperatorNodeBase::do_make_node_prop();
  447. prop->add_flag(NodeProp::Flag::CROSS_COMP_NODE_MEMORY);
  448. prop->add_flag(NodeProp::Flag::NO_AUTOMATIC_DUP);
  449. return prop;
  450. }
  451. void CollectiveComm::do_execute(ExecEnv& env) {
  452. auto&& trait = ModeTrait::from_mode(m_param.mode);
  453. mgb_assert(owner_graph()->options().async_exec_level,
  454. "collective comm must be used with async dispatch");
  455. mgb_assert(output().size() == 1,
  456. "collective comm only support exactly one output");
  457. auto disable = m_disable->get_cast<int>();
  458. if (disable == 1)
  459. return;
  460. mgb_assert(disable == 0,
  461. "disable flag on CollectiveComm can only be 0 or 1,"
  462. " got %d actually.",
  463. disable);
  464. auto cn = output(0)->comp_node();
  465. auto runner = [this, cn, &trait] {
  466. opr_register();
  467. cn.activate();
  468. if (m_debug_mode) {
  469. mgb_log_debug("collective comm: executing %s, rank = %d, key = %s",
  470. cname(), rank(), key().c_str());
  471. }
  472. owner_graph()->event().signal_inplace<cg::event::BeforeKernel>(this, cn);
  473. trait.exec(this);
  474. owner_graph()->event().signal_inplace<cg::event::AfterKernel>(this, cn);
  475. #if CUDART_VERSION < 9000
  476. #pragma message "legacy CUDA; use sync to avoid blocking"
  477. // nccl hangs occasionally without this sync()
  478. cn.sync();
  479. #endif
  480. };
  481. env.dispatch_on_comp_node(cn, runner);
  482. }
  483. void CollectiveComm::on_output_comp_node_stream_changed() {}
  484. VarNodeArray CollectiveComm::grad(const VarNodeArray& out_grads) const {
  485. auto mode = ModeTrait::from_mode(m_param.mode).grad_mode();
  486. SymbolVarArray og_syms;
  487. if (m_param.mode == Param::Mode::REDUCE_SUM) {
  488. for (size_t i = 0; i < output().size(); i++) {
  489. if (out_grads[i])
  490. og_syms.push_back(out_grads[i]);
  491. }
  492. mgb_assert(og_syms.size() == 1);
  493. } else {
  494. for (size_t i = 0; i < output().size(); i++) {
  495. if (!out_grads[i]) {
  496. mgb_assert(m_param.mode != Param::Mode::REDUCE_SCATTER_SUM,
  497. "null out grad in CollctiveCommMM currently "
  498. "unsupported when the forward mode is "
  499. "Reduce_Scatter_Sum.");
  500. DTypeScalar dval{output(i)->dtype()};
  501. dval.set_retain_dtype(0);
  502. auto zeros =
  503. SymbolVar::make_scalar(dval, *output(i)->owner_graph(),
  504. output(i)->comp_node())
  505. .broadcast(SymbolVar(output(i)).symshape());
  506. og_syms.push_back(zeros);
  507. } else {
  508. og_syms.push_back(out_grads[i]);
  509. }
  510. }
  511. }
  512. OperatorNodeConfig::CompNodeArray cn_arr;
  513. if (m_param.mode == Param::Mode::REDUCE_SUM) {
  514. for (auto i : input()) {
  515. cn_arr.push_back(i->comp_node());
  516. }
  517. } else if (m_param.mode == Param::Mode::BROADCAST) {
  518. if (!input().empty()) {
  519. cn_arr.push_back(input(0)->comp_node());
  520. }
  521. }
  522. auto gvar = CollectiveComm::make(
  523. og_syms, owner_graph(), m_key + ":grad", m_nr_devices, m_rank, m_root,
  524. m_group_client, mode, m_dtype, m_backend,
  525. OperatorNodeConfig{}.comp_node_arr(cn_arr));
  526. if (m_param.mode == Param::Mode::ALL_REDUCE_MAX) {
  527. for (size_t i = 0; i < input().size(); ++i) {
  528. gvar[i] = Elemwise::make({output(i), input(i), gvar[i]},
  529. Elemwise::Mode::COND_LEQ_MOV);
  530. }
  531. } else if (m_param.mode == Param::Mode::ALL_REDUCE_MIN) {
  532. for (size_t i = 0; i < input().size(); ++i) {
  533. gvar[i] = Elemwise::make({input(i), output(i), gvar[i]},
  534. Elemwise::Mode::COND_LEQ_MOV);
  535. }
  536. } else if (m_param.mode == Param::Mode::BROADCAST) {
  537. if (!input().empty()) {
  538. CompNode&& master_out_cn = input(0)->comp_node();
  539. SymbolVarArray rst;
  540. for (auto i : gvar) {
  541. if (i.node()->comp_node() == master_out_cn) {
  542. mgb_assert(rst.empty());
  543. rst.push_back(i);
  544. }
  545. }
  546. gvar = rst;
  547. }
  548. }
  549. return cg::to_var_node_array(gvar);
  550. }
  551. MGB_IMPL_OPR_GRAD(CollectiveComm) {
  552. return opr.grad(out_grad);
  553. }
  554. void CollectiveComm::init_output_dtype() {
  555. if (m_dtype.valid()) {
  556. for (size_t i = 0; i < input().size(); ++i) {
  557. mgb_assert(m_dtype == input(i)->dtype(),
  558. "any given input's dtype should be identical to that "
  559. "specified from opr's argument");
  560. }
  561. for (auto i : output()) {
  562. if (!i->dtype().valid())
  563. i->dtype(m_dtype);
  564. }
  565. } else {
  566. Super::init_output_dtype();
  567. }
  568. }
  569. void CollectiveComm::init_output_static_infer_desc() {
  570. if (m_param.mode == Param::Mode::REDUCE_SUM) {
  571. using namespace cg::static_infer;
  572. auto&& mgr = owner_graph()->static_infer_manager();
  573. auto infer_shape_from_input = [](TensorShape& dest, const InpVal& inp_val) {
  574. dest = inp_val.val[0].shape();
  575. return true;
  576. };
  577. auto infer_shape_constant = [](TensorShape& dest, const InpVal&) {
  578. dest = TensorShape{1};
  579. return true;
  580. };
  581. mgb_assert(input().size() == 1);
  582. mgb_assert(output().size() == 1);
  583. if (is_root()) {
  584. mgr.register_shape_infer(output(0),
  585. {SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_shape_from_input});
  586. } else {
  587. mgr.register_shape_infer(output(0),
  588. {SourceType::CONSTANT, {}, infer_shape_constant});
  589. }
  590. } else if (m_param.mode == Param::Mode::BROADCAST) {
  591. using namespace cg::static_infer;
  592. auto&& mgr = owner_graph()->static_infer_manager();
  593. auto infer_shape_from_input = [this](TensorShape& dest, const InpVal& inp_val) {
  594. if (!m_broadcast_output_shape.valid()) {
  595. m_broadcast_output_shape = inp_val.val[0].shape();
  596. m_group_client->set_output_shape(m_key, m_broadcast_output_shape.val());
  597. }
  598. dest = inp_val.val[0].shape();
  599. return true;
  600. };
  601. auto get_shape_from_server = [this](TensorShape& dest, const InpVal&) {
  602. if (!m_enable_shape_infer) {
  603. return false;
  604. }
  605. if (!m_broadcast_output_shape.valid()) {
  606. m_broadcast_output_shape = m_group_client->get_output_shape(m_key);
  607. }
  608. dest = m_broadcast_output_shape.val();
  609. return true;
  610. };
  611. mgb_assert(output().size() == 1);
  612. if (is_root()) {
  613. mgb_assert(input().size() == 1);
  614. mgr.register_shape_infer(output(0),
  615. {SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_shape_from_input});
  616. } else {
  617. mgr.register_shape_infer(output(0),
  618. {SourceType::MUTABLE, {}, get_shape_from_server});
  619. }
  620. } else {
  621. Super::init_output_static_infer_desc();
  622. }
  623. }
  624. /* ===================== shallow copy ===================== */
  625. namespace mgb {
  626. namespace opr {
  627. cg::OperatorNodeBase* opr_shallow_copy_collective_mm(
  628. const serialization::OprShallowCopyContext& ctx,
  629. const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs,
  630. const OperatorNodeConfig& config) {
  631. auto&& opr = opr_.cast_final_safe<opr::CollectiveComm>();
  632. return opr::CollectiveComm::make(to_symbol_var_array(inputs),
  633. ctx.owner_graph(opr_, inputs), opr.key(),
  634. opr.nr_devices(), opr.rank(), opr.root(),
  635. opr.group_client(), opr.dev_buffers(),
  636. opr.param(), opr.dtype(), opr.backend(), config)[0]
  637. .node()
  638. ->owner_opr();
  639. }
  640. MGB_REG_OPR_SHALLOW_COPY(CollectiveComm, opr_shallow_copy_collective_mm);
  641. } // namespace opr
  642. } // namespace mgb
  643. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台