You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

halide_executable.cpp 8.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. /**
  2. * \file src/jit/impl/halide/halide_executable.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./halide_executable.h"
  12. #if MGB_JIT_HALIDE
  13. #include "megbrain/jit/utils.h"
  14. using namespace mgb;
  15. using namespace jit;
  16. using namespace Halide;
  17. HalideExecutable::FunctionHandle::~FunctionHandle() {
  18. if (device_release && uctx_map) {
  19. for (auto&& i : uctx_map->cn2uctx) {
  20. device_release(i.second);
  21. }
  22. }
  23. delete uctx_map;
  24. if (dl_handle) {
  25. ExecutableHelper::get().unload_lib(dl_handle);
  26. }
  27. }
  28. HalideExecutable::TargetTraitUserData* HalideExecutable::TargetTrait::user_data(
  29. const HalideExecutable& hl_exec,
  30. thin_function<std::unique_ptr<TargetTraitUserData>()> maker) {
  31. MGB_LOCK_GUARD(hl_exec.m_target_trait_user_data_mtx);
  32. if (!hl_exec.m_target_trait_user_data) {
  33. hl_exec.m_target_trait_user_data = maker();
  34. }
  35. return hl_exec.m_target_trait_user_data.get();
  36. }
  37. /* =================== HalideExecutable ==================== */
  38. HalideExecutable::~HalideExecutable() = default;
  39. HalideExecutable::HalideExecutable(std::shared_ptr<TargetTrait> target_trait,
  40. const InternalGraph& graph,
  41. const JITExecutor::Args& args)
  42. : m_target_trait{std::move(target_trait)} {
  43. ThinHashMap<VarNode*, const JITExecutor::Args::Data*> placeholders_to_inps;
  44. for (auto&& inp : args.inputs) {
  45. VarNode* placeholder = graph.placeholders().at(inp.idx)->output(0);
  46. placeholders_to_inps[placeholder] = &inp;
  47. }
  48. using AstNodePtr = ast_hl::AstNodePtr;
  49. ThinHashMap<VarNode*, AstNodePtr> mgb2halide;
  50. auto on_opr = [&](cg::OperatorNodeBase* opr) {
  51. auto var = opr->output(0);
  52. AstNodePtr ptr;
  53. if (opr->same_type<JITPlaceholder>()) {
  54. auto data = placeholders_to_inps.at(var);
  55. auto&& ph = opr->cast_final_safe<JITPlaceholder>();
  56. if (ph.is_host_value_shape_input()) {
  57. ptr = std::make_shared<ast_hl::InputHostValueShapeOp>();
  58. ptr->m_layout = data->layout;
  59. } else {
  60. ptr = mgb_var_to_halide_buffer(data->from);
  61. m_value_inputs.emplace_back(static_cast<size_t>(data->idx),
  62. ptr);
  63. }
  64. } else {
  65. ptr = ast_hl::make_from_opr(opr);
  66. for (auto inp : opr->input()) {
  67. ptr->m_inputs.push_back(mgb2halide.at(inp));
  68. }
  69. ptr->init(opr);
  70. }
  71. mgb2halide[var] = std::move(ptr);
  72. };
  73. cg::DepOprIter{on_opr}.add(graph.output());
  74. std::sort(m_value_inputs.begin(), m_value_inputs.end());
  75. m_halide_output = mgb2halide.at(graph.output());
  76. }
  77. void HalideExecutable::execute(JITExecutor* fusion_opr) {
  78. // load func_ptr for current comp node
  79. auto comp_node = fusion_opr->comp_node();
  80. std::atomic<FunctionHandle*>* func_ptr_ref;
  81. {
  82. MGB_LOCK_GUARD(m_mtx);
  83. func_ptr_ref = &m_cn2func[comp_node];
  84. }
  85. auto func_ptr = func_ptr_ref->load();
  86. if (!func_ptr) {
  87. std::pair<std::mutex, FunctionHandle>* func_maker;
  88. {
  89. MGB_LOCK_GUARD(m_mtx);
  90. func_maker =
  91. &m_feature_set2func[m_target_trait->features(comp_node)];
  92. }
  93. // compile the function
  94. MGB_LOCK_GUARD(func_maker->first);
  95. if (!(func_ptr = func_ptr_ref->load())) {
  96. if (!func_maker->second.execute) {
  97. func_maker->second = compile_and_load(comp_node);
  98. mgb_assert(func_maker->second.execute);
  99. }
  100. func_ptr = &func_maker->second;
  101. func_ptr_ref->store(func_ptr);
  102. }
  103. }
  104. void* user_context = nullptr;
  105. if (func_ptr->uctx_map) {
  106. MGB_LOCK_GUARD(func_ptr->uctx_map->mtx);
  107. auto&& ptr = func_ptr->uctx_map->cn2uctx[comp_node];
  108. if (!ptr) {
  109. ptr = m_target_trait->get_user_context(comp_node);
  110. }
  111. user_context = ptr;
  112. }
  113. invoke(user_context, *func_ptr, fusion_opr->input(), fusion_opr->output(0));
  114. }
  115. std::vector<Halide::Argument> HalideExecutable::halide_inputs() const {
  116. std::vector<Argument> args;
  117. for (auto&& i : m_value_inputs) {
  118. auto&& input_buffer =
  119. i.second->cast_final_safe<ast_hl::InputDevValueOp>();
  120. args.emplace_back(input_buffer.m_buffer);
  121. }
  122. return args;
  123. }
  124. HalideExecutable::FunctionHandle HalideExecutable::compile_and_load(
  125. CompNode comp_node) const {
  126. Target target = get_host_target();
  127. auto req_features = m_target_trait->features(comp_node);
  128. target.set_feature(Target::UserContext);
  129. if (MGB_GETENV("MGB_HALIDE_DEBUG")) {
  130. target.set_feature(Target::Debug);
  131. }
  132. for (size_t i = 0; i < req_features.size(); ++i) {
  133. if (req_features.test(i)) {
  134. target.set_feature(static_cast<Target::Feature>(i));
  135. }
  136. }
  137. return m_target_trait->compile_and_load(comp_node, target, *this);
  138. }
  139. void HalideExecutable::invoke(void* user_context, const FunctionHandle& handle,
  140. const VarNodeArray& inputs, VarNode* output) {
  141. mgb_assert(handle.execute && handle.get_device_interface);
  142. halide_device_interface_t* device_interface = handle.get_device_interface();
  143. size_t nr_inputs = m_value_inputs.size(), argv_idx = 0;
  144. void* argv[nr_inputs + 2];
  145. halide_buffer_t image_args[nr_inputs + 1];
  146. size_t nr_dims = (nr_inputs + 1) * TensorLayout::MAX_NDIM;
  147. halide_dimension_t image_dims_buf[nr_dims];
  148. memset(image_dims_buf, 0, sizeof(halide_dimension_t) * nr_dims);
  149. size_t image_arg_idx = 0;
  150. halide_dimension_t* image_dims_ptr = image_dims_buf;
  151. auto add_tensor_arg = [&](const DeviceTensorND& tensor) {
  152. int ndim = tensor.layout().ndim;
  153. for (int i = ndim - 1; i >= 0; i--) {
  154. image_dims_ptr->extent = tensor.layout()[i];
  155. image_dims_ptr->stride = tensor.layout().stride[i];
  156. image_dims_ptr++;
  157. }
  158. auto dtype = tensor.dtype();
  159. halide_type_t type = dtype_mgb2halide(dtype);
  160. image_args[image_arg_idx] = {
  161. reinterpret_cast<uint64_t>(tensor.raw_ptr()),
  162. device_interface,
  163. nullptr,
  164. 0,
  165. type,
  166. ndim,
  167. image_dims_ptr - ndim,
  168. nullptr};
  169. argv[argv_idx++] = &image_args[image_arg_idx++];
  170. };
  171. argv[argv_idx++] = &user_context;
  172. for (auto&& i : m_value_inputs) {
  173. add_tensor_arg(inputs.at(i.first)->dev_tensor());
  174. }
  175. add_tensor_arg(output->dev_tensor());
  176. mgb_assert(argv_idx == nr_inputs + 2);
  177. mgb_assert(image_dims_ptr <= image_dims_buf + nr_dims);
  178. auto err = handle.execute(argv);
  179. mgb_throw_if(err, SystemError, "failed to execute halide function: err=%d",
  180. err);
  181. }
  182. halide_type_t HalideExecutable::dtype_mgb2halide(DType dtype) {
  183. if (dtype == dtype::Float32()) {
  184. return halide_type_of<float>();
  185. } else if (dtype == dtype::Float16()) {
  186. return halide_type_of<float16_t>();
  187. } else if (dtype == dtype::Int32()) {
  188. return halide_type_of<int>();
  189. } else {
  190. mgb_throw(InternalError,
  191. "dtype(%s) is not any of [Float16, Float32, Int32]",
  192. dtype.name());
  193. }
  194. }
  195. ast_hl::AstNodePtr HalideExecutable::mgb_var_to_halide_buffer(VarNode* var) {
  196. auto res = std::make_shared<ast_hl::InputDevValueOp>();
  197. res->m_layout = var->layout();
  198. int ndim = var->layout().ndim;
  199. halide_dimension_t halide_dim[ndim];
  200. memset(halide_dim, 0, sizeof(halide_dimension_t) * ndim);
  201. for (int i = ndim - 1; i >= 0; i--) {
  202. halide_dim[ndim - 1 - i].extent = res->m_layout[i];
  203. halide_dim[ndim - 1 - i].stride = res->m_layout.stride[i];
  204. }
  205. halide_buffer_t buf{
  206. 0, nullptr, nullptr, 0, dtype_mgb2halide(var->dtype()),
  207. ndim, halide_dim, nullptr};
  208. res->m_buffer = Buffer<>{buf};
  209. res->init(nullptr);
  210. return res;
  211. }
  212. #endif // MGB_JIT_HALIDE
  213. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台

Contributors (1)