You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_shallow_copy.cpp 7.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. /**
  2. * \file src/serialization/impl/opr_shallow_copy.cpp
  3. *
  4. * This file is part of MegBrain, a deep learning framework developed by Megvii.
  5. *
  6. * \copyright Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  7. *
  8. */
  9. #include "megbrain/serialization/opr_shallow_copy.h"
  10. #include "megbrain/gopt/basic_arith.h"
  11. #include "megbrain/serialization/opr_load_dump.h"
  12. #include "megbrain/serialization/opr_registry.h"
  13. #include "megbrain/utils/big_key_hashmap.h"
  14. using namespace mgb;
  15. using namespace serialization;
  16. namespace {
  17. //! dump single opr to memory for shallow copy
  18. class OprDumpContextMemory final : public OprDumpContextRawPOD {
  19. std::vector<uint8_t> m_buf;
  20. void write_raw(const void* data, size_t size) override {
  21. auto pos = m_buf.size();
  22. auto end = pos + size;
  23. if (end > m_buf.capacity())
  24. m_buf.reserve(end * 2);
  25. m_buf.resize(end);
  26. memcpy(m_buf.data() + pos, data, size);
  27. }
  28. void dump_tensor(
  29. const std::string&, const HostTensorND&, TensorWriteMethod,
  30. TensorFormat format = {}) override {
  31. MGB_MARK_USED_VAR(format);
  32. mgb_throw(GraphError, "OprDumpContextMemory does not support dump tensor");
  33. }
  34. const GraphDumpConfig& config() const override {
  35. mgb_throw(GraphError, "OprDumpContextMemory has no associated config");
  36. }
  37. public:
  38. OprDumpContextMemory() : OprDumpContextRawPOD(false) {}
  39. auto&& buf() const { return m_buf; }
  40. };
  41. //! load single opr from memory for shallow copy
  42. class OprLoadContextMemory final : public OprLoadContextRawPOD {
  43. const uint8_t* m_ptr;
  44. size_t m_size, m_pos = 0;
  45. ComputingGraph* m_graph;
  46. void read_raw(void* dest, size_t size) override {
  47. auto end = m_pos + size;
  48. mgb_assert(end <= m_size);
  49. memcpy(dest, m_ptr + m_pos, size);
  50. m_pos = end;
  51. }
  52. ComputingGraph& graph() override { return *m_graph; }
  53. std::shared_ptr<HostTensorND> load_tensor() override { mgb_assert(0); }
  54. std::shared_ptr<DeviceTensorND> load_tensor_shared(
  55. bool copy_immediatly = false) override {
  56. (void)copy_immediatly;
  57. mgb_assert(0);
  58. }
  59. const GraphLoadConfig& config() const override {
  60. mgb_throw(GraphError, "OprLoadContextMemory has no associated config");
  61. }
  62. public:
  63. OprLoadContextMemory(ComputingGraph* graph, const OprDumpContextMemory& dumper)
  64. : OprLoadContextRawPOD(false),
  65. m_ptr{dumper.buf().data()},
  66. m_size{dumper.buf().size()},
  67. m_graph{graph} {}
  68. ~OprLoadContextMemory() { mgb_assert(m_pos == m_size); }
  69. };
  70. class ShallowCopyCacheContainer final : public UserDataContainer::UserData {
  71. MGB_TYPEINFO_OBJ_DECL;
  72. struct HashEq {
  73. template <typename T>
  74. static bool eq(const T& x, const T& y) {
  75. return x == y;
  76. }
  77. static bool eq(const OperatorNodeConfig& x, const OperatorNodeConfig& y) {
  78. return x.is_same(y);
  79. }
  80. static size_t hash(const void* ptr) { return std::hash<const void*>{}(ptr); }
  81. static size_t hash(const VarNodeArray& inputs) {
  82. return PODHash<VarNode*>::perform(inputs.data(), inputs.size());
  83. }
  84. static size_t hash(const OperatorNodeConfig& config) { return config.hash(); }
  85. };
  86. public:
  87. big_key_hash_map::BigKeyHashMap<
  88. cg::OperatorNodeBase*, HashEq,
  89. big_key_hash_map::Copy<const cg::OperatorNodeBase*>,
  90. big_key_hash_map::Ref<VarNodeArray>,
  91. big_key_hash_map::Ref<OperatorNodeConfig>>
  92. cache;
  93. };
  94. MGB_TYPEINFO_OBJ_IMPL(ShallowCopyCacheContainer);
  95. } // anonymous namespace
  96. ComputingGraph* serialization::OprShallowCopyContext::owner_graph(
  97. const cg::OperatorNodeBase& opr, const VarNodeArray& inputs) const {
  98. if (!m_owner_graph) {
  99. if (inputs.empty())
  100. return opr.owner_graph();
  101. return inputs[0]->owner_graph();
  102. }
  103. if (!inputs.empty())
  104. mgb_assert(m_owner_graph == inputs[0]->owner_graph());
  105. return m_owner_graph;
  106. }
  107. cg::OperatorNodeBase* serialization::copy_opr_shallow(
  108. const cg::OperatorNodeBase& opr, const VarNodeArray& inputs,
  109. const OperatorNodeConfig& config, const OprShallowCopyContext& ctx) {
  110. OprShallowCopy shallow_copy = nullptr;
  111. if (auto registry = OprRegistry::find_by_type(opr.dyn_typeinfo())) {
  112. shallow_copy = registry->shallow_copy;
  113. } else {
  114. shallow_copy = intl::copy_opr_shallow_default_impl;
  115. }
  116. mgb_assert(inputs.size() == opr.input().size());
  117. auto dst_og = ctx.owner_graph(opr, inputs);
  118. auto do_copy = [&]() {
  119. auto nr_opr_before = opr.owner_graph()->nr_oprs_in_graph();
  120. auto ret = shallow_copy(ctx, opr, inputs, config);
  121. if (dst_og != opr.owner_graph() ||
  122. opr.owner_graph()->nr_oprs_in_graph() != nr_opr_before) {
  123. auto&& attr = ret->node_prop().attribute();
  124. if (!attr.src_opr) {
  125. auto src = cg::get_opr_root_source_opr(
  126. const_cast<cg::OperatorNodeBase*>(&opr));
  127. if (ret != src)
  128. attr.src_opr = src;
  129. }
  130. if (!attr.priority) {
  131. // priority may have been changed by OprInserted event handlers
  132. // (like in python case)
  133. attr.priority = opr.node_prop().attribute().priority;
  134. }
  135. }
  136. return ret;
  137. };
  138. cg::OperatorNodeBase* ret;
  139. if (dst_og == opr.owner_graph()) {
  140. // use cache for copy in same graph
  141. auto&& cache =
  142. dst_og->options()
  143. .user_data.get_user_data_or_create<ShallowCopyCacheContainer>()
  144. ->cache;
  145. auto ins = cache.get(&opr, inputs, config);
  146. if (ins.first) {
  147. *ins.second = do_copy();
  148. } else {
  149. cg::update_output_var_shapes(*ins.second);
  150. }
  151. ret = *ins.second;
  152. } else {
  153. ret = do_copy();
  154. }
  155. mgb_assert(
  156. gopt::has_inplace_basic_arith_opt(opr) ||
  157. (( // outputs match
  158. opr.usable_output().size() ==
  159. ret->usable_output().size()) &&
  160. ( // new opr is returned
  161. (&opr != ret) || opr.input() == inputs)),
  162. "bad opr copy: src=%s{%s} dst=%s{%s}", opr.cname(),
  163. opr.dyn_typeinfo()->name, ret->cname(), ret->dyn_typeinfo()->name);
  164. return ret;
  165. }
  166. cg::OperatorNodeBase* serialization::intl::copy_opr_shallow_default_impl(
  167. const OprShallowCopyContext& ctx, const cg::OperatorNodeBase& opr,
  168. const VarNodeArray& inputs, const OperatorNodeConfig& config) {
  169. MGB_MARK_USED_VAR(ctx);
  170. OprDumper opr_dumper = nullptr;
  171. OprLoaderWrapper opr_loader = nullptr;
  172. if (auto registry = OprRegistry::find_by_type(opr.dyn_typeinfo())) {
  173. opr_loader = registry->loader;
  174. opr_dumper = registry->dumper;
  175. } else {
  176. auto registryv2 = OprRegistryV2::versioned_find_by_typeinfo(
  177. opr.dyn_typeinfo(), CURRENT_VERSION);
  178. opr_loader = registryv2->loader;
  179. opr_dumper = registryv2->dumper;
  180. }
  181. mgb_assert(
  182. opr_dumper && opr_loader,
  183. "can not shallow_copy operator %s{%s}: "
  184. "no dumper/loader registered",
  185. opr.cname(), opr.dyn_typeinfo()->name);
  186. OprDumpContextMemory memory_dumper;
  187. opr_dumper(memory_dumper, opr);
  188. OprLoadContextMemory loader{opr.owner_graph(), memory_dumper};
  189. return opr_loader(loader, inputs, config).opr();
  190. }
  191. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}