You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

io_remote.cpp 8.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. /**
  2. * \file src/opr-mm/impl/io_remote.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/opr/io_remote.h"
  12. #include "megbrain/comp_node_env.h"
  13. #include "megbrain/graph/grad_impl.h"
  14. #include "megbrain/opr/megray_helper.h"
  15. #include "megbrain/serialization/sereg.h"
  16. using namespace mgb;
  17. using namespace opr;
  18. cudaStream_t get_stream(VarNode* var) {
  19. return CompNodeEnv::from_comp_node(var->comp_node()).cuda_env().stream;
  20. }
  21. /* ===================== RemoteSend ===================== */
  22. MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteSend);
  23. RemoteSend::RemoteSend(const std::string& key, VarNode* var,
  24. std::shared_ptr<GroupClient> group_client,
  25. bool is_grad, const OperatorNodeConfig& config) :
  26. Super(var->owner_graph(), config, "remote_send", {var}),
  27. m_is_grad(is_grad) {
  28. m_key = key;
  29. m_group_client = group_client;
  30. add_input({var});
  31. auto ovar = add_output(None);
  32. if (!m_is_grad) {
  33. ovar->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE)
  34. .add_flag(VarNode::Flag::VOLATILE_CONTENT);
  35. }
  36. add_equivalence_component<ScalarHash<void*>>(this);
  37. }
  38. SymbolVar RemoteSend::make(const std::string& key, SymbolVar var,
  39. std::shared_ptr<GroupClient> group_client,
  40. bool is_grad, const OperatorNodeConfig& config) {
  41. return var.insert_single_output_opr<RemoteSend>(key, var.node(), group_client,
  42. is_grad, config);
  43. }
  44. void RemoteSend::scn_do_execute() {
  45. if (!m_init) {
  46. auto&& comp_node = output(0)->comp_node();
  47. // rank 0 for RemoteSend
  48. auto reg_info = m_group_client->opr_register(m_key, 2, 0, false,
  49. comp_node.get_uid());
  50. m_megray_comm = MegRayCommBuilder::get_megray_comm(
  51. reg_info.hash, m_key, 2, 0, MegRay::MEGRAY_NCCL, m_group_client);
  52. m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0)));
  53. m_init = true;
  54. }
  55. mgb_assert(m_init);
  56. size_t data_size = 1;
  57. auto&& tensor = input(0)->dev_tensor();
  58. auto&& ishp = tensor.shape();
  59. for (size_t i = 0; i < ishp.ndim; i++) {
  60. data_size *= ishp[i];
  61. }
  62. data_size *= tensor.dtype().size();
  63. auto status = m_megray_comm->send(tensor.raw_ptr(), data_size, 1, m_megray_ctx);
  64. mgb_assert(status == MegRay::MEGRAY_OK, "MegRay send failed");
  65. if (m_is_grad) {
  66. auto&& dest = output(0)->dev_tensor();
  67. if (m_output_val.empty()) {
  68. m_output_val.comp_node(dest.comp_node())
  69. .dtype(dest.dtype())
  70. .resize({1});
  71. memset(m_output_val.raw_ptr(), 0, m_output_val.dtype().size());
  72. }
  73. dest.copy_from_fixlayout(m_output_val);
  74. }
  75. }
  76. void RemoteSend::init_output_static_infer_desc() {
  77. using namespace cg::static_infer;
  78. auto&& mgr = owner_graph()->static_infer_manager();
  79. auto do_infer = [this](TensorShape& dest, const InpVal&) {
  80. if (m_is_grad) {
  81. dest = {1};
  82. } else {
  83. dest = {0};
  84. }
  85. return true;
  86. };
  87. mgr.register_shape_infer(output(0), {SourceType::CONSTANT, {}, do_infer});
  88. }
  89. cg::OperatorNodeBase::NodeProp* RemoteSend::do_make_node_prop() const {
  90. auto prop = RemoteIOBase::do_make_node_prop();
  91. prop->add_flag(NodeProp::Flag::CROSS_COMP_NODE_MEMORY);
  92. return prop;
  93. }
  94. MGB_IMPL_OPR_GRAD(RemoteSend) {
  95. mgb_assert(opr.is_grad());
  96. return RemoteRecv::make(opr.key() + ":grad",
  97. *opr.owner_graph(), opr.group_client(),
  98. OperatorNodeConfig{opr.comp_node()}.name(
  99. opr.name() + ":grad_recv"),
  100. opr.input(0)->shape(), opr.input(0)->dtype())
  101. .node();
  102. }
  103. /* ===================== RemoteRecv ===================== */
  104. MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteRecv);
  105. RemoteRecv::RemoteRecv(const std::string& key, cg::ComputingGraph& graph,
  106. std::shared_ptr<GroupClient> group_client,
  107. const OperatorNodeConfig& config,
  108. const TensorShape& shape, DType dtype) :
  109. Super(&graph, config, "remote_recv", {}),
  110. m_shape(shape), m_dtype(dtype) {
  111. m_key = key;
  112. m_group_client = group_client;
  113. add_output(None)
  114. ->dtype(dtype)
  115. .add_flag(VarNode::Flag::NO_MEM_RECLAIM)
  116. .add_flag(VarNode::Flag::DISALLOW_RT_FORCE_DYNAMIC_MEM_ALLOC);
  117. add_equivalence_component<ScalarHash<void*>>(this);
  118. }
  119. SymbolVar RemoteRecv::make(const std::string& key, cg::ComputingGraph& graph,
  120. std::shared_ptr<GroupClient> group_client,
  121. const OperatorNodeConfig& config,
  122. const TensorShape& shape, DType dtype) {
  123. auto opr = graph.insert_opr(std::make_unique<RemoteRecv>(
  124. key, graph, group_client, config, shape, dtype));
  125. return opr->output(0);
  126. }
  127. void RemoteRecv::scn_do_execute() {
  128. if (!m_init) {
  129. auto&& comp_node = output(0)->comp_node();
  130. // rank 1 for RemoteRecv
  131. auto reg_info = m_group_client->opr_register(
  132. m_key, 2, false, 1,
  133. comp_node.get_uid());
  134. m_megray_comm = MegRayCommBuilder::get_megray_comm(
  135. reg_info.hash, m_key, 2, 1, MegRay::MEGRAY_NCCL, m_group_client);
  136. m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0)));
  137. m_init = true;
  138. }
  139. mgb_assert(m_init);
  140. size_t data_size = 1;
  141. auto&& tensor = output(0)->dev_tensor();
  142. auto&& ishp = tensor.shape();
  143. for (size_t i = 0; i < ishp.ndim; i++) {
  144. data_size *= ishp[i];
  145. }
  146. data_size *= tensor.dtype().size();
  147. auto status = m_megray_comm->recv(tensor.raw_ptr(), data_size, 0, m_megray_ctx);
  148. mgb_assert(status == MegRay::MEGRAY_OK, "MegRay recv failed");
  149. }
  150. void RemoteRecv::init_output_static_infer_desc() {
  151. using namespace cg::static_infer;
  152. auto&& mgr = owner_graph()->static_infer_manager();
  153. auto do_infer = [this](TensorShape& dest, const InpVal&) {
  154. dest = m_shape;
  155. return true;
  156. };
  157. mgr.register_shape_infer(output(0), {SourceType::CONSTANT, {}, do_infer});
  158. }
  159. cg::OperatorNodeBase::NodeProp* RemoteRecv::do_make_node_prop() const {
  160. auto prop = RemoteIOBase::do_make_node_prop();
  161. prop->add_flag(NodeProp::Flag::IMPURE_FUNC);
  162. if (input().size() == 1)
  163. prop->reset_dep_type(input(), {NodeProp::DepType::DEV_COMP_ORDER});
  164. return prop;
  165. }
  166. /* ===================== shallow copy ===================== */
  167. namespace mgb {
  168. namespace opr {
  169. cg::OperatorNodeBase* opr_shallow_copy_remote_send(
  170. const serialization::OprShallowCopyContext& ctx,
  171. const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs,
  172. const OperatorNodeConfig& config) {
  173. mgb_assert(inputs.size() == 1);
  174. auto&& opr = opr_.cast_final_safe<RemoteSend>();
  175. return RemoteSend::make(opr.key(), inputs[0], opr.group_client(),
  176. opr.is_grad(), config)
  177. .node()
  178. ->owner_opr();
  179. }
  180. MGB_REG_OPR_SHALLOW_COPY(RemoteSend, opr_shallow_copy_remote_send);
  181. cg::OperatorNodeBase* opr_shallow_copy_remote_recv(
  182. const serialization::OprShallowCopyContext& ctx,
  183. const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs,
  184. const OperatorNodeConfig& config) {
  185. auto&& opr = opr_.cast_final_safe<RemoteRecv>();
  186. return RemoteRecv::make(opr.key(), *opr.owner_graph(),
  187. opr.group_client(), config, inputs[0]->shape(),
  188. inputs[0]->dtype())
  189. .node()
  190. ->owner_opr();
  191. }
  192. MGB_REG_OPR_SHALLOW_COPY(RemoteRecv, opr_shallow_copy_remote_recv);
  193. } // namespace opr
  194. } // namespace mgb
  195. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台