You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utility.cpp 28 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841
  1. /**
  2. * \file src/opr/impl/utility.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/graph/grad_impl.h"
  12. #include "megbrain/graph/event.h"
  13. #include "megbrain/graph/exc_extra_info.h"
  14. #include "megbrain/graph/operator_node.h"
  15. #include "megbrain/utils/debug.h"
  16. #include "megbrain/opr/utility.h"
  17. #include "megbrain/opr/basic_arith_wrapper.h"
  18. #include "megbrain/opr/internal/megdnn_opr_wrapper.h"
  19. #include "megbrain/comp_node_env.h"
  20. #include <thread>
  21. using namespace mgb;
  22. using namespace opr;
  23. #if !MGB_BUILD_SLIM_SERVING
  24. namespace {
  25. OperatorNodeConfig setup_config_cn(const OperatorNodeConfig& config_,
  26. const CompNode& cn) {
  27. auto prev_cn = config_.get_single_comp_node();
  28. mgb_assert(!prev_cn.valid() || cn == prev_cn);
  29. auto config = config_;
  30. config.comp_node(cn);
  31. return config;
  32. }
  33. } // namespace
  34. /* ===================== Sleep ===================== */
  35. MGB_DYN_TYPE_OBJ_FINAL_IMPL(Sleep);
  36. void Sleep::scn_do_execute() {
  37. #if MGB_HAVE_THREAD
  38. auto in = input(0), out = output(0);
  39. if (m_type.device) {
  40. if (!m_opr || m_opr.comp_node() != comp_node()) {
  41. m_opr = intl::create_megdnn_opr<megdnn::Sleep>(comp_node());
  42. }
  43. m_opr->param().time = m_seconds;
  44. m_opr->exec();
  45. }
  46. if (m_type.host) {
  47. std::this_thread::sleep_for(std::chrono::microseconds(
  48. static_cast<uint64_t>(m_seconds * 1e6)));
  49. }
  50. out->dev_tensor().copy_from_fixlayout(in->dev_tensor());
  51. #else
  52. mgb_throw(MegBrainError, "sleep is unavilable when threading is disabled");
  53. #endif
  54. }
  55. void Sleep::record_execute_deps(ExecDependencyArray& deps) {
  56. if (m_opr) {
  57. mixin::MegDNNOprHolder::record_megdnn_opr(std::move(m_opr), deps);
  58. }
  59. }
  60. void Sleep::sleep(const CompNode &node, double seconds) {
  61. node.activate();
  62. auto opr = intl::get_megdnn_handle(node)->create_operator<megdnn::Sleep>();
  63. opr->param().time = seconds;
  64. opr->exec();
  65. }
  66. Sleep::Sleep(VarNode *node, double seconds, Type type,
  67. const OperatorNodeConfig &config):
  68. Super(node->owner_graph(), config, "sleep", {node}),
  69. m_seconds{seconds}, m_type{type}
  70. {
  71. mgb_assert(seconds > 0);
  72. add_input({node});
  73. add_output(None);
  74. add_equivalence_component<PODHash<double>>(&m_seconds);
  75. add_equivalence_component<PODHash<Type>>(&m_type);
  76. }
  77. SymbolVar Sleep::make(SymbolVar node, double seconds, Type type,
  78. const OperatorNodeConfig &config) {
  79. mgb_assert(seconds >= 0);
  80. if (!seconds)
  81. return node;
  82. return node.insert_single_output_opr<Sleep>(node.node(),
  83. seconds, type, config);
  84. }
  85. MGB_IMPL_OPR_GRAD(Sleep) {
  86. return out_grad.at(0);
  87. }
  88. /* ===================== Timestamp ===================== */
  89. MGB_DYN_TYPE_OBJ_FINAL_IMPL(Timestamp);
  90. class Timestamp::GraphStorage final : public UserDataContainer::UserData {
  91. MGB_TYPEINFO_OBJ_DECL;
  92. //! whether oprs and event info should be cleared upon next register call
  93. bool m_should_clear = false;
  94. SyncEventConnecter::ReceiverHandler m_recv_handler_wait,
  95. m_recv_handler_compile;
  96. std::vector<Timestamp*> m_oprs;
  97. CompNode::UnorderedMap<CompNode::Event*> m_first_event;
  98. public:
  99. GraphStorage(ComputingGraph* cg) {
  100. auto on_compile = [this](const cg::event::CompSeqOrderDetermined&) {
  101. m_should_clear = true;
  102. };
  103. auto on_wait = [this](const cg::event::CompSeqExecFinished& event) {
  104. for (auto i : m_oprs) {
  105. i->update();
  106. }
  107. mgb_assert(event.device_actually_finished,
  108. "Timestamp in subgraph is not supported");
  109. };
  110. m_recv_handler_compile =
  111. cg->event()
  112. .register_receiver<cg::event::CompSeqOrderDetermined>(
  113. on_compile);
  114. m_recv_handler_wait =
  115. cg->event().register_receiver<cg::event::CompSeqExecFinished>(
  116. on_wait);
  117. }
  118. //! return the first event on this comp seq
  119. CompNode::Event* register_opr(Timestamp* opr) {
  120. if (m_should_clear) {
  121. m_oprs.clear();
  122. m_first_event.clear();
  123. m_should_clear = true;
  124. }
  125. m_oprs.push_back(opr);
  126. auto ins = m_first_event.insert({opr->comp_node(), opr->m_event.get()});
  127. return ins.first->second;
  128. }
  129. };
  130. MGB_TYPEINFO_OBJ_IMPL(Timestamp::GraphStorage);
  131. void Timestamp::add_input_layout_constraint() {
  132. if (!m_event) {
  133. m_event = comp_node().create_event(CompNode::Event::Flags::NEED_TIMER);
  134. }
  135. auto make = [this]() {
  136. return std::make_shared<GraphStorage>(owner_graph());
  137. };
  138. auto storage =
  139. owner_graph()
  140. ->options()
  141. .user_data.get_user_data_or_create<GraphStorage>(make);
  142. m_first_event = storage->register_opr(this);
  143. Super::add_input_layout_constraint();
  144. }
  145. void Timestamp::scn_do_execute_finish(const DeviceTensorND&) {
  146. m_event->record();
  147. }
  148. void Timestamp::on_output_comp_node_stream_changed() {
  149. m_event.reset();
  150. Super::on_output_comp_node_stream_changed();
  151. }
  152. void Timestamp::update() {
  153. mgb_assert(m_dest_off < m_dest->shape(0));
  154. m_dest->ptr<float>()[m_dest_off] =
  155. m_first_event->elapsed_time_until(*m_event);
  156. }
  157. Timestamp::Timestamp(VarNode* node, std::shared_ptr<HostTensorND> dest,
  158. size_t dest_off, const OperatorNodeConfig& config)
  159. : Super(node->owner_graph(), config, "timestamp", {node}),
  160. m_dest{std::move(dest)},
  161. m_dest_off{dest_off} {
  162. mgb_assert(m_dest, "empty dest tensor");
  163. mgb_assert(m_dest->dtype() == dtype::Float32{} &&
  164. m_dest->shape().ndim == 1 &&
  165. dest_off < m_dest->shape()[0] &&
  166. m_dest->layout().stride[0] == 1,
  167. "dest tensor must be 1-dimensional float32; got %s (%s)",
  168. m_dest->layout().to_string().c_str(), m_dest->dtype().name());
  169. add_input({node});
  170. add_output(None);
  171. add_equivalence_component<ScalarHash<void*>>(m_dest.get());
  172. add_equivalence_component<ScalarHash<size_t>>(m_dest_off);
  173. }
  174. SymbolVar Timestamp::make(SymbolVar node, std::shared_ptr<HostTensorND> dest,
  175. size_t dest_off, const OperatorNodeConfig& config) {
  176. return node.insert_single_output_opr<Timestamp>(
  177. node.node(), std::move(dest), dest_off, config);
  178. }
  179. /* ========================== VirtualDep ============================ */
  180. MGB_DYN_TYPE_OBJ_FINAL_IMPL(VirtualDep);
  181. VirtualDep::VirtualDep(const VarNodeArray& inputs,
  182. const OperatorNodeConfig& config)
  183. : Super(inputs[0]->owner_graph(),
  184. setup_config_cn(config, inputs[0]->comp_node()), "virtual_dep",
  185. inputs) {
  186. for (auto inp : inputs) {
  187. add_input({inp});
  188. }
  189. mgb_assert(inputs[0]->dtype().valid());
  190. add_output(None)->dtype(inputs[0]->dtype());
  191. }
  192. cg::OperatorNodeBase::NodeProp* VirtualDep::do_make_node_prop() const {
  193. auto prop = Super::do_make_node_prop();
  194. if (input().size() > 1) {
  195. SmallVector<NodeProp::DepType> dep_types{NodeProp::DepType::DEV_VALUE};
  196. for (size_t i = 1; i < input().size(); ++i) {
  197. dep_types.push_back(NodeProp::DepType::DEV_COMP_ORDER);
  198. }
  199. prop->reset_dep_type(input(), dep_types);
  200. }
  201. prop->add_flag(
  202. cg::OperatorNodeBase::NodeProp::Flag::CROSS_COMP_NODE_MEMORY);
  203. return prop;
  204. }
  205. SymbolVar VirtualDep::make(const SymbolVarArray& inputs,
  206. const OperatorNodeConfig& config) {
  207. mgb_assert(!inputs.empty());
  208. auto nodes = to_var_node_array(inputs);
  209. return inputs[0].insert_single_output_opr<VirtualDep>(nodes, config);
  210. }
  211. MGB_IMPL_OPR_GRAD(VirtualDep) {
  212. if (wrt_idx == 0) {
  213. return out_grad.at(0);
  214. }
  215. return nullptr;
  216. }
  217. #endif // MGB_BUILD_SLIM_SERVING
  218. /* ===================== MarkDynamicVar ===================== */
  219. MGB_DYN_TYPE_OBJ_FINAL_IMPL(MarkDynamicVar);
  220. void MarkDynamicVar::scn_do_execute() {
  221. auto i = input(0), o = output(0);
  222. o->shape_alloc(i->shape());
  223. o->dev_tensor().copy_from_fixlayout(i->dev_tensor());
  224. }
  225. #ifdef MGB_ENABLE_GRAD
  226. MGB_IMPL_OPR_GRAD(MarkDynamicVar) {
  227. return MarkDynamicVar::make(out_grad.at(0)).node();
  228. }
  229. #endif
  230. MarkDynamicVar::MarkDynamicVar(VarNode *node, const OperatorNodeConfig &config):
  231. Super{node->owner_graph(), config, "mark_dyn", {node}}
  232. {
  233. add_input({node});
  234. add_output(None)
  235. ->add_flag(VarNode::Flag::NO_SYS_MEM_ALLOC)
  236. .add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
  237. }
  238. SymbolVar MarkDynamicVar::make(
  239. SymbolVar node, const OperatorNodeConfig &config) {
  240. return node.insert_single_output_opr<MarkDynamicVar>(node.node(), config);
  241. }
  242. MarkDynamicVar::NodeProp* MarkDynamicVar::do_make_node_prop() const {
  243. auto ret = Super::do_make_node_prop();
  244. ret->add_dep_type_existing_var(input(0),
  245. NodeProp::DepType::VALUE_ALLOW_EMPTY);
  246. return ret;
  247. }
  248. /* ===================== CallbackInjector ===================== */
  249. MGB_DYN_TYPE_OBJ_FINAL_IMPL(CallbackInjector);
  250. CallbackInjector::CallbackInjector(
  251. VarNode *inp, const Param &param, const OperatorNodeConfig &config):
  252. Super{inp->owner_graph(), config, "callback", {inp}},
  253. m_param{param}
  254. {
  255. add_input({inp});
  256. add_output(None);
  257. if (m_param.ignore_side_effect) {
  258. set_ignore_side_effect();
  259. }
  260. // so this opr would not get deduped
  261. add_equivalence_component<ScalarHash<void*>>(this);
  262. }
  263. CallbackInjector::CallbackInjector(
  264. VarNodeArray& inps,
  265. const Param &param,
  266. const OperatorNodeConfig &config):
  267. Super{inps[0]->owner_graph(), config, "callback", inps}, m_param{param}
  268. {
  269. for (auto inp : inps) {
  270. add_input({inp});
  271. }
  272. add_output(None);
  273. if (m_param.ignore_side_effect) {
  274. set_ignore_side_effect();
  275. }
  276. // so this opr would not get deduped
  277. add_equivalence_component<ScalarHash<void*>>(this);
  278. }
  279. SymbolVar CallbackInjector::make(mgb::cg::SymbolVarArray inp, const Param &param,
  280. const OperatorNodeConfig &config) {
  281. auto nodes = to_var_node_array(inp);
  282. return inp[0].insert_single_output_opr<CallbackInjector>(nodes, param, config);
  283. }
  284. void CallbackInjector::scn_do_execute_finish(const DeviceTensorND &val) {
  285. SmallVector<DeviceTensorND> input_list = {};
  286. for(size_t i = 0; i < input().size(); ++i) {
  287. input_list.push_back(input(i)->dev_tensor());
  288. }
  289. m_param.callback(const_cast<SmallVector<DeviceTensorND>&>(input_list));
  290. }
  291. cg::OperatorNodeBase::NodeProp* CallbackInjector::do_make_node_prop() const {
  292. auto prop = ForwardInputToOutput::do_make_node_prop();
  293. if (!m_param.allow_auto_dup) {
  294. prop->add_flag(NodeProp::Flag::NO_AUTOMATIC_DUP);
  295. }
  296. return prop;
  297. }
  298. cg::static_infer::ValueInferDesc
  299. CallbackInjector::mixin_get_static_infer_desc(OperatorNodeBase &opr) {
  300. using namespace cg::static_infer;
  301. auto infer_val = [this](DeviceTensorND& dst, const InpVal& iv) -> bool {
  302. dst = iv.val[0].value();
  303. if (!m_param.invoke_for_static_infer) {
  304. return true;
  305. }
  306. if (m_warn_printed < 10) {
  307. mgb_log_warn(
  308. "[warn %d/10] CallbackInjector %s is called during static "
  309. "value inference. The warning can be safely ignored if "
  310. "CallbackInjector does nothing other than inspecting the "
  311. "tensor value; otherwise it may introduce unexpected "
  312. "behavior.",
  313. ++m_warn_printed, cname());
  314. }
  315. SmallVector<DeviceTensorND> callback_list = {};
  316. for (size_t i = 0; i < iv.val.size(); ++i) {
  317. if (m_append_one_more_shape and i + 1== iv.val.size()) {
  318. continue;
  319. }
  320. callback_list.push_back(iv.val[i].value());
  321. }
  322. m_param.callback(callback_list);
  323. return true;
  324. };
  325. DepVal dep_val_list = {};
  326. for (size_t i = 0; i < input().size(); ++i) {
  327. dep_val_list.push_back({opr.input(i), DepType::VALUE});
  328. }
  329. if (m_param.invoke_for_static_infer) {
  330. return {SourceType::DEP, {{opr.input(0), DepType::VALUE}}, infer_val};
  331. } else {
  332. return {SourceType::DEP, dep_val_list, infer_val};
  333. }
  334. }
  335. #ifdef MGB_ENABLE_GRAD
  336. MGB_IMPL_OPR_GRAD(CallbackInjector) {
  337. MGB_MARK_USED_VAR(wrt_idx);
  338. return out_grad.at(0);
  339. }
  340. #endif
  341. /* ===================== MarkNoBroadcastElemwise ===================== */
  342. MGB_DYN_TYPE_OBJ_FINAL_IMPL(MarkNoBroadcastElemwise);
  343. MarkNoBroadcastElemwise::MarkNoBroadcastElemwise(
  344. VarNode* input, const OperatorNodeConfig &config):
  345. Super(input->owner_graph(), config, "no_brdcst", {input})
  346. {
  347. add_input({input});
  348. add_output(None);
  349. set_ignore_side_effect();
  350. }
  351. SymbolVar MarkNoBroadcastElemwise::make(
  352. SymbolVar input, const OperatorNodeConfig &config) {
  353. return input.insert_single_output_opr<MarkNoBroadcastElemwise>(
  354. input.node(), config);
  355. }
  356. #ifdef MGB_ENABLE_GRAD
  357. MGB_IMPL_OPR_GRAD(MarkNoBroadcastElemwise) {
  358. return out_grad.at(0);
  359. }
  360. #endif
  361. /* ===================== Identity ===================== */
  362. MGB_DYN_TYPE_OBJ_FINAL_IMPL(Identity);
  363. Identity::Identity(VarNode* input, const OperatorNodeConfig &config):
  364. Super(input->owner_graph(), config, "identity", {input})
  365. {
  366. add_input({input});
  367. add_output(None);
  368. set_ignore_side_effect();
  369. }
  370. SymbolVar Identity::make(
  371. SymbolVar input, const OperatorNodeConfig &config) {
  372. if (input.node()->owner_opr()->same_type<Identity>()) {
  373. // collapse consecutive Identity oprs
  374. // this is also necessary for megskull GradWrt in loop to work
  375. return input;
  376. }
  377. return input.insert_single_output_opr<Identity>(input.node(), config);
  378. }
  379. #ifdef MGB_ENABLE_GRAD
  380. MGB_IMPL_OPR_GRAD(Identity) {
  381. return out_grad.at(0);
  382. }
  383. #endif
  384. /* ===================== AssertEqual ===================== */
  385. MGB_DYN_TYPE_OBJ_FINAL_IMPL(AssertEqual);
  386. AssertEqual::AssertEqual(
  387. VarNode *expect, VarNode *get, VarNode *err,
  388. const Param &param, const OperatorNodeConfig &config):
  389. Super(err->owner_graph(), config, "assert_eq", {expect, get}),
  390. m_param{param}
  391. {
  392. add_input({expect, get, err});
  393. add_output(None);
  394. add_equivalence_component<PODHash<Param>>(&m_param);
  395. }
  396. SymbolVar AssertEqual::make(SymbolVar expect, SymbolVar get,
  397. const Param &param, const OperatorNodeConfig &config) {
  398. auto err = opr::reduce_max(
  399. opr::abs(expect - get) /
  400. opr::max(
  401. opr::min(opr::abs(expect), opr::abs(get)),
  402. expect.make_scalar_dt(1)),
  403. expect.make_scalar(1));
  404. return make(expect, get, err, param, config);
  405. }
  406. SymbolVar AssertEqual::make(
  407. SymbolVar expect, SymbolVar get, SymbolVar err,
  408. const Param &param, const OperatorNodeConfig &config) {
  409. return expect.insert_single_output_opr<AssertEqual>(
  410. expect.node(), get.node(), err.node(), param, config);
  411. }
  412. void AssertEqual::scn_do_execute_finish(const DeviceTensorND &) {
  413. if (owner_graph()->options().comp_node_seq_record_level >= 2) {
  414. mgb_log_error("AssertEqual %s disabled due to seq rec", cname());
  415. return;
  416. }
  417. m_hv.copy_from(input(2)->dev_tensor()).sync();
  418. mgb_assert(m_hv.shape().is_scalar());
  419. auto err = DTypeScalar::make_from_raw(
  420. m_hv.dtype(), m_hv.raw_ptr()).get_cast<float>();
  421. if (m_param.verbose) {
  422. //! FIXME: stderr will be slow when build windows with VS clang-cl (test in VM),
  423. //! but I can`t find the root case. fix it when you figure out
  424. fprintf(stdout,
  425. "AssertEqual: err=%g (name=%s id=%zu)\n", err, cname(), id());
  426. }
  427. if (!(err >= 0 && err <= m_param.maxerr)) {
  428. HostTensorND expect, get;
  429. expect.copy_from(input(0)->dev_tensor());
  430. get.copy_from(input(1)->dev_tensor()).sync();
  431. auto msg = debug::compare_tensor_value(
  432. expect, cg::dump_var_info({input(0)}).c_str(),
  433. get, cg::dump_var_info({input(1)}).c_str(),
  434. m_param.maxerr);
  435. mgb_assert(msg.valid());
  436. if (m_throw_on_error) {
  437. owner_graph()->record_async_error(
  438. cg::OperatorNodeExcExtraInfo::ExcMaker{
  439. input(1)->owner_opr()}.make_unique<UnequalError>(msg.val()));
  440. } else {
  441. mgb_log_error("%s", msg->c_str());
  442. }
  443. }
  444. }
  445. #if MGB_ENABLE_GRAD
  446. /* ===================== SetGrad ===================== */
  447. MGB_DYN_TYPE_OBJ_FINAL_IMPL(SetGrad);
  448. SetGrad::SetGrad(
  449. VarNode* input, const GradGetter& grad_getter,
  450. const OperatorNodeConfig &config):
  451. Super(input->owner_graph(), config, "set_grad", {input}),
  452. m_grad_getter{grad_getter}
  453. {
  454. add_input({input});
  455. add_output(None);
  456. set_ignore_side_effect();
  457. if (grad_getter) {
  458. // dedup not allowed
  459. add_equivalence_component<ScalarHash<void*>>(this);
  460. } else {
  461. // force to be zero_grad if no callback, and we can safely enable dedup
  462. m_grad_getter = zero_grad;
  463. }
  464. }
  465. SymbolVar SetGrad::make(SymbolVar input, const GradGetter& grad_getter,
  466. const OperatorNodeConfig &config) {
  467. return input.insert_single_output_opr<SetGrad>(
  468. input.node(), grad_getter, config);
  469. }
  470. #ifdef MGB_ENABLE_GRAD
  471. MGB_IMPL_OPR_GRAD(SetGrad) {
  472. MGB_MARK_USED_VAR(wrt_idx);
  473. MGB_MARK_USED_VAR(out_grad);
  474. auto grad = opr.grad_getter()(opr);
  475. mgb_assert(!grad.node() || grad.node()->owner_graph() == opr.owner_graph(),
  476. "var returned by grad_getter belongs to a different comp graph");
  477. return grad.node();
  478. }
  479. #endif
  480. /* ===================== InvalidGrad ===================== */
  481. MGB_DYN_TYPE_OBJ_FINAL_IMPL(InvalidGrad);
  482. void InvalidGrad::scn_do_execute() {
  483. mgb_assert(0);
  484. }
  485. InvalidGrad::InvalidGrad(VarNode* vinp, const OperatorNodeBase* grad_opr,
  486. size_t inp_idx)
  487. : Super{vinp->owner_graph(), {}, "invalid_grad", {vinp}},
  488. m_grad_opr(grad_opr),
  489. m_inp_idx(inp_idx) {
  490. add_input({vinp});
  491. add_output(None);
  492. }
  493. void InvalidGrad::add_input_layout_constraint() {
  494. MGB_MARK_USED_VAR(m_grad_opr);
  495. mgb_throw(GraphError,
  496. "invalid grad: can not take grad with respect to the %zu'th "
  497. "input var of operator {id:%zu, name:%s, type:%s}; "
  498. "(w.r.t. var: %s)",
  499. m_inp_idx, m_grad_opr->id(), m_grad_opr->cname(),
  500. m_grad_opr->dyn_typeinfo()->name,
  501. cg::dump_var_info(input()).c_str());
  502. }
  503. VarNode* InvalidGrad::make(const OperatorNodeBase& grad_opr, size_t inp_idx) {
  504. return SymbolVar(grad_opr.input(inp_idx))
  505. .insert_single_output_opr<InvalidGrad>(grad_opr.input(inp_idx),
  506. &grad_opr, inp_idx)
  507. .node();
  508. }
  509. /* ===================== VirtualGrad ===================== */
  510. MGB_DYN_TYPE_OBJ_FINAL_IMPL(VirtualGrad);
  511. VirtualGrad::VirtualGrad(VarNode *target, VarNode *wrt,
  512. const OperatorNodeConfig &config):
  513. Super(target->owner_graph(), config, "grad", {target, wrt})
  514. {
  515. add_input({target, wrt});
  516. add_output(None)->dtype(wrt->dtype());
  517. }
  518. SymbolVar VirtualGrad::make(SymbolVar target, SymbolVar wrt,
  519. Param, const OperatorNodeConfig &config) {
  520. return target.insert_single_output_opr<VirtualGrad>(
  521. target.node(), wrt.node(), config);
  522. }
  523. void VirtualGrad::do_execute(ExecEnv &) {
  524. mgb_throw(MegBrainError, "VirtualGrad opr must be removed by "
  525. "gopt::ExpandVirtualGradPass");
  526. }
  527. void VirtualGrad::init_output_comp_node() {
  528. output(0)->comp_node(input(1)->comp_node());
  529. }
  530. void VirtualGrad::init_output_static_infer_desc() {
  531. using namespace cg::static_infer;
  532. auto &&mgr = owner_graph()->static_infer_manager();
  533. auto ovar = output(0), ivar = input(1);
  534. mgr.register_shape_infer(ovar, ShapeInferDesc::make_identity(ivar));
  535. }
  536. void VirtualGrad::on_output_comp_node_stream_changed() {
  537. }
  538. VirtualGrad::NodeProp* VirtualGrad::do_make_node_prop() const {
  539. auto ret = Super::do_make_node_prop();
  540. ret->add_flag(NodeProp::Flag::CROSS_COMP_NODE_MEMORY);
  541. return ret;
  542. }
  543. /* ===================== VirtualLoss ===================== */
  544. MGB_DYN_TYPE_OBJ_FINAL_IMPL(VirtualLoss);
  545. VirtualLoss::VirtualLoss(const VarNodeArray& inputs,
  546. const OperatorNodeConfig& config)
  547. : Super(inputs.at(0)->owner_graph(), config, "internal_grad",
  548. {inputs.at(0)}) {
  549. mgb_assert(inputs.size() % 2 == 0);
  550. for (size_t i = 0, it = inputs.size() / 2; i < it; ++i) {
  551. auto yi = inputs[i], gradi = inputs[i + it];
  552. mgb_assert(yi && gradi);
  553. auto&& shp0 = yi->shape();
  554. auto&& shp1 = gradi->shape();
  555. mgb_assert((!shp0.ndim && !shp1.ndim) || shp0.eq_shape(shp1),
  556. "grad shape mismatch: %s vs %s", shp0.to_string().c_str(),
  557. shp1.to_string().c_str());
  558. mgb_assert(yi->comp_node() == gradi->comp_node());
  559. add_input({yi});
  560. }
  561. for (size_t i = inputs.size() / 2; i < inputs.size(); ++i) {
  562. add_input({inputs[i]});
  563. }
  564. add_output(None)->dtype(dtype::Float32{});
  565. }
  566. SymbolVar VirtualLoss::make(const SymbolVarArray& ys,
  567. const SymbolVarArray& y_grads, Param,
  568. const OperatorNodeConfig& config) {
  569. mgb_assert(ys.size() == y_grads.size() && !ys.empty());
  570. VarNodeArray inputs = to_var_node_array(ys);
  571. // sort for better dedup
  572. auto cmp = [](VarNode* a, VarNode* b) { return a->id() < b->id(); };
  573. std::sort(inputs.begin(), inputs.end(), cmp);
  574. ThinHashMap<VarNode*, VarNode*> var2grad;
  575. for (size_t i = 0; i < inputs.size(); ++i) {
  576. var2grad[ys[i].node()] = y_grads[i].node();
  577. }
  578. inputs.resize(inputs.size() * 2);
  579. for (size_t i = 0, it = inputs.size() / 2; i < it; ++i) {
  580. inputs[i + it] = var2grad.at(inputs[i]);
  581. }
  582. return ys[0].insert_single_output_opr<VirtualLoss>(inputs, config);
  583. }
  584. void VirtualLoss::do_execute(ExecEnv&) {
  585. mgb_throw_if(
  586. #if MGB_BUILD_SLIM_SERVING
  587. true,
  588. #else
  589. !owner_graph()->options().eager_evaluation,
  590. #endif
  591. MegBrainError, "InternalGradLoss should never be executed");
  592. }
  593. void VirtualLoss::init_output_comp_node() {
  594. output(0)->comp_node(input(0)->comp_node());
  595. }
  596. void VirtualLoss::init_output_static_infer_desc() {
  597. using namespace cg::static_infer;
  598. auto&& mgr = owner_graph()->static_infer_manager();
  599. mgr.register_shape_infer(output(0), ShapeInferDesc::make_const({1}));
  600. }
  601. void VirtualLoss::on_output_comp_node_stream_changed() {}
  602. VirtualLoss::NodeProp* VirtualLoss::do_make_node_prop() const {
  603. auto ret = Super::do_make_node_prop();
  604. ret->add_flag(NodeProp::Flag::CROSS_COMP_NODE_MEMORY);
  605. return ret;
  606. }
  607. #ifdef MGB_ENABLE_GRAD
  608. MGB_IMPL_OPR_GRAD(VirtualLoss) {
  609. mgb_assert(out_grad.size() == 1);
  610. auto mid = opr.input().size() / 2;
  611. if (wrt_idx < mid) {
  612. return opr.input(wrt_idx + mid);
  613. }
  614. return nullptr;
  615. }
  616. #endif
  617. #else
  618. VarNode* InvalidGrad::make(const OperatorNodeBase&, size_t) {
  619. mgb_throw(MegBrainError, "grad disabled at compile time");
  620. }
  621. #endif // MGB_ENABLE_GRAD
  622. /* ================== PersistentOutputStorage =================== */
  623. class PersistentOutputStorage::StorageHolder final
  624. : public UserDataContainer::UserData {
  625. MGB_TYPEINFO_OBJ_DECL;
  626. using Key = std::pair<CompNode, int>;
  627. struct KeyHash {
  628. size_t operator()(const Key& key) const {
  629. return hash_pair_combine(HashTrait<CompNode>::eval(key.first),
  630. key.second);
  631. }
  632. };
  633. std::mutex m_mtx;
  634. std::unordered_map<Key, DeviceTensorStorage, KeyHash> m_storage;
  635. public:
  636. void set_tensor(DeviceTensorND& dst, int key, CompNode comp_node,
  637. const TensorLayout& layout) {
  638. MGB_LOCK_GUARD(m_mtx);
  639. DeviceTensorStorage* storage;
  640. Maybe<DeviceTensorStorage> local_storage;
  641. if (key == -1) {
  642. storage = &local_storage.emplace(dst.storage());
  643. } else {
  644. storage = &m_storage[{comp_node, key}];
  645. }
  646. if (!storage->comp_node_valid()) {
  647. storage->comp_node(comp_node);
  648. }
  649. auto s = layout.span().dist_byte();
  650. if (s > storage->size()) {
  651. if (storage->size()) {
  652. // exponential growth if size gets increased
  653. s = s * 3 / 2;
  654. }
  655. storage->ensure_size(s);
  656. }
  657. dst.reset(*storage, layout);
  658. }
  659. };
  660. MGB_DYN_TYPE_OBJ_FINAL_IMPL(PersistentOutputStorage);
  661. MGB_TYPEINFO_OBJ_IMPL(PersistentOutputStorage::StorageHolder);
  662. class PersistentOutputStorage::DevValueExecDep final : public ExecDependency {
  663. DeviceTensorStorage m_val;
  664. public:
  665. explicit DevValueExecDep(DeviceTensorStorage val) : m_val{std::move(val)} {}
  666. };
  667. PersistentOutputStorage::PersistentOutputStorage(
  668. VarNode* inp, const Param& param, const OperatorNodeConfig& config)
  669. : Super{inp->owner_graph(), config, "persist", {}}, m_param{param} {
  670. add_input({inp});
  671. add_output(None)
  672. ->add_flag(VarNode::Flag::NO_MEM_RECLAIM)
  673. .add_flag(VarNode::Flag::DISALLOW_RT_FORCE_DYNAMIC_MEM_ALLOC);
  674. }
  675. SymbolVar PersistentOutputStorage::make(SymbolVar inp, const Param& param,
  676. const OperatorNodeConfig& config) {
  677. return inp.insert_single_output_opr<PersistentOutputStorage>(inp.node(),
  678. param, config);
  679. }
  680. void PersistentOutputStorage::record_execute_deps(ExecDependencyArray& deps) {
  681. mgb_assert(!m_dev_tensor.empty());
  682. deps.emplace_back(
  683. std::make_unique<DevValueExecDep>(m_dev_tensor.storage()));
  684. }
  685. void PersistentOutputStorage::scn_do_execute() {
  686. auto &&od = output(0)->dev_tensor(), &&id = input(0)->dev_tensor();
  687. mgb_assert(od.raw_ptr() == m_dev_tensor.raw_ptr());
  688. od.copy_from_fixlayout(id);
  689. }
  690. void PersistentOutputStorage::init_output_mem_plan(bool dynamic) {
  691. mgb_throw_if(
  692. dynamic, GraphError,
  693. "PersistentOutputStorage can not be used in dynamic storage case");
  694. auto cn = comp_node();
  695. auto ovar = output(0);
  696. mgb_assert(cg::is_static_var_storage(ovar));
  697. // note that this method is called after static shape infer, so it is safe
  698. // to access var shapes here
  699. auto&& shape = ovar->shape();
  700. if (!m_dev_tensor.shape().eq_shape(shape) ||
  701. m_dev_tensor.comp_node() != cn) {
  702. TensorLayout layout{shape, ovar->dtype(), ovar->format()};
  703. auto holder =
  704. owner_graph()
  705. ->options()
  706. .user_data.get_user_data_or_create<StorageHolder>();
  707. holder->set_tensor(m_dev_tensor, m_param.share_key, cn, layout);
  708. }
  709. ovar->init_mem_plan(&m_dev_tensor);
  710. }
  711. /* ================ RequireInputDynamicStorage ================== */
  712. MGB_DYN_TYPE_OBJ_FINAL_IMPL(RequireInputDynamicStorage);
  713. RequireInputDynamicStorage::RequireInputDynamicStorage(
  714. VarNode* input, const OperatorNodeConfig& config)
  715. : Super{input->owner_graph(),
  716. config,
  717. "require_input_dynamic_storage",
  718. {input}} {
  719. input->add_flag(VarNode::Flag::NO_SYS_STATIC_MEM_ALLOC);
  720. add_input({input});
  721. add_output(None);
  722. }
  723. SymbolVar RequireInputDynamicStorage::make(const SymbolVar input,
  724. const OperatorNodeConfig& config) {
  725. return input.insert_single_output_opr<RequireInputDynamicStorage>(
  726. input.node(), config);
  727. }
  728. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台