You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

op_trait.h 3.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. /**
  2. * \file imperative/src/impl/op_trait.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megbrain/imperative/op_def.h"
  13. namespace mgb {
  14. namespace imperative {
  15. namespace detail {
  16. template<typename Signature>
  17. struct OpMeth;
  18. template<typename RType, typename ...Args>
  19. struct OpMeth<RType(Args...)>: public thin_function<RType(Args...)> {
  20. using Base = thin_function<RType(Args...)>;
  21. using Base::Base;
  22. RType operator()(Args... args) const {
  23. if (!this->Base::operator bool()) {
  24. mgb_throw(MegBrainError, "Not Implemented");
  25. }
  26. return this->Base::operator ()(args...);
  27. }
  28. };
  29. } // detail
  30. using OpDefMaker = detail::OpMeth<
  31. decltype(OpDef::make_from_op_node)>;
  32. using ApplyOnPhysicalTensor = detail::OpMeth<
  33. decltype(OpDef::apply_on_physical_tensor)>;
  34. using ApplyOnVarNode = detail::OpMeth<
  35. decltype(OpDef::apply_on_var_node)>;
  36. using InferOutputAttrsFallible = detail::OpMeth<
  37. decltype(OpDef::infer_output_attrs_fallible)>;
  38. using GradMaker = detail::OpMeth<
  39. decltype(OpDef::make_backward_graph)>;
  40. struct OpTrait {
  41. const char* name;
  42. OpDefMaker make_from_op_node;
  43. ApplyOnPhysicalTensor apply_on_physical_tensor;
  44. ApplyOnVarNode apply_on_var_node;
  45. InferOutputAttrsFallible infer_output_attrs_fallible;
  46. GradMaker make_backward_graph;
  47. OpTrait(const char* name);
  48. static OpTrait* find_by_name(const char* name);
  49. static OpTrait* find_by_typeinfo(Typeinfo* type);
  50. static void for_each_trait(thin_function<void(OpTrait&)> visitor);
  51. };
  52. #define FOR_EACH_OP_METH(cb) \
  53. cb(make_from_op_node) \
  54. cb(apply_on_physical_tensor) \
  55. cb(apply_on_var_node) \
  56. cb(infer_output_attrs_fallible) \
  57. cb(make_backward_graph)
  58. struct OpTraitRegistry {
  59. OpTrait* trait;
  60. #define DECL(meth) \
  61. OpTraitRegistry& meth(decltype(OpTrait::meth) f) { \
  62. mgb_assert(!trait->meth, "op %s has duplicate method %s", trait->name, #meth); \
  63. trait->meth = f; \
  64. return *this; \
  65. }
  66. FOR_EACH_OP_METH(DECL)
  67. #undef DECL
  68. OpTraitRegistry& fallback();
  69. template<typename T>
  70. void insert() {
  71. do_insert(T::typeinfo());
  72. }
  73. template<typename T0, typename T1, typename ...Ts>
  74. void insert() {
  75. insert<T0>();
  76. insert<T1, Ts...>();
  77. }
  78. template<typename ...Args>
  79. static OpTraitRegistry insert(const char* name) {
  80. auto&& ret = do_insert(name);
  81. ret.insert<Args...>();
  82. return ret;
  83. }
  84. void do_insert(Typeinfo* type);
  85. static OpTraitRegistry do_insert(const char* name);
  86. };
  87. } // namespace imperative
  88. } // namespace mgb
  89. #define OP_TRAIT_REG(name, ...) \
  90. static OpTraitRegistry __##name##_global_registry__ = \
  91. OpTraitRegistry::insert<__VA_ARGS__>(#name)
  92. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台