You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

collective_comm.cpp 3.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. /**
  2. * \file imperative/src/impl/ops/collective_comm.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain_build_config.h"
  12. #if MGB_ENABLE_OPR_MM
  13. #include "../op_trait.h"
  14. #include "../proxy_graph_detail.h"
  15. #include "megbrain/opr/mm_handler.h"
  16. #include "megbrain/utils/hash.h"
  17. #endif // MGB_ENABLE_OPR_MM
  18. #include "megbrain/imperative/ops/collective_comm.h"
  19. namespace mgb {
  20. namespace imperative {
  21. #if MGB_ENABLE_OPR_MM
  22. namespace {
  23. cg::OperatorNodeBase* apply_on_var_node(
  24. const OpDef& def,
  25. const VarNodeArray& inputs) {
  26. auto&& comm = def.cast_final_safe<CollectiveComm>();
  27. auto group_client = std::make_shared<GroupClientProxy>(
  28. ssprintf("%s:%d", comm.addr.data(), comm.port));
  29. SmallVector<std::shared_ptr<mgb::DeviceTensorND>> dev_buffer_arr(1, nullptr);
  30. auto disable = std::make_shared<DTypeScalar>();
  31. disable->set(0);
  32. cg::OperatorNodeConfig config;
  33. if (comm.comp_node.size() > 0) {
  34. config.comp_node(CompNode::load(comm.comp_node));
  35. }
  36. mgb_assert(inputs.size() == 1, "exactly one input expected");
  37. auto&& graph = inputs[0]->owner_graph();
  38. return graph->insert_opr(std::make_unique<opr::CollectiveComm>(
  39. inputs, graph, comm.key, comm.nr_devices, comm.is_root, comm.rank,
  40. comm.local_grad, group_client, comm.mode, comm.dtype, comm.backend,
  41. dev_buffer_arr, config, disable));
  42. }
  43. std::tuple<std::string, std::string> split_address(const std::string& address_and_port){
  44. auto index = address_and_port.find_last_of(':');
  45. mgb_assert(index != std::string::npos, "missing ':' in server address");
  46. return {address_and_port.substr(0, index), address_and_port.substr(index+1)};
  47. }
  48. std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node) {
  49. auto&& comm = node->cast_final_safe<opr::CollectiveComm>();
  50. auto&& group_client = comm.group_client();
  51. auto [addr, port] = split_address(group_client->get_addr());
  52. auto comp_node = node->config().get_single_comp_node().to_string_logical();
  53. return std::make_shared<CollectiveComm>(
  54. comm.key(), comm.nr_devices(), comm.rank(), comm.is_root(),
  55. comm.local_grad(), addr, std::stoi(port), comm.param().mode,
  56. comm.dtype(), comm.backend(), comp_node);
  57. }
  58. OP_TRAIT_REG(CollectiveComm, CollectiveComm, opr::CollectiveComm)
  59. .apply_on_var_node(apply_on_var_node)
  60. .make_from_op_node(make_from_op_node)
  61. .fallback();
  62. } // anonymous namespace
  63. bool CollectiveComm::is_same_st(const Hashable& another) const{
  64. auto* comm_opr = another.try_cast_final<CollectiveComm>();
  65. if(!comm_opr){
  66. return false;
  67. }
  68. return as_tuple() == comm_opr->as_tuple();
  69. }
  70. size_t CollectiveComm::hash() const{
  71. XXHash xxhash{};
  72. auto append = [&xxhash](auto field){
  73. auto hash_val = HashTrait<decltype(field)>::eval(field);
  74. xxhash.update(reinterpret_cast<void*>(&hash_val), sizeof(hash_val));
  75. };
  76. append(key);
  77. append(nr_devices);
  78. append(rank);
  79. append(is_root);
  80. append(local_grad);
  81. append(addr);
  82. append(port);
  83. append(mode);
  84. append(backend);
  85. append(comp_node);
  86. return xxhash.digest();
  87. }
  88. #else
  89. bool CollectiveComm::is_same_st(const Hashable& another) const{
  90. return OpDef::is_same_st(another);
  91. }
  92. size_t CollectiveComm::hash() const{
  93. return OpDef::hash();
  94. }
  95. #endif // MGB_ENABLE_OPR_MM
  96. MGB_DYN_TYPE_OBJ_FINAL_IMPL(CollectiveComm);
  97. } // namespace imperative
  98. } // namespace mgb
  99. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台