You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

trace.cpp 28 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679
  1. /**
  2. * \file imperative/src/impl/transformations/trace.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/imperative/transformations/trace.h"
  12. #include <chrono>
  13. #include <exception>
  14. #include "megbrain/gopt/inference.h"
  15. #include "megbrain/graph/helper.h"
  16. #include "megbrain/imperative/ops/autogen.h"
  17. #include "megbrain/opr/io.h"
  18. #include "megbrain/opr/utility.h"
  19. #include "megbrain/serialization/serializer.h"
  20. #include "../event_pool.h"
  21. #define trace_assert(_cond, _msg...) \
  22. do { \
  23. if (mgb_unlikely(!(_cond))) { \
  24. auto exc = std::make_exception_ptr(TraceError(ssprintf(_msg))); \
  25. set_exception(exc); \
  26. std::rethrow_exception(exc); \
  27. } \
  28. } while (0)
  29. namespace mgb {
  30. namespace imperative {
  31. VarNodeArray TraceResult::dump(
  32. ComputingGraph& graph,
  33. std::vector<std::tuple<size_t, std::string, TensorShape>> inputs,
  34. std::vector<std::pair<size_t, std::string>> outputs, bool prefer_input_names) {
  35. // var -> VarNode
  36. std::vector<VarNode*> nodes(vars.size(), nullptr);
  37. // make h2d node for each input
  38. for (auto&& [input, name, shape] : inputs) {
  39. auto& var = vars[input];
  40. auto& node = nodes[input];
  41. // TODO: cambricon CompNode
  42. auto host = std::make_shared<HostTensorND>(
  43. CompNode::load("xpux"), shape, var.dtype);
  44. OperatorNodeConfig config;
  45. // if prefer_input_names, prefer names from dump args
  46. // else prefer names got from trace procedure
  47. if (prefer_input_names && !name.empty()) {
  48. config.name(name);
  49. } else if (!var.name.empty()) {
  50. config.name(var.name);
  51. } else if (!name.empty()) {
  52. config.name(name);
  53. }
  54. node = opr::Host2DeviceCopy::make(graph, host, {}, config).node();
  55. }
  56. // make const node for each constant
  57. for (size_t i = 0; i < vars.size(); ++i) {
  58. auto& var = vars[i];
  59. auto& node = nodes[i];
  60. if (!node) {
  61. if (var.kind != VarKind::Internal) {
  62. if (!var.bound_data) {
  63. continue;
  64. }
  65. if (!var.name.empty()) {
  66. node = opr::ImmutableTensor::make(
  67. graph, var.bound_data.numpy()->as_nd(), {var.name})
  68. .node();
  69. } else {
  70. node = opr::ImmutableTensor::make(
  71. graph, var.bound_data.numpy()->as_nd())
  72. .node();
  73. }
  74. }
  75. }
  76. }
  77. std::unordered_map<std::string, std::vector<cg::OperatorNodeBase*>> name2ops;
  78. // iterate over opr_seq
  79. for (auto&& item : seq) {
  80. auto&& [op, inputs, outputs] = item;
  81. VarNodeArray input_nodes;
  82. for (auto&& input : inputs) {
  83. auto& node = nodes[input];
  84. input_nodes.push_back(node);
  85. }
  86. VarNodeArray output_nodes;
  87. if (op) {
  88. if (auto* bn = op->try_cast_final<BatchNorm>()) {
  89. mgb_assert(
  90. bn->fwd_mode == BatchNorm::FwdMode::INFERENCE,
  91. "can not dump BatchNorm in training mode, maybe you forget to "
  92. "do model.eval()?");
  93. }
  94. output_nodes = OpDef::apply_on_var_node(*op, input_nodes);
  95. name2ops[output_nodes[0]->owner_opr()->name()].push_back(
  96. output_nodes[0]->owner_opr());
  97. } else {
  98. // no opr, just forward VarNode
  99. mgb_assert(
  100. inputs.size() == outputs.size(),
  101. "output size not equals to input size when forwarding");
  102. output_nodes = input_nodes;
  103. }
  104. mgb_assert(output_nodes.size() == outputs.size(), "output size mismatch");
  105. for (size_t i = 0; i < outputs.size(); ++i) {
  106. auto output = outputs[i];
  107. auto& var = vars[output];
  108. auto& node = nodes[output];
  109. mgb_assert(var.kind == VarKind::Internal, "output node should be internal");
  110. if (!node) {
  111. node = output_nodes[i];
  112. }
  113. if (!var.name.empty()) {
  114. node->name(var.name);
  115. }
  116. }
  117. }
  118. for (auto&& [name, ops] : name2ops) {
  119. if (ops.size() <= 1) {
  120. continue;
  121. }
  122. // ops.size() > 1, need dedup (rename op)
  123. for (size_t i = 0; i < ops.size(); ++i) {
  124. auto& op = ops[i];
  125. auto new_name = ssprintf("%s[%zu]", name.c_str(), i);
  126. for (auto&& output : op->output()) {
  127. auto output_name = output->name();
  128. auto pos = output_name.find(name);
  129. if (pos != std::string::npos) {
  130. output_name.replace(pos, name.length(), new_name);
  131. }
  132. output->name(output_name);
  133. }
  134. op->name(new_name);
  135. }
  136. }
  137. VarNodeArray output_nodes;
  138. for (auto&& [output, name] : outputs) {
  139. mgb_assert(output < vars.size(), "invalid output id %zu", output);
  140. mgb_assert(nodes[output], "output node invalid");
  141. if (!name.empty()) {
  142. nodes[output]->name(name);
  143. }
  144. output_nodes.push_back(nodes[output]);
  145. }
  146. return output_nodes;
  147. }
  148. std::vector<ValueRef> TracingTransformation::apply_transformation(
  149. const Operator& op, Span<ValueRef> inputs) {
  150. if (auto* op_value = op.as<ApplyOp>()) {
  151. SmallVector<ValueRef> unwrapped_inputs;
  152. SmallVector<TracingValue::ref_t> wrapped_inputs;
  153. SmallVector<size_t> input_ids;
  154. for (auto input : inputs) {
  155. auto tracing_value = input.as_ref<TracingValue>();
  156. if (!tracing_value) {
  157. tracing_value =
  158. record_var(input, m_capture_as_const, VarKind::External);
  159. }
  160. unwrapped_inputs.push_back(tracing_value->value());
  161. wrapped_inputs.push_back(tracing_value);
  162. input_ids.push_back(tracing_value->id());
  163. }
  164. // TODO: remove OpDef::set_scope
  165. auto scopes = Transformation::scopes();
  166. std::string scopes_join;
  167. for (auto&& scope : scopes) {
  168. if (!scopes_join.empty()) {
  169. scopes_join.push_back('.');
  170. }
  171. scopes_join.append(scope);
  172. }
  173. const_cast<OpDef&>(op_value->op()).set_scope(scopes_join);
  174. auto unwrapped_outputs = imperative::apply(op, unwrapped_inputs);
  175. std::vector<ValueRef> wrapped_outputs;
  176. SmallVector<size_t> output_ids;
  177. for (auto&& output : unwrapped_outputs) {
  178. auto wrapped_output = record_var(output, false, VarKind::Internal);
  179. wrapped_outputs.push_back(wrapped_output);
  180. output_ids.push_back(wrapped_output->id());
  181. }
  182. m_seq.push_back({op_value->op().shared_from_this(), input_ids, output_ids});
  183. return wrapped_outputs;
  184. } else if (auto* create_tensor = op.as<CreateTensor>()) {
  185. auto outputs = imperative::apply(op, inputs);
  186. if (create_tensor->kind() == CreateTensor::NoTrace) {
  187. return outputs;
  188. }
  189. bool is_const = create_tensor->kind() == CreateTensor::Const;
  190. auto wrapped_input = record_var(
  191. outputs[0], is_const || m_capture_as_const,
  192. is_const ? VarKind::Constant : VarKind::External);
  193. auto wrapped_output = record_var(outputs[0], false, VarKind::Internal);
  194. auto input_id = wrapped_input->id();
  195. auto output_id = wrapped_output->id();
  196. m_seq.push_back({{}, {input_id}, {output_id}});
  197. return {wrapped_output};
  198. } else if (auto* get_attr = op.as<GetAttr>()) {
  199. auto unwrapped_input = unwrap_var(inputs[0]);
  200. auto outputs = imperative::apply(op, unwrapped_input);
  201. if (auto* tracing_value = inputs[0].as<TracingValue>()) {
  202. auto& var_info = m_vars[tracing_value->id()];
  203. switch (get_attr->attr()) {
  204. case GetAttr::Shape:
  205. // TODO: reduce h2d when data or value is available
  206. var_info.shape_required = true;
  207. break;
  208. case GetAttr::Data:
  209. var_info.data_required = true;
  210. break;
  211. case GetAttr::Value:
  212. var_info.value_required = true;
  213. break;
  214. default:
  215. break;
  216. }
  217. }
  218. return outputs;
  219. } else if (auto* trace_mark_var = op.as<TraceMarkVar>()) {
  220. mgb_assert(inputs.size() == 1, "TraceMarkVar expects exactly one input");
  221. auto input = inputs[0];
  222. auto tracing_var = input.as_ref<TracingValue>();
  223. if (!tracing_var) {
  224. bool is_input = trace_mark_var->mark().substr(0, 4) == "arg_" ||
  225. trace_mark_var->mark().substr(0, 6) == "kwarg_";
  226. if (is_input) {
  227. tracing_var = record_var(input, false, VarKind::External);
  228. } else {
  229. tracing_var = record_var(input, m_capture_as_const, VarKind::External);
  230. }
  231. } else {
  232. input = tracing_var->value();
  233. }
  234. auto output = record_var(input, false, VarKind::Internal);
  235. m_vars[output->id()].mark = trace_mark_var->mark();
  236. m_seq.push_back({{}, {tracing_var->id()}, {output->id()}});
  237. return {output};
  238. } else if (auto* trace_name_var = op.as<RenameValue>()) {
  239. mgb_assert(inputs.size() == 1, "RenameValue expects exactly one input");
  240. auto input = inputs[0];
  241. auto tracing_var = input.as_ref<TracingValue>();
  242. if (!tracing_var) {
  243. tracing_var = record_var(input, m_capture_as_const, VarKind::External);
  244. } else {
  245. input = tracing_var->value();
  246. }
  247. auto output = record_var(input, false, VarKind::Internal);
  248. m_vars[output->id()].name = trace_name_var->name();
  249. m_seq.push_back({{}, {tracing_var->id()}, {output->id()}});
  250. return {output};
  251. } else if (op.is<GetName>()) {
  252. mgb_assert(inputs.size() == 1, "GetName expects exactly one input");
  253. auto input = inputs[0];
  254. if (auto tracing_var = input.as_ref<TracingValue>()) {
  255. auto name = m_vars[tracing_var->id()].name;
  256. if (!name.empty()) {
  257. return {StringValue::make(name)};
  258. } else {
  259. return {ValueRef()};
  260. }
  261. }
  262. return imperative::apply(op, inputs);
  263. } else {
  264. // TODO: handle DTRCommand and ...
  265. return op.fallback(inputs);
  266. }
  267. }
  268. void TracingTransformation::on_unregister() noexcept {
  269. for (auto&& weak_var : m_weak_vars) {
  270. if (auto tracing_value = weak_var.lock()) {
  271. auto& var_info = m_vars[tracing_value->id()];
  272. var_info.data_required = true;
  273. tracing_value.reset(tracing_value->value());
  274. }
  275. }
  276. m_weak_vars.clear();
  277. }
  278. void CompiledTransformation::compile() {
  279. // these ops require seq order, so we link them to an mm_io_link to ensure order
  280. static std::unordered_set<Typeinfo*> mm_io_ops = {
  281. CollectiveComm::typeinfo(), RemoteSend::typeinfo(), RemoteRecv::typeinfo()};
  282. mgb_assert(!m_executable, "already compiled");
  283. // FIXME: mm_io_link and io_links should be merged
  284. SymbolVarArray io_links;
  285. SymbolVar mm_io_link;
  286. auto make_input = [&](VarInfo* var_info) {
  287. mgb_assert(
  288. var_info->kind == VarKind::External, "input node should be external");
  289. VarAccessor accessor;
  290. auto box = make_box<DeviceTensorND>();
  291. // TODO: attach ref count, release early
  292. auto outputs = opr::InputCallback::make(
  293. *m_graph, [box] { return box->take_value(); }, var_info->device,
  294. var_info->dtype, var_info->shape, io_links, m_input_shape_static);
  295. // attach input_callback to io_links
  296. accessor.node = outputs[0].node();
  297. io_links = {outputs[1]};
  298. accessor.data_setter = [box](DeviceTensorND data) { box->try_set_value(data); };
  299. return accessor;
  300. };
  301. auto make_output = [&](TraceResult::VarInfo* var_info, SymbolVar node) {
  302. VarAccessor accessor;
  303. accessor.node = node.node();
  304. if (var_info->shape_required) {
  305. // TODO: use static infer manager for some vars?
  306. auto box = make_box<TensorShape>();
  307. auto callback = [box](DeviceTensorND data) {
  308. box->try_set_value(data.shape());
  309. };
  310. SymbolVarArray inputs = io_links;
  311. inputs.insert(inputs.begin(), node);
  312. auto output = opr::OutputCallback::make({callback, true, false}, inputs);
  313. io_links = {output};
  314. accessor.shape_getter = [box]() -> TensorShape { return box->get_value(); };
  315. }
  316. if (var_info->data_required) {
  317. auto box = make_box<DeviceTensorND>();
  318. auto callback = [box](DeviceTensorND data) { box->try_set_value(data); };
  319. SymbolVarArray inputs = io_links;
  320. inputs.insert(inputs.begin(), node);
  321. auto output = opr::OutputCallback::make({callback, false, false}, inputs);
  322. io_links = {output};
  323. accessor.data_getter = [box]() -> DeviceTensorND {
  324. return box->get_value();
  325. };
  326. }
  327. if (var_info->value_required) {
  328. struct ValueWithEvent {
  329. HostTensorND value;
  330. CompNode::Event* event = nullptr;
  331. };
  332. auto box = make_box<ValueWithEvent>();
  333. auto event = EventPool::without_timer().alloc_shared(var_info->device);
  334. auto callback = [box, event](DeviceTensorND data) {
  335. HostTensorND host_val;
  336. host_val.copy_from(data);
  337. if (data.comp_node() != CompNode::default_cpu()) {
  338. mgb_assert(data.comp_node() == event->comp_node());
  339. event->record();
  340. box->try_set_value({host_val, event.get()});
  341. } else {
  342. box->try_set_value({host_val});
  343. }
  344. };
  345. SymbolVarArray inputs = io_links;
  346. inputs.insert(inputs.begin(), node);
  347. auto output = opr::OutputCallback::make({callback, false, true}, inputs);
  348. io_links = {output};
  349. accessor.value_getter = [box]() -> HostTensorND {
  350. auto&& [value, event] = box->get_value();
  351. if (event) {
  352. event->host_wait();
  353. }
  354. return value;
  355. };
  356. }
  357. return accessor;
  358. };
  359. auto make_const = [&](TraceResult::VarInfo* var_info) {
  360. VarAccessor accessor;
  361. mgb_assert(
  362. var_info->kind == VarKind::Constant, "const node should be constant");
  363. HostTensorND host_val = var_info->bound_data.numpy()->as_nd();
  364. accessor.node = opr::ImmutableTensor::make(*m_graph, host_val).node();
  365. return accessor;
  366. };
  367. std::vector<VarAccessor> var_accessors(m_vars.size());
  368. for (auto&& item : m_seq) {
  369. bool require_link = bool(item.op) && mm_io_ops.count(item.op->dyn_typeinfo());
  370. VarNodeArray input_vars;
  371. for (auto&& input : item.inputs) {
  372. auto& var = m_vars[input];
  373. if (!var_accessors[input].node) {
  374. switch (var.kind) {
  375. case VarKind::External:
  376. var_accessors[input] = make_input(&var);
  377. break;
  378. case VarKind::Constant:
  379. var_accessors[input] = make_const(&var);
  380. break;
  381. default:
  382. mgb_throw(
  383. AssertionError,
  384. "internal node should be valid when used as input");
  385. }
  386. }
  387. input_vars.push_back(var_accessors[input].node);
  388. }
  389. if (require_link && mm_io_link.node()) {
  390. mgb_assert(
  391. !input_vars.empty(),
  392. "io-mm operator should have at least one input");
  393. input_vars[0] =
  394. opr::VirtualDep::make({SymbolVar(input_vars[0]), mm_io_link})
  395. .node();
  396. }
  397. VarNodeArray output_vars;
  398. if (item.op) {
  399. output_vars = OpDef::apply_on_var_node(*item.op, input_vars);
  400. } else {
  401. // forward inputs to outputs
  402. mgb_assert(
  403. item.inputs.size() == item.outputs.size(),
  404. "output size not equals to input size when forwarding");
  405. for (auto&& input_var : input_vars) {
  406. output_vars.push_back(input_var);
  407. }
  408. }
  409. if (require_link) {
  410. mgb_assert(
  411. !item.outputs.empty(),
  412. "io-mm operator should have at least one output");
  413. mm_io_link = SymbolVar(output_vars[0]);
  414. }
  415. // init output accessors
  416. for (size_t i = 0; i < output_vars.size(); ++i) {
  417. auto output = item.outputs[i];
  418. auto& node = output_vars[i];
  419. auto& var = m_vars[output];
  420. var_accessors[output] = make_output(&var, node);
  421. }
  422. }
  423. ComputingGraph::OutputSpec output_specs;
  424. // avoid input/output/callback from being optimized
  425. for (auto&& io_link : io_links) {
  426. output_specs.push_back({io_link, {}});
  427. }
  428. // avoid remote io ops from being optimized
  429. if (mm_io_link.node()) {
  430. output_specs.push_back({mm_io_link, {}});
  431. }
  432. {
  433. // set_priority_to_id
  434. // workaround for having mm_io_link and io_links separated
  435. auto on_opr = [](mgb::cg::OperatorNodeBase* opr) {
  436. if (opr->node_prop().attribute().priority == 0) {
  437. opr->node_prop().attribute().priority = opr->id();
  438. }
  439. };
  440. mgb::cg::DepOprIter dep_iter{on_opr};
  441. for (const auto& output_spec : output_specs) {
  442. dep_iter.add(output_spec.first);
  443. }
  444. }
  445. m_executable = m_graph->compile(output_specs);
  446. m_var_accessors = var_accessors;
  447. m_output_spec = output_specs;
  448. }
  449. void CompiledTransformation::recompile() {
  450. mgb_assert(m_executable);
  451. m_executable = m_graph->compile(m_output_spec);
  452. }
  453. void CompiledTransformation::assert_tensor_equal(ValueRef lhs, ValueRef rhs) {
  454. trace_assert(m_value_comparator(lhs, rhs), "tensors not equals");
  455. }
  456. void CompiledTransformation::trace_input(size_t id, ValueRef value) {
  457. try {
  458. auto& var = m_vars[id];
  459. auto& var_accessor = m_var_accessors[id];
  460. switch (var.kind) {
  461. case VarKind::External: {
  462. trace_assert(
  463. !value.is<TracedValue>(), "expect external node, got internal");
  464. if (var.bound_data) {
  465. assert_tensor_equal(var.bound_data, value);
  466. } else {
  467. DType dtype = *value.dtype();
  468. CompNode device = *value.device();
  469. trace_assert(
  470. var.dtype == dtype, "dtype mismatch: %s vs %s",
  471. var.dtype.name(), dtype.name());
  472. trace_assert(
  473. var.device == device, "comp_node mismatch: %s vs %s",
  474. var.device.to_string().c_str(), device.to_string().c_str());
  475. }
  476. var_accessor.data_setter(value.dev_tensor()->as_nd());
  477. break;
  478. }
  479. case VarKind::Constant: {
  480. mgb_assert(var.bound_data, "const var without data bound");
  481. assert_tensor_equal(var.bound_data, value);
  482. break;
  483. }
  484. case VarKind::Internal: {
  485. trace_assert(
  486. value.is<TracedValue>(), "expect internal node, got external");
  487. auto& traced_value = value.cast<TracedValue>();
  488. trace_assert(traced_value.id() == id, "input id mismatch");
  489. break;
  490. }
  491. }
  492. } catch (TraceError&) {
  493. throw;
  494. } catch (...) {
  495. mgb_assert(false, "unexpected error");
  496. }
  497. }
  498. TracedValue::ref_t CompiledTransformation::trace_output(size_t id) {
  499. auto traced_value = TracedValue::make(id);
  500. m_weak_values.push_back(traced_value);
  501. return traced_value;
  502. }
  503. TraceResult::SeqItem& CompiledTransformation::next_instruction() {
  504. trace_assert(m_pc < m_seq.size(), "too many instructions");
  505. return m_seq[m_pc++];
  506. }
  507. std::vector<ValueRef> CompiledTransformation::apply_transformation(
  508. const Operator& op, Span<ValueRef> inputs) {
  509. if (auto* op_value = op.as<ApplyOp>()) {
  510. auto& item = next_instruction();
  511. SmallVector<ValueRef> unwrapped_inputs;
  512. SmallVector<ValueRef> wrapped_inputs;
  513. trace_assert(inputs.size() == item.inputs.size(), "input size mismatch");
  514. trace_assert(op_value->op().is_same(*item.op), "operator mismatch");
  515. for (size_t i = 0; i < inputs.size(); ++i) {
  516. trace_input(item.inputs[i], inputs[i]);
  517. }
  518. std::vector<ValueRef> outputs;
  519. for (auto&& output_id : item.outputs) {
  520. outputs.push_back(trace_output(output_id));
  521. }
  522. return outputs;
  523. } else if (auto* create_tensor = op.as<CreateTensor>()) {
  524. if (create_tensor->kind() == CreateTensor::NoTrace) {
  525. return imperative::apply(op, inputs);
  526. }
  527. auto& item = next_instruction();
  528. trace_assert(item.op == nullptr, "operator mismatch");
  529. auto input_id = item.inputs[0];
  530. auto output_id = item.outputs[0];
  531. auto tensor = imperative::apply(op, inputs)[0];
  532. trace_input(input_id, tensor);
  533. return {trace_output(output_id)};
  534. } else if (auto* get_attr = op.as<GetAttr>()) {
  535. if (auto* traced_value = inputs[0].as<TracedValue>()) {
  536. ValueRef output;
  537. auto& var = m_vars[traced_value->id()];
  538. auto& var_accessor = m_var_accessors[traced_value->id()];
  539. switch (get_attr->attr()) {
  540. case GetAttr::Shape:
  541. trace_assert(var_accessor.shape_getter, "shape unreadable");
  542. output = ShapeValue::make(
  543. ValueShape::from(var_accessor.shape_getter()));
  544. break;
  545. case GetAttr::Data:
  546. trace_assert(var_accessor.data_getter, "data unreadable");
  547. output = DeviceValue::make(var_accessor.data_getter());
  548. break;
  549. case GetAttr::Value:
  550. trace_assert(var_accessor.value_getter, "value unreadable");
  551. output = HostValue::make(var_accessor.value_getter());
  552. break;
  553. case GetAttr::DType:
  554. output = DTypeValue::make(var.dtype);
  555. break;
  556. case GetAttr::Device:
  557. output = CompNodeValue::make(var.device);
  558. default:
  559. break;
  560. }
  561. return {output};
  562. } else {
  563. return imperative::apply(op, inputs);
  564. }
  565. } else if (auto* trace_mark_var = op.as<TraceMarkVar>()) {
  566. auto& item = next_instruction();
  567. trace_assert(item.op == nullptr, "operator mismatch");
  568. trace_assert(item.inputs.size() == 1, "inputs size mismatch");
  569. trace_assert(item.outputs.size() == 1, "inputs output mismatch");
  570. trace_input(item.inputs[0], inputs[0]);
  571. trace_assert(
  572. trace_mark_var->mark() == m_vars[item.outputs[0]].mark,
  573. "mark mismatch");
  574. return {trace_output(item.outputs[0])};
  575. } else if (auto* trace_name_var = op.as<RenameValue>()) {
  576. auto& item = next_instruction();
  577. trace_assert(item.op == nullptr, "operator mismatch");
  578. trace_assert(item.inputs.size() == 1, "inputs size mismatch");
  579. trace_assert(item.outputs.size() == 1, "outputs size mismatch");
  580. trace_input(item.inputs[0], inputs[0]);
  581. trace_assert(
  582. trace_name_var->name() == m_vars[item.outputs[0]].name,
  583. "name mismatch");
  584. return {trace_output(item.outputs[0])};
  585. } else {
  586. return op.fallback(inputs);
  587. }
  588. }
  589. void CompiledTransformation::on_unregister() noexcept {
  590. // resolve pending values
  591. for (auto&& weak_value : m_weak_values) {
  592. if (auto traced_value = weak_value.lock()) {
  593. auto& var_accessor = m_var_accessors[traced_value->id()];
  594. auto value = ([&]() -> ValueRef {
  595. try {
  596. trace_assert(var_accessor.data_getter, "data unreadable");
  597. auto dev_value = DeviceValue::make(var_accessor.data_getter());
  598. return imperative::apply(
  599. CreateTensor(
  600. CreateTensor::Common, dev_value->device(),
  601. dev_value->dtype(), dev_value->shape()),
  602. DeviceStorage::make(dev_value->storage()))[0];
  603. } catch (...) {
  604. set_exception(std::current_exception());
  605. return ErrorValue::make("trace exit failed");
  606. }
  607. })();
  608. traced_value.reset(value);
  609. }
  610. }
  611. m_weak_values.clear();
  612. }
  613. void CompiledTransformation::execute() {
  614. mgb_assert(m_executable != nullptr);
  615. m_graph_executor = std::thread([&] {
  616. try {
  617. m_executable->execute();
  618. m_executable->wait();
  619. } catch (...) {
  620. auto exc = std::current_exception();
  621. set_exception(exc);
  622. }
  623. });
  624. }
  625. void CompiledTransformation::wait() {
  626. try {
  627. trace_assert(m_pc == m_seq.size(), "mismature end");
  628. } catch (...) {
  629. }
  630. mgb_assert(m_executable != nullptr);
  631. m_graph_executor.join();
  632. m_graph_executor = {};
  633. for (auto&& box : m_boxes) {
  634. box->reset();
  635. }
  636. m_pc = 0;
  637. std::exception_ptr graph_exc;
  638. std::swap(m_graph_exc, graph_exc);
  639. if (graph_exc) {
  640. // graph with exception cannot be reused
  641. recompile();
  642. std::rethrow_exception(graph_exc);
  643. }
  644. }
  645. std::exception_ptr CompiledTransformation::set_exception(
  646. std::exception_ptr exc) noexcept {
  647. MGB_LOCK_GUARD(m_mutex);
  648. if (m_graph_exc) {
  649. return m_graph_exc;
  650. }
  651. for (auto&& box : m_boxes) {
  652. box->try_set_exception(exc);
  653. }
  654. m_graph_exc = exc;
  655. return m_graph_exc;
  656. }
  657. } // namespace imperative
  658. } // namespace mgb