You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_shallow_copy.cpp 7.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. /**
  2. * \file src/serialization/impl/opr_shallow_copy.cpp
  3. *
  4. * This file is part of MegBrain, a deep learning framework developed by Megvii.
  5. *
  6. * \copyright Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  7. *
  8. */
  9. #include "megbrain/serialization/opr_shallow_copy.h"
  10. #include "megbrain/gopt/basic_arith.h"
  11. #include "megbrain/serialization/opr_load_dump.h"
  12. #include "megbrain/serialization/opr_registry.h"
  13. #include "megbrain/utils/big_key_hashmap.h"
  14. using namespace mgb;
  15. using namespace serialization;
  16. namespace {
  17. //! dump single opr to memory for shallow copy
  18. class OprDumpContextMemory final : public OprDumpContextRawPOD {
  19. std::vector<uint8_t> m_buf;
  20. void write_raw(const void* data, size_t size) override {
  21. auto pos = m_buf.size();
  22. auto end = pos + size;
  23. if (end > m_buf.capacity())
  24. m_buf.reserve(end * 2);
  25. m_buf.resize(end);
  26. memcpy(m_buf.data() + pos, data, size);
  27. }
  28. void dump_tensor(
  29. const std::string&, const HostTensorND&, TensorWriteMethod) override {
  30. mgb_throw(GraphError, "OprDumpContextMemory does not support dump tensor");
  31. }
  32. const GraphDumpConfig& config() const override {
  33. mgb_throw(GraphError, "OprDumpContextMemory has no associated config");
  34. }
  35. public:
  36. OprDumpContextMemory() : OprDumpContextRawPOD(false) {}
  37. auto&& buf() const { return m_buf; }
  38. };
  39. //! load single opr from memory for shallow copy
  40. class OprLoadContextMemory final : public OprLoadContextRawPOD {
  41. const uint8_t* m_ptr;
  42. size_t m_size, m_pos = 0;
  43. ComputingGraph* m_graph;
  44. void read_raw(void* dest, size_t size) override {
  45. auto end = m_pos + size;
  46. mgb_assert(end <= m_size);
  47. memcpy(dest, m_ptr + m_pos, size);
  48. m_pos = end;
  49. }
  50. ComputingGraph& graph() override { return *m_graph; }
  51. std::shared_ptr<HostTensorND> load_tensor() override { mgb_assert(0); }
  52. std::shared_ptr<DeviceTensorND> load_tensor_shared(
  53. bool copy_immediatly = false) override {
  54. (void)copy_immediatly;
  55. mgb_assert(0);
  56. }
  57. const GraphLoadConfig& config() const override {
  58. mgb_throw(GraphError, "OprLoadContextMemory has no associated config");
  59. }
  60. public:
  61. OprLoadContextMemory(ComputingGraph* graph, const OprDumpContextMemory& dumper)
  62. : OprLoadContextRawPOD(false),
  63. m_ptr{dumper.buf().data()},
  64. m_size{dumper.buf().size()},
  65. m_graph{graph} {}
  66. ~OprLoadContextMemory() { mgb_assert(m_pos == m_size); }
  67. };
  68. class ShallowCopyCacheContainer final : public UserDataContainer::UserData {
  69. MGB_TYPEINFO_OBJ_DECL;
  70. struct HashEq {
  71. template <typename T>
  72. static bool eq(const T& x, const T& y) {
  73. return x == y;
  74. }
  75. static bool eq(const OperatorNodeConfig& x, const OperatorNodeConfig& y) {
  76. return x.is_same(y);
  77. }
  78. static size_t hash(const void* ptr) { return std::hash<const void*>{}(ptr); }
  79. static size_t hash(const VarNodeArray& inputs) {
  80. return PODHash<VarNode*>::perform(inputs.data(), inputs.size());
  81. }
  82. static size_t hash(const OperatorNodeConfig& config) { return config.hash(); }
  83. };
  84. public:
  85. big_key_hash_map::BigKeyHashMap<
  86. cg::OperatorNodeBase*, HashEq,
  87. big_key_hash_map::Copy<const cg::OperatorNodeBase*>,
  88. big_key_hash_map::Ref<VarNodeArray>,
  89. big_key_hash_map::Ref<OperatorNodeConfig>>
  90. cache;
  91. };
  92. MGB_TYPEINFO_OBJ_IMPL(ShallowCopyCacheContainer);
  93. } // anonymous namespace
  94. ComputingGraph* serialization::OprShallowCopyContext::owner_graph(
  95. const cg::OperatorNodeBase& opr, const VarNodeArray& inputs) const {
  96. if (!m_owner_graph) {
  97. if (inputs.empty())
  98. return opr.owner_graph();
  99. return inputs[0]->owner_graph();
  100. }
  101. if (!inputs.empty())
  102. mgb_assert(m_owner_graph == inputs[0]->owner_graph());
  103. return m_owner_graph;
  104. }
  105. cg::OperatorNodeBase* serialization::copy_opr_shallow(
  106. const cg::OperatorNodeBase& opr, const VarNodeArray& inputs,
  107. const OperatorNodeConfig& config, const OprShallowCopyContext& ctx) {
  108. OprShallowCopy shallow_copy = nullptr;
  109. if (auto registry = OprRegistry::find_by_type(opr.dyn_typeinfo())) {
  110. shallow_copy = registry->shallow_copy;
  111. } else {
  112. shallow_copy = intl::copy_opr_shallow_default_impl;
  113. }
  114. mgb_assert(inputs.size() == opr.input().size());
  115. auto dst_og = ctx.owner_graph(opr, inputs);
  116. auto do_copy = [&]() {
  117. auto nr_opr_before = opr.owner_graph()->nr_oprs_in_graph();
  118. auto ret = shallow_copy(ctx, opr, inputs, config);
  119. if (dst_og != opr.owner_graph() ||
  120. opr.owner_graph()->nr_oprs_in_graph() != nr_opr_before) {
  121. auto&& attr = ret->node_prop().attribute();
  122. if (!attr.src_opr) {
  123. auto src = cg::get_opr_root_source_opr(
  124. const_cast<cg::OperatorNodeBase*>(&opr));
  125. if (ret != src)
  126. attr.src_opr = src;
  127. }
  128. if (!attr.priority) {
  129. // priority may have been changed by OprInserted event handlers
  130. // (like in python case)
  131. attr.priority = opr.node_prop().attribute().priority;
  132. }
  133. }
  134. return ret;
  135. };
  136. cg::OperatorNodeBase* ret;
  137. if (dst_og == opr.owner_graph()) {
  138. // use cache for copy in same graph
  139. auto&& cache =
  140. dst_og->options()
  141. .user_data.get_user_data_or_create<ShallowCopyCacheContainer>()
  142. ->cache;
  143. auto ins = cache.get(&opr, inputs, config);
  144. if (ins.first) {
  145. *ins.second = do_copy();
  146. } else {
  147. cg::update_output_var_shapes(*ins.second);
  148. }
  149. ret = *ins.second;
  150. } else {
  151. ret = do_copy();
  152. }
  153. mgb_assert(
  154. gopt::has_inplace_basic_arith_opt(opr) ||
  155. (( // outputs match
  156. opr.usable_output().size() ==
  157. ret->usable_output().size()) &&
  158. ( // new opr is returned
  159. (&opr != ret) || opr.input() == inputs)),
  160. "bad opr copy: src=%s{%s} dst=%s{%s}", opr.cname(),
  161. opr.dyn_typeinfo()->name, ret->cname(), ret->dyn_typeinfo()->name);
  162. return ret;
  163. }
  164. cg::OperatorNodeBase* serialization::intl::copy_opr_shallow_default_impl(
  165. const OprShallowCopyContext& ctx, const cg::OperatorNodeBase& opr,
  166. const VarNodeArray& inputs, const OperatorNodeConfig& config) {
  167. MGB_MARK_USED_VAR(ctx);
  168. OprDumper opr_dumper = nullptr;
  169. OprLoaderWrapper opr_loader = nullptr;
  170. if (auto registry = OprRegistry::find_by_type(opr.dyn_typeinfo())) {
  171. opr_loader = registry->loader;
  172. opr_dumper = registry->dumper;
  173. } else {
  174. auto registryv2 = OprRegistryV2::versioned_find_by_typeinfo(
  175. opr.dyn_typeinfo(), CURRENT_VERSION);
  176. opr_loader = registryv2->loader;
  177. opr_dumper = registryv2->dumper;
  178. }
  179. mgb_assert(
  180. opr_dumper && opr_loader,
  181. "can not shallow_copy operator %s{%s}: "
  182. "no dumper/loader registered",
  183. opr.cname(), opr.dyn_typeinfo()->name);
  184. OprDumpContextMemory memory_dumper;
  185. opr_dumper(memory_dumper, opr);
  186. OprLoadContextMemory loader{opr.owner_graph(), memory_dumper};
  187. return opr_loader(loader, inputs, config).opr();
  188. }
  189. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}