You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

proxy_graph.cpp 32 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878
  1. /**
  2. * \file imperative/src/impl/proxy_graph.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./blob_manager_impl.h"
  12. #include "./proxy_graph.h"
  13. #include "megbrain/graph/static_infer.h"
  14. #include "megbrain/graph/operator_node.h"
  15. #include "megbrain/opr/io.h"
  16. #include "megbrain/opr/tensor_manip.h"
  17. #include "megbrain/opr/utility.h"
  18. #include "megbrain/imperative/ops/opr_attr.h"
  19. #include "megbrain/imperative/ops/backward_graph.h"
  20. #if __cplusplus >= 201703L
  21. #include <optional>
  22. #endif
  23. namespace mgb {
  24. namespace imperative {
  25. using cg::OperatorNodeBase;
  26. template<bool p, typename T, typename F>
  27. constexpr auto&& select(T&& t, F&& f) {
  28. if constexpr (p) {
  29. return std::forward<T>(t);
  30. } else {
  31. return std::forward<F>(f);
  32. }
  33. }
  34. MGB_DEFINE_OPR_CLASS(
  35. ProxyGraph::InputPlaceholder,
  36. cg::OperatorNodeBase) // {
  37. void on_output_comp_node_stream_changed() override {
  38. mgb_assert(0);
  39. }
  40. // TODO: consider implement following initialization method,
  41. // so InputPlaceholder can be initialized correctly during
  42. // operator insertion
  43. void init_output_comp_node() override {
  44. }
  45. void init_output_format() override {
  46. }
  47. void init_output_dtype() override {
  48. }
  49. void init_output_static_infer_desc() override {
  50. }
  51. void init_output_mem_plan(bool dynamic) override {
  52. MGB_MARK_USED_VAR(dynamic);
  53. mgb_assert(0);
  54. }
  55. void do_execute(ExecEnv &env) override {
  56. mgb_assert(0);
  57. }
  58. public:
  59. Tensor* m_tensor;
  60. InputPlaceholder(ComputingGraph& graph, Tensor* tensor = nullptr,
  61. const DeviceTensorND& static_infer_value = {})
  62. : Super(&graph, {}, "device_value", {}), m_tensor(tensor),
  63. m_static_infer_value(static_infer_value) {
  64. mgb_assert(m_static_infer_value.empty() ||
  65. m_static_infer_value.comp_node() == CompNode::default_cpu());
  66. add_output(None)->add_flag(VarNode::Flag::NO_SYS_MEM_ALLOC);
  67. // never dedup
  68. add_equivalence_component<ScalarHash<void*>>(this);
  69. }
  70. static SymbolVar make(ComputingGraph& graph, Tensor& tensor) {
  71. auto opr = graph.insert_opr(
  72. std::make_unique<InputPlaceholder>(graph, &tensor));
  73. auto var = opr->output(0);
  74. auto&& dev_tensor = tensor.dev_tensor();
  75. var->m_comp_node = dev_tensor.comp_node();
  76. var->m_shape = dev_tensor.shape();
  77. if (dev_tensor.empty()) {
  78. auto layout = dev_tensor.layout();
  79. layout.init_contiguous_stride();
  80. dev_tensor.reset(dev_tensor.storage(), layout);
  81. }
  82. var->m_dev_tensor = dev_tensor;
  83. var->m_mem_plan.reset_from_owner_var().chunk()
  84. .mem_alloc_status.set_from_owner_var();
  85. return var;
  86. }
  87. static SymbolVar make(ComputingGraph& graph, const LogicalTensorDesc& desc) {
  88. auto opr = graph.insert_opr(
  89. std::make_unique<InputPlaceholder>(graph, nullptr, desc.value));
  90. auto var = opr->output(0);
  91. var->m_comp_node = desc.comp_node;
  92. var->m_shape = desc.layout;
  93. var->m_dev_tensor.reset({}, TensorLayout(desc.layout.dtype));
  94. return var;
  95. }
  96. const DeviceTensorND* get_static_infer_value(bool may_sync) {
  97. if (!m_static_infer_value.empty()) {
  98. return &m_static_infer_value;
  99. }
  100. if (m_tensor && (may_sync || m_tensor->try_get_value())) {
  101. auto&& hv = m_tensor->get_value();
  102. mgb_assert(!hv.empty());
  103. m_static_infer_value = hv.proxy_to_default_cpu();
  104. // steal ownership from shared_ptr
  105. using SP = std::shared_ptr<dt_byte>;
  106. auto& sp = const_cast<SP&>(m_static_infer_value.storage().raw_storage());
  107. static auto dummy = std::make_shared<dt_byte>();
  108. sp = SP(dummy, sp.get());
  109. return &m_static_infer_value;
  110. }
  111. return nullptr;
  112. }
  113. private:
  114. DeviceTensorND m_static_infer_value;
  115. };
  116. MGB_DYN_TYPE_OBJ_FINAL_IMPL(
  117. ProxyGraph::InputPlaceholder);
  118. class ProxyGraph::ExecEnv final : public cg::GraphExecutable::ExecEnv {
  119. public:
  120. void dispatch_on_comp_node(CompNode, Task&& task) override {
  121. task();
  122. }
  123. void dispatch_on_comp_node_with_mask(CompNode, Task&& task,
  124. cg::ExecutionMask* mask) override {
  125. mgb_throw_if(mask, GraphError,
  126. "ExecutionMask not supported in imperative mode");
  127. task();
  128. }
  129. void pause_exec() override {}
  130. void resume_exec() override {}
  131. };
  132. class ProxyGraph::StaticInferManager : public cg::static_infer::StaticInferManager {
  133. public:
  134. using Tag = cg::static_infer::Tag;
  135. using ShapeInferDesc = cg::static_infer::ShapeInferDesc;
  136. using ValueInferDesc = cg::static_infer::ValueInferDesc;
  137. using InferType = cg::static_infer::InferType;
  138. using DepVal = cg::static_infer::DepVal;
  139. using DepElement = cg::static_infer::DepElement;
  140. using DepType = cg::static_infer::DepType;
  141. using InpElement = cg::static_infer::InpElement;
  142. struct Result {
  143. TensorShape shape;
  144. DeviceTensorND value;
  145. };
  146. ProxyGraph* owner;
  147. cg::OperatorNodeBase* cur_opr = nullptr;
  148. std::vector<std::optional<ShapeInferDesc>> shape_descs;
  149. std::vector<std::optional<ValueInferDesc>> value_descs;
  150. std::vector<Result> inferred_outputs;
  151. StaticInferManager(ProxyGraph* owner_) : owner(owner_) {}
  152. size_t locate_output(VarNode* var) {
  153. mgb_assert(cur_opr);
  154. auto&& output_vars = cur_opr->output();
  155. mgb_assert(shape_descs.size() == output_vars.size());
  156. auto&& it = std::find(output_vars.begin(), output_vars.end(), var);
  157. mgb_assert(it != output_vars.end());
  158. return it - output_vars.begin();
  159. }
  160. void register_shape_infer(Tag dest, const ShapeInferDesc &desc) override {
  161. auto i = locate_output(dest);
  162. mgb_assert(!shape_descs[i]);
  163. shape_descs[i].emplace(desc);
  164. }
  165. void register_value_infer(Tag dest, const ValueInferDesc &desc) override {
  166. auto i = locate_output(dest);
  167. mgb_assert(!value_descs[i]);
  168. value_descs[i].emplace(desc);
  169. }
  170. InferType get_infer_type(Tag var) override {
  171. // may be called during get_proxy_opr or make_backward_graph
  172. // don't let opr apply any immediate optimization
  173. return {InferType::MISSING_INP, InferType::MISSING_INP};
  174. if (auto opr = var->owner_opr()->try_cast_final<InputPlaceholder>()) {
  175. return {var->shape().ndim ? InferType::CONST : InferType::MISSING_INP,
  176. opr->m_tensor ? InferType::CONST : InferType::MISSING_INP};
  177. }
  178. if (cur_opr) {
  179. auto&& outputs = cur_opr->output();
  180. auto&& it = std::find(outputs.begin(), outputs.end(), var);
  181. if (it != outputs.end()) {
  182. return {infer_shape_fallible(var) ? InferType::CONST : InferType::MISSING_INP,
  183. // value inference could be expensive
  184. InferType::MISSING_INP};
  185. }
  186. }
  187. return {InferType::MISSING_INP, InferType::MISSING_INP};
  188. }
  189. void update() {
  190. if (cur_opr != owner->m_cur_opr) {
  191. clear();
  192. cur_opr = owner->m_cur_opr;
  193. if (cur_opr) {
  194. auto nout = cur_opr->output().size();
  195. shape_descs.resize(nout);
  196. value_descs.resize(nout);
  197. inferred_outputs.resize(nout);
  198. cur_opr->init_output_static_infer_desc();
  199. }
  200. }
  201. }
  202. void clear() {
  203. cur_opr = nullptr;
  204. shape_descs.clear();
  205. value_descs.clear();
  206. inferred_outputs.clear();
  207. }
  208. template<bool is_shape>
  209. auto do_infer(Tag dest, bool may_sync)
  210. -> const std::conditional_t<is_shape, TensorShape, DeviceTensorND>* {
  211. // Some infer_func does not use InpVal passed to them, but
  212. // call infer_* on their inputs instead, so dest could be an input.
  213. // It is also possible that an opr call infer_* on its inputs before it
  214. // is inserted
  215. if (auto opr = dest->owner_opr()->try_cast_final<InputPlaceholder>()) {
  216. if constexpr (is_shape) {
  217. auto* shp = &dest->shape();
  218. return shp->ndim ? shp : nullptr;
  219. } else {
  220. return opr->get_static_infer_value(may_sync);
  221. }
  222. }
  223. mgb_assert(cur_opr);
  224. mgb_assert(cur_opr->output().size() == shape_descs.size());
  225. // dest must be an output now
  226. auto i = locate_output(dest);
  227. auto& result = inferred_outputs[i];
  228. auto& desc = select<is_shape>(shape_descs[i], value_descs[i]);
  229. // return if no need to call infer_func
  230. if constexpr (is_shape) {
  231. if (result.shape.ndim != 0) {
  232. return &result.shape;
  233. }
  234. } else {
  235. if (!result.value.empty()) {
  236. return &result.value;
  237. }
  238. }
  239. if (!desc) {
  240. return nullptr;
  241. }
  242. // fill args for infer_func
  243. cg::static_infer::InpVal args{1};
  244. args.val.reserve(desc->deps.size());
  245. auto push_shape = [&args](const TensorShape* shape) {
  246. args.val.emplace_back();
  247. args.val.back().m_shape = shape;
  248. };
  249. auto push_value = [&args](const DeviceTensorND* value) {
  250. args.val.emplace_back();
  251. args.val.back().m_value = value;
  252. };
  253. for (auto&& dep : desc->deps) {
  254. if (auto opr = dep.dest->owner_opr()->template try_cast_final<InputPlaceholder>()) {
  255. if (dep.type == DepType::SHAPE) {
  256. if (dep.dest->shape().ndim) {
  257. push_shape(&dep.dest->shape());
  258. } else {
  259. return nullptr;
  260. }
  261. } else {
  262. if (auto* p = opr->get_static_infer_value(may_sync)) {
  263. push_value(p);
  264. } else {
  265. return nullptr;
  266. }
  267. }
  268. continue;
  269. }
  270. // dep must be an output
  271. if (dep.type == DepType::SHAPE) {
  272. if (auto* p = do_infer<true>(dep.dest, may_sync)) {
  273. push_shape(p);
  274. } else {
  275. return nullptr;
  276. }
  277. } else {
  278. if (auto* p = do_infer<false>(dep.dest, may_sync)) {
  279. push_value(p);
  280. } else {
  281. return nullptr;
  282. }
  283. }
  284. }
  285. // call infer_func
  286. if constexpr (is_shape) {
  287. if (!desc->infer_func(result.shape, args)) {
  288. mgb_log_warn("something is missing for shape inference of %s",
  289. cur_opr->dyn_typeinfo()->name);
  290. return nullptr;
  291. }
  292. return &result.shape;
  293. } else {
  294. if (!desc->infer_func(result.value, args)) {
  295. mgb_log_warn("something is missing for value inference of %s",
  296. cur_opr->dyn_typeinfo()->name);
  297. return nullptr;
  298. }
  299. return &result.value;
  300. }
  301. }
  302. const TensorShape& infer_shape(Tag var) override {
  303. auto* p = do_infer<true>(var, true);
  304. mgb_assert(p, "failed to infer shape for %s", var->name().c_str());
  305. return *p;
  306. }
  307. const TensorShape* infer_shape_fallible(Tag var) override {
  308. return do_infer<true>(var, false);
  309. }
  310. const DeviceTensorND& infer_value(Tag var) override {
  311. auto* p = do_infer<false>(var, true);
  312. mgb_assert(p, "failed to infer value for %s", var->name().c_str());
  313. return *p;
  314. }
  315. const DeviceTensorND* infer_value_fallible(Tag var) override {
  316. return do_infer<false>(var, false);
  317. }
  318. DepVal get_rt_static_source_deps(const DepElement&) override {mgb_assert(0);}
  319. };
  320. class ProxyGraph::SeqCompNodeOptimizer : public cg::SeqCompNodeOptimizer {
  321. void register_stream_var(VarNode*, StreamPropType) override {}
  322. void register_propagate_function(VarNode*, PropFunction) override {}
  323. StreamPropType stream_prop_type(VarNode*) override {mgb_assert(0);}
  324. };
  325. class ProxyGraph::ProxyGraphImpl : public cg::ComputingGraph {
  326. static std::atomic<size_t> m_node_id;
  327. ProxyGraph* m_owner;
  328. MemPool<VarNode> m_var_node_pool;
  329. std::vector<std::unique_ptr<OperatorNodeBase>> m_opr_refkeeper;
  330. std::mutex m_opr_refkeeper_mtx;
  331. CompNode::UnorderedSet m_used_comp_node;
  332. VarReceiverInfo m_var_receiver_info;
  333. public:
  334. ~ProxyGraphImpl() {
  335. mgb_assert(!m_owner->m_cur_opr);
  336. if (is_finalized()) return;
  337. for (auto&& i : m_used_comp_node) {
  338. if (i.device_type() == CompNode::DeviceType::CUDA) continue;
  339. if (i.device_type() == CompNode::DeviceType::ROCM) continue;
  340. i.sync();
  341. }
  342. }
  343. ProxyGraphImpl(ProxyGraph* owner) : m_owner(owner) {
  344. options().imperative_proxy_graph = true;
  345. options().no_force_inplace = true;
  346. options().log_level = 0;
  347. m_var_receiver_info.dev_value = 1;
  348. m_var_receiver_info.allow_empty_value = 1;
  349. }
  350. static std::unique_ptr<ProxyGraphImpl> make(ProxyGraph* owner) {
  351. return std::make_unique<ProxyGraphImpl>(owner);
  352. }
  353. void add_used_comp_node(CompNode cn) {
  354. m_used_comp_node.insert(cn);
  355. }
  356. bool invalid() const {
  357. return is_finalized() || nr_oprs_in_graph() > m_owner->m_max_op_cnt;
  358. }
  359. size_t next_node_id() override {
  360. return m_node_id.fetch_add(1);
  361. }
  362. void* alloc_varnode_storage() override {
  363. return m_var_node_pool.alloc_raw();
  364. }
  365. void free_varnode_storage(void* ptr) override {
  366. m_var_node_pool.free_raw(ptr);
  367. }
  368. OperatorNodeBase* insert_opr(std::unique_ptr<OperatorNodeBase> opr_uniqp) override {
  369. mgb_assert(!is_finalized());
  370. auto opr = opr_uniqp.get();
  371. if (!opr->inserted_in_graph()) {
  372. m_opr_refkeeper.emplace_back(std::move(opr_uniqp));
  373. opr->set_inserted_in_graph();
  374. opr->init_output_comp_node();
  375. opr->init_output_dtype();
  376. opr->init_output_format();
  377. }
  378. return opr;
  379. }
  380. cg::static_infer::StaticInferManager& static_infer_manager() override {
  381. return *m_owner->m_static_infer_manager;
  382. }
  383. cg::SeqCompNodeOptimizer& seq_comp_node_optimizer() override {
  384. return *m_owner->m_seq_comp_node_optimizer;
  385. }
  386. std::shared_ptr<void> on_comp_node_finalize() override {
  387. MGB_LOCK_GUARD(m_opr_refkeeper_mtx);
  388. mgb_assert(!m_owner->m_cur_opr);
  389. // finalize would do sync first
  390. m_opr_refkeeper.clear();
  391. return {};
  392. }
  393. const VarReceiverInfo& var_receiver_in_current_comp_seq(
  394. const VarNode *var) const override {
  395. return m_var_receiver_info;
  396. }
  397. size_t nr_oprs_in_graph() const override {return m_opr_refkeeper.size();}
  398. void record_async_error(std::unique_ptr<MegBrainError> async_exc) override {
  399. if (!ProxyGraph::tm_async_error) {
  400. std::swap(async_exc, tm_async_error);
  401. }
  402. }
  403. std::unique_ptr<cg::AsyncExecutable> compile(const OutputSpec &out_spec) override {mgb_assert(0);}
  404. SmallVector<std::unique_ptr<cg::AsyncExecutable>> compile_multi_part(
  405. const SmallVector<OutputSpec>& out_specs) override {mgb_assert(0);}
  406. cg::AsyncExecutable* current_comp_seq() override {mgb_assert(0);}
  407. std::string get_mem_allocation_info() const override {mgb_assert(0);}
  408. VarNode* find_var_by_id(size_t id) const override {mgb_assert(0);}
  409. void share_device_memory_with(ComputingGraph &other) override {mgb_assert(0);}
  410. void set_device_memory_allocator(
  411. std::shared_ptr<cg::DeviceMemoryAllocator> allocator) override {mgb_assert(0);}
  412. size_t get_device_memory_size(CompNode cn) override {mgb_assert(0);}
  413. size_t clear_device_memory() override {mgb_assert(0);}
  414. void set_as_subgraph(ComputingGraph &par_graph) override {mgb_assert(0);}
  415. };
  416. std::atomic<size_t> ProxyGraph::ProxyGraphImpl::m_node_id = 0;
  417. ProxyGraph::ProxyGraph() :
  418. m_graph(ProxyGraphImpl::make(this)),
  419. m_env{new ExecEnv},
  420. m_static_infer_manager(new StaticInferManager(this)),
  421. m_seq_comp_node_optimizer(new SeqCompNodeOptimizer()) {
  422. }
  423. void ProxyGraph::reset() {
  424. mgb_assert(!m_cur_opr);
  425. m_graph = ProxyGraphImpl::make(this);
  426. }
  427. ProxyGraph* ProxyGraph::get_default_graph() {
  428. static thread_local ProxyGraph inst;
  429. if (inst.m_graph->invalid()) {
  430. inst.reset();
  431. }
  432. return &inst;
  433. }
  434. class ProxyGraph::CurOprGuard {
  435. public:
  436. CurOprGuard(ProxyGraph* owner, OperatorNodeBase* opr) : m_owner(owner) {
  437. mgb_assert(!owner->m_cur_opr);
  438. owner->m_cur_opr = opr;
  439. }
  440. CurOprGuard(const CurOprGuard&) = delete;
  441. ~CurOprGuard() {
  442. m_owner->cleanup();
  443. }
  444. private:
  445. ProxyGraph* m_owner;
  446. };
  447. #define CUR_OPR_GUARD(opr) CurOprGuard MGB_TOKENPASTE2(__cur_opr_guard_, __LINE__)(this, opr)
  448. /*********************** Physical Tensor Impl ***********************/
  449. SmallVector<LogicalTensorDesc> ProxyGraph::infer_output_attrs(
  450. const OpDef& opdef,
  451. const SmallVector<Tensor*>& inputs) {
  452. SmallVector<LogicalTensorDesc> ret;
  453. CUR_OPR_GUARD(get_proxy_opr(opdef, inputs));
  454. do_shape_infer(true);
  455. for (auto&& i: m_cur_opr->usable_output()) {
  456. mgb_assert(i->dtype().valid() && i->comp_node().valid());
  457. mgb_assert(i->shape().ndim || i->contain_flag(VarNode::Flag::NO_SYS_MEM_ALLOC));
  458. ret.push_back({{i->shape(), i->dtype()}, i->comp_node()});
  459. }
  460. return ret;
  461. }
  462. void ProxyGraph::invoke_op(const OpDef& opdef,
  463. const SmallVector<Tensor*>& inputs,
  464. const SmallVector<Tensor*>& outputs,
  465. const SmallVector<Tensor*>& workspaces) {
  466. CUR_OPR_GUARD(get_proxy_opr(opdef, inputs));
  467. init_output_tensor(outputs, workspaces);
  468. for (auto oup : m_cur_opr->output()) {
  469. m_graph->add_used_comp_node(oup->comp_node());
  470. }
  471. m_cur_opr->execute(*m_env);
  472. }
  473. void ProxyGraph::cleanup() {
  474. if (m_cur_opr) {
  475. for (auto&& i : m_cur_opr->input()) {
  476. i->m_dev_tensor.storage({});
  477. }
  478. for (auto&& i : m_cur_opr->output()) {
  479. i->m_dev_tensor.storage({});
  480. }
  481. m_static_infer_manager->clear();
  482. }
  483. m_cur_opr = nullptr;
  484. }
  485. void ProxyGraph::init_output_tensor(const SmallVector<Tensor*>& outputs, const SmallVector<Tensor*>& workspaces) {
  486. // get proxy opr
  487. auto proxy = m_cur_opr;
  488. do_shape_infer(true);
  489. size_t j = 0;
  490. size_t k = 0;
  491. for (auto&& var : proxy->output()) {
  492. auto &&chk = var->m_mem_plan.reset_from_owner_var().chunk();
  493. if (var->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
  494. // workspace
  495. if (workspaces.size()) {
  496. mgb_assert(k < workspaces.size());
  497. auto && layout = workspaces[k]->layout();
  498. mgb_assert(var->comp_node() == workspaces[k]->comp_node() &&
  499. var->shape().eq_shape(layout) &&
  500. var->dtype() == layout.dtype);
  501. var->m_dev_tensor = workspaces[k]->dev_tensor();
  502. ++ k;
  503. } else {
  504. TensorLayout layout{var->shape(), var->dtype(), var->format()};
  505. var->m_dev_tensor = BlobManager::inst()->alloc_workspace_with_defrag(var->comp_node(), layout);
  506. }
  507. } else {
  508. mgb_assert(j < outputs.size());
  509. auto &&tensor = outputs[j];
  510. auto &&layout = tensor->layout();
  511. mgb_assert(var->comp_node() == tensor->comp_node() &&
  512. var->shape().eq_shape(layout) &&
  513. var->dtype() == layout.dtype);
  514. var->assign_dev_tensor_from_tensor(tensor->dev_tensor());
  515. ++ j;
  516. }
  517. chk.mem_alloc_status.set_from_owner_var();
  518. }
  519. mgb_assert(j == outputs.size());
  520. mgb_assert(k == workspaces.size());
  521. // Memory forwarding was bypassed in megbrain with graph option
  522. // imerative_proxy_graph on, here we call mem_plan_fwd_in2out_readonly
  523. // to initialize some opr(e.g. Subtensor)'s internal state
  524. // TODO: implement memory forwarding
  525. proxy->mem_plan_fwd_in2out_readonly();
  526. {
  527. // some opr (e.g. Reduce) rely on on_mem_status_changed to set
  528. // input/output tensor corretly, since we bypass var_node_mem_mgr
  529. // on_mem_status_changed should be called here
  530. auto&& cb = proxy->get_opr_event_callback().on_mem_status_changed;
  531. if (cb.valid()) {
  532. cb.val()();
  533. }
  534. }
  535. }
  536. cg::OperatorNodeBase* ProxyGraph::get_proxy_opr(
  537. const OpDef& opdef,
  538. const SmallVector<Tensor*>& inputs) {
  539. VarNodeArray vinputs(inputs.size());
  540. for (size_t i = 0; i < inputs.size(); ++ i) {
  541. vinputs[i] = InputPlaceholder::make(*m_graph, *inputs[i]).node();
  542. }
  543. auto opr = OpDef::apply_on_var_node(opdef, vinputs)[0]->owner_opr();
  544. mgb_assert(!opr->same_type<InputPlaceholder>());
  545. for (auto &&i : opr->input()) {
  546. mgb_assert(i->owner_opr()->same_type<InputPlaceholder>());
  547. }
  548. return opr;
  549. }
  550. /*********************** Logical Tensor Impl ***********************/
  551. size_t ProxyGraph::get_opr_output_size(const OpDef& opdef,
  552. const SmallVector<LogicalTensorDesc>& inputs) {
  553. return get_proxy_opr(opdef, inputs)->usable_output().size();
  554. }
  555. std::tuple<SmallVector<LogicalTensorDesc>, bool> ProxyGraph::infer_output_attrs_fallible(
  556. const OpDef& opdef,
  557. const SmallVector<LogicalTensorDesc>& inputs) {
  558. auto opr = get_proxy_opr(opdef, inputs);
  559. CUR_OPR_GUARD(opr);
  560. SmallVector<LogicalTensorDesc> outputs;
  561. bool validated = do_shape_infer(false);
  562. for (auto&& i : opr->usable_output()) {
  563. outputs.push_back({{i->shape(), i->dtype()}, i->comp_node()});
  564. }
  565. bool need_check = opr->same_type<opr::Reshape>();
  566. return {outputs, validated && !need_check};
  567. }
  568. std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> ProxyGraph::infer_output_mem_desc(
  569. const OpDef& def,
  570. const SmallVector<Tensor*>& inputs_tensors,
  571. const SmallVector<MemoryDesc>& inputs_mems) {
  572. auto opr = get_proxy_opr(def, inputs_tensors);
  573. CUR_OPR_GUARD(opr);
  574. do_shape_infer(true);
  575. SmallVector<MemoryDesc> outputs;
  576. SmallVector<MemoryDesc> workspaces;
  577. size_t cur_id = 0;
  578. for (auto&& i : opr->output()) {
  579. if (i->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
  580. workspaces.push_back({{i->shape(), i->dtype(), i->format()}, 0, i->comp_node(), StorageIdentifier::make(++ cur_id)});
  581. } else {
  582. outputs.push_back({{i->shape(), i->dtype()}, 0, i->comp_node(), StorageIdentifier::make(++ cur_id)});
  583. }
  584. }
  585. return {outputs, workspaces};
  586. }
  587. struct ProxyGraph::GradGraph {
  588. cg::VarNodeArray inputs;
  589. cg::VarNodeArray outputs;
  590. cg::VarNodeArray output_grads;
  591. cg::VarNode* grad;
  592. };
  593. EncodedSubgraph
  594. ProxyGraph::make_backward_graph(
  595. const OpDef& opdef,
  596. const SmallVector<LogicalTensorDesc>& input_descs,
  597. const SmallVector<bool>& input_requires_grad,
  598. const SmallVector<bool>& output_has_grad) {
  599. ThinHashMap<VarNode*, size_t> var2idx;
  600. auto push = [&var2idx, cnt=1](VarNode* var) mutable { //cnt is always greater non zero
  601. auto&& ret = var2idx.emplace(var, cnt ++);
  602. mgb_assert(ret.second, "var %s has been already inserted", var->cname());
  603. return ret.first->second;
  604. };
  605. auto inputs = make_input_place_holders(input_descs);
  606. auto fwd = OpDef::apply_on_var_node(opdef, inputs)[0]->owner_opr();
  607. auto&& outputs = fwd->usable_output();
  608. SmallVector<LogicalTensorDesc> output_descs;
  609. for (auto&& i : outputs) {
  610. output_descs.push_back({TensorLayout{i->dtype()}, i->comp_node()});
  611. }
  612. auto output_grads = make_input_place_holders(output_descs);
  613. mgb_assert(output_grads.size() == output_has_grad.size(), "%d vs %d",
  614. output_grads.size(), output_has_grad.size());
  615. bool any_input_has_grad = false;
  616. for (size_t i = 0; i < output_grads.size(); ++ i) {
  617. if (!output_has_grad[i]) {
  618. output_grads[i] = nullptr;
  619. } else {
  620. any_input_has_grad = true;
  621. }
  622. }
  623. if (!any_input_has_grad) {
  624. return {};
  625. }
  626. auto* gfunc = cg::lookup_grad_func(fwd->dyn_typeinfo());
  627. EncodedSubgraph result;
  628. auto&& igraph = result.graph;
  629. size_t nr_backward_graph_inputs = 0;
  630. auto gen_expr = [this, &var2idx, &igraph, &push, &fwd,
  631. &nr_backward_graph_inputs](cg::OperatorNodeBase* op) {
  632. if (auto t = as_tensor(op)) {
  633. mgb_assert(op->output().size() == 1);
  634. igraph.constants.emplace_back(push(op->output(0)), std::move(t));
  635. } else if (op->same_type<InputPlaceholder>()) {
  636. ++ nr_backward_graph_inputs;
  637. push(op->output(0));
  638. } else {
  639. SmallVector<size_t> inputs, outputs;
  640. for (auto &&i : op->input()) {
  641. if (i->owner_opr() == fwd) {
  642. if (var2idx.find(i) == var2idx.end()) {
  643. ++ nr_backward_graph_inputs;
  644. push(i);
  645. }
  646. }
  647. inputs.push_back(var2idx.at(i));
  648. }
  649. for (auto &&i : op->usable_output()) {
  650. outputs.push_back(push(i));
  651. }
  652. igraph.exprs.push_back({OpDef::make_from_op_node(op), inputs, outputs});
  653. }
  654. };
  655. // set backward graph outputs
  656. cg::DepOprIter iter{gen_expr};
  657. iter.set_visited(fwd);
  658. result.output_mask.resize(inputs.size());
  659. VarNodeArray output_grads_with_unused_var;
  660. {
  661. auto iter = output_grads.begin();
  662. for (auto&& i : fwd->output()) {
  663. if (i->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
  664. // the var node with VOLATILE_CONTENT(e.g. workspace
  665. // or an empty var) would not be considered as a normal
  666. // output, so its grad is always NULL
  667. output_grads_with_unused_var.push_back(nullptr);
  668. } else {
  669. output_grads_with_unused_var.push_back(*iter);
  670. ++ iter;
  671. }
  672. }
  673. mgb_assert(iter == output_grads.end());
  674. }
  675. Maybe<VarNodeArray> grad_results;
  676. for (size_t i = 0; i < inputs.size(); ++ i) {
  677. VarNode* grad;
  678. if (grad_results.valid()) {
  679. grad = grad_results.val()[i];
  680. } else {
  681. mgb_assert(gfunc, "could not find grad function");
  682. auto res = (*gfunc)(fwd, i, output_grads_with_unused_var);
  683. if (res.from_single()) {
  684. grad = res.single();
  685. } else {
  686. grad_results.emplace(res.all(fwd));
  687. grad = grad_results.val()[i];
  688. }
  689. }
  690. if (grad && !grad->owner_opr()->same_type<opr::InvalidGrad>()
  691. && input_requires_grad[i]) {
  692. mgb_assert(!grad->owner_opr()->same_type<opr::InvalidGrad>(),
  693. "gradient of operator %s w.r.t. input #%lu is "
  694. "either not well defined or not implemented",
  695. fwd->dyn_typeinfo()->name, i);
  696. iter.add(grad);
  697. igraph.outputs.push_back(var2idx.at(grad));
  698. result.output_mask[i] = true;
  699. } else {
  700. result.output_mask[i] = false;
  701. }
  702. }
  703. if (igraph.outputs.empty()) {
  704. return {};
  705. }
  706. // set backward graph inputs
  707. igraph.inputs.reserve(nr_backward_graph_inputs);
  708. result.input_mask.reserve(nr_backward_graph_inputs);
  709. auto write_inputs = [&igraph, &var2idx, &result](const VarNodeArray& vars) {
  710. for (auto&& i: vars) {
  711. auto&& iter = var2idx.find(i);
  712. if (iter != var2idx.end()) {
  713. igraph.inputs.push_back(iter->second);
  714. result.input_mask.push_back(true);
  715. } else {
  716. result.input_mask.push_back(false);
  717. }
  718. }
  719. };
  720. write_inputs(inputs);
  721. write_inputs(outputs);
  722. write_inputs(output_grads);
  723. mgb_assert(igraph.inputs.size() == nr_backward_graph_inputs);
  724. return result;
  725. }
  726. cg::OperatorNodeBase* ProxyGraph::get_proxy_opr(const OpDef& opdef,
  727. const SmallVector<LogicalTensorDesc>& inputs) {
  728. mgb_assert(!m_cur_opr);
  729. auto vinputs = make_input_place_holders(inputs);
  730. return OpDef::apply_on_var_node(opdef, vinputs)[0]->owner_opr();
  731. }
  732. VarNodeArray ProxyGraph::make_input_place_holders(const SmallVector<LogicalTensorDesc>& inputs) {
  733. VarNodeArray vinputs(inputs.size());
  734. for (size_t i = 0; i < inputs.size(); ++ i) {
  735. vinputs[i] = InputPlaceholder::make(*m_graph, inputs[i]).node();
  736. }
  737. return vinputs;
  738. }
  739. /*********************** Common Impl ***********************/
  740. bool ProxyGraph::do_shape_infer(bool sync_value) {
  741. m_static_infer_manager->update();
  742. bool validated = true;
  743. for (auto* var : m_cur_opr->output()) {
  744. if (sync_value) {
  745. var->shape(m_static_infer_manager->infer_shape(var));
  746. } else if (auto* shape = m_static_infer_manager->infer_shape_fallible(var)) {
  747. var->shape(*shape);
  748. } else {
  749. validated = false;
  750. }
  751. }
  752. return validated;
  753. }
  754. TensorPtr ProxyGraph::as_tensor(cg::OperatorNodeBase* opr, bool share) {
  755. // TODO : maybe some tensor should copy value from origin opr rather than
  756. // share the RawStorage
  757. mgb_assert(share, "can't share memory with opr %s", opr->cname());
  758. if (opr->same_type<opr::ImmutableTensor>()) {
  759. auto&& dv = opr->cast_final_safe<opr::ImmutableTensor>().value();
  760. HostTensorND hv(dv.comp_node(), dv.shape(), dv.dtype());
  761. const DeviceTensorND* cpu_value;
  762. // get host value
  763. if (opr->owner_graph() == m_graph.get()) {
  764. CUR_OPR_GUARD(opr);
  765. m_static_infer_manager->update();
  766. cpu_value = m_static_infer_manager->infer_value_fallible(opr->output(0));
  767. } else {
  768. cpu_value = opr->owner_graph()->static_infer_manager().infer_value_fallible(opr->output(0));
  769. }
  770. mgb_assert(cpu_value);
  771. mgb_assert(cpu_value->comp_node() == CompNode::default_cpu());
  772. // default_cpu is synchronous with respect to caller
  773. hv.proxy_to_default_cpu().copy_from_fixlayout(*cpu_value);
  774. return Tensor::make(dv, hv);
  775. } else if (opr->same_type<opr::SharedDeviceTensor>()) {
  776. return Tensor::make(opr->cast_final_safe<opr::SharedDeviceTensor>().get_dev_tensor());
  777. } else {
  778. return {};
  779. }
  780. }
  781. thread_local std::unique_ptr<MegBrainError> ProxyGraph::tm_async_error;
  782. } // namespace imperative
  783. } // namespace mgb
  784. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台