You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_attr.cpp 4.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. /**
  2. * \file imperative/src/impl/ops/opr_attr.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/imperative/ops/opr_attr.h"
  12. #include "megbrain/serialization/opr_load_dump.h"
  13. #include "../op_trait.h"
  14. #include "megbrain/imperative/proxy_graph_detail.h"
  15. namespace mgb {
  16. namespace imperative {
  17. namespace {
  18. class OprParamsLoadContext final: public serialization::OprLoadContextRawPOD {
  19. const OprAttr::Param& m_param;
  20. size_t m_pos = 0;
  21. ComputingGraph *m_graph;
  22. void read_raw(void *dest, size_t size) override final {
  23. mgb_assert(m_pos + size <= m_param.size(), "too many bytes requested");
  24. memcpy(dest, m_param.data() + m_pos, size);
  25. m_pos += size;
  26. }
  27. std::shared_ptr<HostTensorND> load_tensor() override {
  28. mgb_assert(0);
  29. }
  30. std::shared_ptr<DeviceTensorND> load_tensor_shared() override {
  31. mgb_assert(0);
  32. }
  33. const serialization::GraphLoadConfig& config() const override {
  34. mgb_assert(0);
  35. }
  36. public:
  37. OprParamsLoadContext(const OprAttr::Param& param,
  38. ComputingGraph *graph):
  39. serialization::OprLoadContextRawPOD(false), m_param(param), m_graph(graph)
  40. {}
  41. ~OprParamsLoadContext() {
  42. mgb_assert(m_pos == m_param.size(), "param not fully consumed");
  43. }
  44. ComputingGraph& graph() override {
  45. return *m_graph;
  46. }
  47. };
  48. class OprParamsDumpContext final: public serialization::OprDumpContextRawPOD {
  49. public:
  50. OprAttr::Param m_param;
  51. OprParamsDumpContext() : serialization::OprDumpContextRawPOD(false) {}
  52. void write_raw(const void *data, size_t size) {
  53. const char* src = static_cast<const char*>(data);
  54. m_param.insert(m_param.end(), src, src + size);
  55. }
  56. void dump_tensor(
  57. const std::string &name,
  58. const HostTensorND &tensor,
  59. TensorWriteMethod method) {
  60. mgb_assert(0);
  61. }
  62. const serialization::GraphDumpConfig& config() const {
  63. mgb_assert(0);
  64. }
  65. };
  66. cg::OperatorNodeBase* apply_on_var_node(
  67. const OpDef& def, const VarNodeArray& inputs) {
  68. auto&& attr = def.cast_final_safe<OprAttr>();
  69. auto config = attr.config;
  70. config.name(attr.make_name());
  71. mgb_assert(!inputs.empty());
  72. auto registry = serialization::OprRegistry::find_by_name(attr.type);
  73. mgb_assert(registry, "operator %s not found", attr.type.c_str());
  74. OprParamsLoadContext ctx{attr.param, inputs[0]->owner_graph()};
  75. return registry->loader(ctx, inputs, config);
  76. }
  77. std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* opr) {
  78. OprParamsDumpContext ctx;
  79. auto registry = serialization::OprRegistry::find_by_type(opr->dyn_typeinfo());
  80. mgb_assert(registry, "operator %s not found", opr->dyn_typeinfo()->name);
  81. mgb_assert(registry->dumper, "operator %s cannot be serialized", opr->dyn_typeinfo()->name);
  82. registry->dumper(ctx, *opr);
  83. return OprAttr::make(registry->name, std::move(ctx.m_param), opr->config());
  84. }
  85. std::vector<std::pair<const char*, std::string>> props(const OpDef& def) {
  86. return {};
  87. }
  88. std::string make_name(const OpDef& def) {
  89. return "OprAttr";
  90. }
  91. OP_TRAIT_REG(OprAttr, OprAttr)
  92. .make_from_op_node(make_from_op_node)
  93. .apply_on_var_node(apply_on_var_node)
  94. .props(props)
  95. .make_name(make_name)
  96. .fallback();
  97. } // anonymous namespace
  98. bool OprAttr::is_same_st(const Hashable& rhs_) const {
  99. auto&& rhs = static_cast<const OprAttr&>(rhs_);
  100. return type == rhs.type && param == rhs.param
  101. && config.comp_node() == rhs.config.comp_node()
  102. && config.output_dtype() == rhs.config.output_dtype();
  103. }
  104. size_t OprAttr::hash() const {
  105. return hash_pair_combine(
  106. hash_pair_combine(
  107. mgb::hash(type),
  108. mgb::hash(static_cast<std::vector<char>>(param))),
  109. config.hash());
  110. }
  111. MGB_DYN_TYPE_OBJ_FINAL_IMPL(OprAttr);
  112. } // namespace imperative
  113. } // namespace mgb
  114. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台