You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

io_remote.h 3.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. /**
  2. * \file src/opr-mm/include/megbrain/opr/io_remote.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megbrain/graph.h"
  13. #include "megbrain/opr/internal/mixin_base.h"
  14. #include "megbrain/opr/group_manager.h"
  15. #include "megray.h"
  16. namespace mgb {
  17. namespace opr {
  18. /*!
  19. * \brief base class for remote I/O nodes
  20. */
  21. MGB_DEFINE_CLS_WITH_SUPER(RemoteIOBase, cg::SingleCNOperatorNodeBase) // {
  22. public:
  23. const std::string& key() const { return m_key; }
  24. std::shared_ptr<GroupClient> group_client() const {
  25. return m_group_client;
  26. }
  27. protected:
  28. std::string m_key;
  29. std::shared_ptr<GroupClient> m_group_client;
  30. std::shared_ptr<MegRay::Communicator> m_megray_comm;
  31. std::shared_ptr<MegRay::Context> m_megray_ctx;
  32. bool m_init = false;
  33. using Super::Super;
  34. };
  35. /*!
  36. * \brief send a variable to remote address; a virtual output is produced
  37. * for expressing dependency
  38. */
  39. MGB_DEFINE_OPR_CLASS(RemoteSend, RemoteIOBase) // {
  40. public:
  41. RemoteSend(const std::string& key, VarNode* var,
  42. std::shared_ptr<GroupClient> group_client,
  43. bool is_grad, const OperatorNodeConfig& config);
  44. static SymbolVar make(
  45. const std::string& key, SymbolVar var,
  46. std::shared_ptr<GroupClient> group_client,
  47. bool is_grad, const OperatorNodeConfig& config = {});
  48. bool is_grad() const { return m_is_grad; }
  49. private:
  50. HostTensorND m_output_val;
  51. bool m_is_grad;
  52. void scn_do_execute() override;
  53. void init_output_static_infer_desc() override;
  54. NodeProp* do_make_node_prop() const override;
  55. };
  56. /*!
  57. * \brief receive a variable from remote address; target computing node
  58. * of the var must be specified in config
  59. */
  60. MGB_DEFINE_OPR_CLASS(RemoteRecv, RemoteIOBase) // {
  61. public:
  62. RemoteRecv(const std::string& key, cg::ComputingGraph& graph,
  63. std::shared_ptr<GroupClient> group_client,
  64. const OperatorNodeConfig& config, const TensorShape& shape,
  65. DType dtype);
  66. static SymbolVar make(
  67. const std::string& key, cg::ComputingGraph& graph,
  68. std::shared_ptr<GroupClient> group_client,
  69. const OperatorNodeConfig& config, const TensorShape& shape,
  70. DType dtype);
  71. private:
  72. const TensorShape m_shape;
  73. const DType m_dtype;
  74. const CompNode m_comp_node;
  75. DeviceTensorND m_dev_buffer;
  76. void scn_do_execute() override;
  77. void init_output_static_infer_desc() override;
  78. NodeProp* do_make_node_prop() const override;
  79. };
  80. } // namespace opr
  81. } // namespace mgb
  82. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台