You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

proxy_graph.h 3.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. /**
  2. * \file imperative/src/impl/proxy_graph.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megbrain/imperative.h"
  13. #include "megbrain/graph/cg.h"
  14. #include "megbrain/graph/grad_impl.h"
  15. #include "megbrain/comp_node.h"
  16. #include "megbrain/imperative/ops/backward_graph.h"
  17. namespace mgb {
  18. namespace imperative {
  19. class ProxyGraph : public NonCopyableObj {
  20. public:
  21. static ProxyGraph* get_default_graph();
  22. static std::unique_ptr<MegBrainError> get_async_error() {
  23. return std::move(tm_async_error);
  24. }
  25. /********************** Physical Tensor API **********************/
  26. SmallVector<LogicalTensorDesc> infer_output_attrs(
  27. const OpDef& opdef,
  28. const SmallVector<Tensor*>& inputs);
  29. void invoke_op(
  30. const OpDef& opdef,
  31. const SmallVector<Tensor*>& inputs,
  32. const SmallVector<Tensor*>& outputs,
  33. const SmallVector<Tensor*>& workspace);
  34. BackwardGraphResult make_backward_graph(
  35. const OpDef& opdef,
  36. const SmallVector<LogicalTensorDesc>& input_descs,
  37. const SmallVector<bool>& input_requires_grad,
  38. const SmallVector<bool>& output_has_grad);
  39. std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
  40. const OpDef& def,
  41. const SmallVector<Tensor*>& inputs_tensors,
  42. const SmallVector<MemoryDesc>& inputs_mems);
  43. /********************** Logical Tensor API **********************/
  44. size_t get_opr_output_size(
  45. const OpDef& opdef,
  46. const SmallVector<LogicalTensorDesc>& inputs);
  47. std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
  48. const OpDef& opdef,
  49. const SmallVector<LogicalTensorDesc>& inputs);
  50. private:
  51. ProxyGraph();
  52. class ProxyGraphImpl;
  53. class ExecEnv;
  54. class StaticInferManager;
  55. class SeqCompNodeOptimizer;
  56. class InputPlaceholder;
  57. struct ProxyGraphInst;
  58. struct GradGraph;
  59. class CurOprGuard;
  60. void reset();
  61. /********************** Physical Tensor Helper **********************/
  62. void cleanup();
  63. void init_output_tensor(
  64. const SmallVector<Tensor*>& outputs,
  65. const SmallVector<Tensor*>& workspace);
  66. cg::OperatorNodeBase* get_proxy_opr(
  67. const OpDef& opdef,
  68. const SmallVector<Tensor*>& inputs);
  69. /********************** Logical Tensor Helper **********************/
  70. cg::OperatorNodeBase* get_proxy_opr(
  71. const OpDef& opdef,
  72. const SmallVector<LogicalTensorDesc>& inputs);
  73. cg::VarNodeArray make_input_place_holders(
  74. const SmallVector<LogicalTensorDesc>& inputs);
  75. /********************** Common Helper **********************/
  76. bool do_shape_infer(bool sync_value);
  77. TensorPtr as_tensor(cg::OperatorNodeBase* opr, bool share=true);
  78. cg::OperatorNodeBase* m_cur_opr = nullptr;
  79. std::unique_ptr<ProxyGraphImpl> m_graph;
  80. size_t m_max_op_cnt = 100;
  81. std::unique_ptr<ExecEnv> m_env;
  82. std::unique_ptr<StaticInferManager> m_static_infer_manager;
  83. std::unique_ptr<SeqCompNodeOptimizer> m_seq_comp_node_optimizer;
  84. static thread_local std::unique_ptr<MegBrainError> tm_async_error;
  85. };
  86. } // namespace imperative
  87. } // namespace mgb
  88. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台