You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

layout_transform_context.cpp 4.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. /**
  2. * \file src/gopt/impl/layout_transform_context.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "./utils.h"
  13. #include "megbrain/gopt/global_layout_transform.h"
  14. #include "megbrain/opr/dnn/pooling.h"
  15. #include "megbrain/opr/imgproc.h"
  16. #include "megbrain/opr/nn_int.h"
  17. using namespace mgb;
  18. using namespace gopt;
  19. namespace {
  20. using OprFormat = LayoutTransformContext::OprFormat;
  21. using OprList = LayoutTransformContext::OprList;
  22. using Attribute = LayoutTransformContext::Attribute;
  23. using Target = LayoutTransformContext::Target;
  24. const char* target_to_string(Target target) {
  25. #define cb(_target) \
  26. case Target::_target: \
  27. return #_target
  28. switch (target) {
  29. cb(CUDA);
  30. cb(X86);
  31. cb(ARM);
  32. cb(UNSPEC);
  33. default:
  34. mgb_assert(false, "unsupported target (got:%u)",
  35. static_cast<uint32_t>(target));
  36. }
  37. #undef cb
  38. }
  39. std::unique_ptr<LayoutTransformContext> make_cuda_ctx(
  40. OprFormat base_opr_format, TensorFormats base_tensor_format) {
  41. OprList opr_list = {
  42. opr::ConvBiasForward::typeinfo(),
  43. opr::ConvolutionForward::typeinfo(),
  44. opr::ConvolutionBackwardData::typeinfo(),
  45. opr::ElemwiseMultiType::typeinfo(),
  46. opr::Elemwise::typeinfo(),
  47. opr::TypeCvt::typeinfo(),
  48. opr::PoolingForward::typeinfo(),
  49. opr::WarpPerspectiveForward::typeinfo(),
  50. };
  51. SmallVector<TensorFormats> available_tensor_formats = {
  52. TensorFormats::NCHW, TensorFormats::NHWC,
  53. TensorFormats::NCHWc4, TensorFormats::NCHWc32,
  54. TensorFormats::NCHWc64, TensorFormats::CHWNc4};
  55. Attribute attribute = {base_opr_format, base_tensor_format, Target::CUDA};
  56. auto ctx = std::make_unique<LayoutTransformContext>(
  57. std::move(opr_list), std::move(available_tensor_formats),
  58. attribute);
  59. ctx->add_opr_config(
  60. opr::ConvBiasForward::typeinfo(),
  61. {OprFormat::NCHW, OprFormat::NHWC, OprFormat::NCHW4,
  62. OprFormat::NCHW32, OprFormat::NCHW64, OprFormat::CHWN4})
  63. .add_opr_config(opr::ConvolutionForward::typeinfo(),
  64. {OprFormat::NCHW, OprFormat::NCHW4})
  65. .add_opr_config(opr::ConvolutionBackwardData::typeinfo(),
  66. {OprFormat::NCHW, OprFormat::NCHW4})
  67. .add_opr_config(
  68. opr::PoolingForward::typeinfo(),
  69. {OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::NHWC,
  70. OprFormat::NCHW64, OprFormat::CHWN4})
  71. .add_opr_config(
  72. opr::WarpPerspectiveForward::typeinfo(),
  73. {OprFormat::NHWC, OprFormat::NCHW4, OprFormat::NCHW64});
  74. return ctx;
  75. }
  76. } // namespace
  77. /* ================= LayoutTransformContext ==================*/
  78. LayoutTransformContext& LayoutTransformContext::add_opr_config(
  79. Typeinfo* opr, OprFormat opr_format) {
  80. auto& dispatchers = m_opr_configs[opr];
  81. dispatchers[opr_format] =
  82. OprTensorFormatsConfiguration::find_dispatcher_by_type_format(
  83. opr, opr_format);
  84. return *this;
  85. }
  86. LayoutTransformContext& LayoutTransformContext::add_opr_config(
  87. Typeinfo* opr, SmallVector<OprFormat> opr_formats) {
  88. auto& dispatchers = m_opr_configs[opr];
  89. for (auto opr_fmt : opr_formats) {
  90. dispatchers[opr_fmt] =
  91. OprTensorFormatsConfiguration::find_dispatcher_by_type_format(
  92. opr, opr_fmt);
  93. }
  94. return *this;
  95. }
  96. std::unique_ptr<LayoutTransformContext> LayoutTransformContext::make(
  97. Target target, OprFormat base_opr_format,
  98. TensorFormats base_tensor_format) {
  99. switch (target) {
  100. case Target::CUDA:
  101. return make_cuda_ctx(base_opr_format, base_tensor_format);
  102. default:
  103. mgb_assert(false, "unsupported target %s\n",
  104. target_to_string(target));
  105. }
  106. }
  107. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台