You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

misc.cpp 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435
  1. /**
  2. * \file src/opr/impl/misc.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./internal/megdnn_opr_wrapper.inl"
  12. #include "megbrain/graph/grad_impl.h"
  13. #include "megbrain/opr/basic_arith_wrapper.h"
  14. #include "megbrain/opr/indexing.h"
  15. #include "megbrain/opr/misc.h"
  16. #include "megbrain/opr/tensor_manip.h"
  17. #include "megbrain/opr/utility.h"
  18. using namespace mgb;
  19. using namespace opr;
  20. namespace mgb {
  21. namespace opr {
  22. namespace intl {
  23. template<>
  24. struct MegDNNOprInitPostCtor<Argmax> {
  25. static void apply(cg::OperatorNodeBase &opr) {
  26. opr.output(0)->dtype(dtype::Int32());
  27. }
  28. };
  29. template<>
  30. struct MegDNNOprInitPostCtor<Argmin>: public MegDNNOprInitPostCtor<Argmax> {
  31. };
  32. template<>
  33. struct MegDNNOprInitPostCtor<ArgsortForward> {
  34. static void apply(cg::OperatorNodeBase &opr) {
  35. opr.output(0)->dtype(opr.input(0)->dtype());
  36. opr.output(1)->dtype(dtype::Int32());
  37. }
  38. };
  39. }
  40. }
  41. }
  42. /* ================= Argmxx ================= */
  43. #if MGB_ENABLE_GRAD
  44. MGB_IMPL_OPR_GRAD(Argmax) {
  45. MGB_MARK_USED_VAR(out_grad);
  46. MGB_MARK_USED_VAR(opr);
  47. mgb_assert(!wrt_idx);
  48. return nullptr;
  49. }
  50. #endif
  51. MGB_DYN_TYPE_OBJ_FINAL_IMPL(Argmax);
  52. MEGDNN_OPR_INIT1(Argmax, "argmax")
  53. #if MGB_ENABLE_GRAD
  54. MGB_IMPL_OPR_GRAD(Argmin) {
  55. MGB_MARK_USED_VAR(out_grad);
  56. MGB_MARK_USED_VAR(opr);
  57. mgb_assert(!wrt_idx);
  58. return nullptr;
  59. }
  60. #endif
  61. MGB_DYN_TYPE_OBJ_FINAL_IMPL(Argmin);
  62. MEGDNN_OPR_INIT1(Argmin, "argmin")
  63. /* ================= ArgsortForward ================= */
  64. MGB_DYN_TYPE_OBJ_FINAL_IMPL(ArgsortForward);
  65. MEGDNN_OPR_CTOR_INIT1(ArgsortForward, "argsort")
  66. std::array<SymbolVar, 2> ArgsortForward::make(
  67. SymbolVar in_tensor, const Param &param,
  68. const OperatorNodeConfig &config)
  69. {
  70. auto node = in_tensor.node()->owner_graph()->insert_opr(
  71. std::make_unique<ArgsortForward>(in_tensor.node(), param, config));
  72. mgb_assert(node->output().size() == 3);
  73. return {node->output(0), node->output(1)};
  74. }
  75. #if MGB_ENABLE_GRAD
  76. MGB_IMPL_OPR_GRAD(ArgsortForward) {
  77. mgb_assert(out_grad.size() == 3 && wrt_idx == 0 && !out_grad[2]);
  78. if (!out_grad[0])
  79. return nullptr;
  80. return ArgsortBackward::make(out_grad[0], opr.output(1)).node();
  81. }
  82. #endif
  83. /* ================= ArgsortBackward ================= */
  84. MGB_DYN_TYPE_OBJ_FINAL_IMPL(ArgsortBackward);
  85. MEGDNN_OPR_INIT3(ArgsortBackward, "argsort_bwd", 2, false)
  86. /* ================= Cumsum ================= */
  87. MGB_DYN_TYPE_OBJ_FINAL_IMPL(Cumsum);
  88. Cumsum::Cumsum(VarNode* opr, const Param& param,
  89. const OperatorNodeConfig& config)
  90. : Super{opr->owner_graph(), config, "Cumsum", {opr}} {
  91. init_megdnn_opr(*this, param);
  92. add_input({opr}, AddInputSortType::CUR_ADDED);
  93. }
  94. #if MGB_ENABLE_GRAD
  95. MGB_IMPL_OPR_GRAD(Cumsum) {
  96. mgb_assert(out_grad[0] && !out_grad[1]);
  97. auto param = opr.param();
  98. param.reverse = !param.reverse;
  99. return Cumsum::make(out_grad[0], param).node();
  100. }
  101. #endif
  102. SymbolVar Cumsum::make(SymbolVar opr, const Param& param,
  103. const OperatorNodeConfig& config) {
  104. return opr.insert_single_output_opr<Cumsum>(opr.node(), param, config);
  105. }
  106. void Cumsum::scn_do_execute() {
  107. megdnn_opr()->exec(input(0)->dev_tensor().as_megdnn(),
  108. output(0)->dev_tensor().as_megdnn(),
  109. intl::get_megdnn_workspace_from_var(output().back()));
  110. }
  111. void Cumsum::init_output_static_infer_desc() {
  112. using namespace cg::static_infer;
  113. auto infer_shape = [](TensorShape& dest, const InpVal& iv) {
  114. auto ishp = iv.val.at(0).shape();
  115. dest = ishp;
  116. return true;
  117. };
  118. owner_graph()->static_infer_manager().register_shape_infer(
  119. output(0),
  120. {SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_shape});
  121. auto infer_workspace = [this](TensorShape& dest, const InpVal& iv) {
  122. auto dtype = input(0)->dtype();
  123. auto ishp = iv.val.at(0).shape();
  124. TensorLayout ily(ishp, dtype);
  125. Param real_param = param();
  126. if (real_param.axis < 0)
  127. real_param.axis += ishp.ndim;
  128. megdnn_opr()->param() = real_param;
  129. dest.ndim = 1;
  130. dest[0] = megdnn_opr()->get_workspace_in_bytes(ily, ily);
  131. return true;
  132. };
  133. owner_graph()->static_infer_manager().register_shape_infer(
  134. output(1),
  135. {SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_workspace});
  136. }
  137. /* ================= NvOf ================= */
  138. #if MGB_CUDA
  139. MGB_DYN_TYPE_OBJ_FINAL_IMPL(NvOf);
  140. NvOf::NvOf(VarNode* opr, const Param& param, const OperatorNodeConfig& config)
  141. : Super{opr->owner_graph(), config, "NvOf", {opr}}, m_param{param} {
  142. constexpr size_t NDIM = 5;
  143. mgb_assert(opr->dtype() == dtype::Uint8());
  144. add_input({opr});
  145. //! NvOf hava only one output
  146. add_output(None);
  147. mgb_log_debug("init nvof engine with precision: %u", m_param.precision);
  148. auto input_shape = this->input()[0]->shape();
  149. //! nvof input format: nthwc4
  150. mgb_assert(input_shape.ndim == NDIM);
  151. //! now only support RGBA format channel data
  152. mgb_assert(input_shape[4] == 4);
  153. for (size_t i = 0; i < NDIM; i++) {
  154. vshape.push_back(input_shape[i]);
  155. }
  156. }
  157. void NvOf::init_output_dtype() {
  158. output(0)->dtype(dtype::Int16());
  159. }
  160. SymbolVar NvOf::make(SymbolVar opr, const Param& param,
  161. const OperatorNodeConfig& config) {
  162. return opr.insert_single_output_opr<NvOf>(opr.node(), param, config);
  163. }
  164. void NvOf::scn_do_execute() {
  165. auto c = this->comp_node();
  166. //! comp_node may init on CUDA or CPU, eg: lar with --cpu
  167. //! if ON CUDA, need sync, caused by we use different stream
  168. if (CompNode::DeviceType::CUDA == c.device_type()) {
  169. c.sync();
  170. } else {
  171. mgb_log_warn(
  172. "NvOf opr on non CUDA comp_node, which will triger H2D and "
  173. "D2H!!");
  174. }
  175. //! create NvOF engine at same device id of comp_node, can not get
  176. //! comp_node device id, when NvOf:NvOf, so init at scn_do_execute
  177. std::lock_guard<std::mutex> lock(m_lock);
  178. if (init_flag == false) {
  179. //! nvof sdk do not imp p2p copy, so init nvof engine on the same
  180. //! device with mgb comp_node
  181. nv_flow_extractor = std::make_shared<NVFlowExtractor>(
  182. c.locator().device, vshape, m_param.precision, true, true);
  183. init_flag = true;
  184. }
  185. nv_flow_extractor->extract_flow(
  186. static_cast<unsigned char*>(
  187. input(0)->dev_tensor().as_megdnn().raw_ptr),
  188. vshape,
  189. reinterpret_cast<int16_t*>(
  190. output(0)->dev_tensor().as_megdnn().raw_ptr));
  191. }
  192. void NvOf::init_output_static_infer_desc() {
  193. using namespace cg::static_infer;
  194. auto infer_shape = [](TensorShape& dest, const InpVal& iv) {
  195. auto ishp = iv.val.at(0).shape();
  196. SmallVector<size_t> tv;
  197. tv.push_back(ishp[0]);
  198. tv.push_back(ishp[1] - 1);
  199. tv.push_back(ishp[2] / 4);
  200. tv.push_back(ishp[3] / 4);
  201. tv.push_back(ishp[4] / 2);
  202. dest = tv;
  203. return true;
  204. };
  205. owner_graph()->static_infer_manager().register_shape_infer(
  206. output(0),
  207. {SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_shape});
  208. }
  209. #endif
  210. /* ================= CondTake ================= */
  211. MGB_DYN_TYPE_OBJ_FINAL_IMPL(CondTake);
  212. CondTake::CondTake(VarNode *data, VarNode *mask,
  213. const Param &param, const OperatorNodeConfig &config):
  214. Super(data->owner_graph(), config, "cond_take", {data, mask})
  215. {
  216. init_megdnn_opr(*this, param);
  217. add_input({data, mask});
  218. auto dtypes = megdnn_opr()->infer_dtype(data->dtype(), mask->dtype());
  219. for (int i = 0; i < 2; ++ i) {
  220. output(i)
  221. ->add_flag(VarNode::Flag::NO_SYS_MEM_ALLOC)
  222. .add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE)
  223. .dtype(dtypes[i]);
  224. }
  225. }
  226. #if MGB_ENABLE_GRAD
  227. MGB_IMPL_OPR_GRAD(CondTake) {
  228. mgb_assert(out_grad.size() == 3 && !out_grad[2]);
  229. if (wrt_idx == 0 && out_grad[0]) {
  230. SymbolVar data_sym{opr.input(0)};
  231. auto inp_set = IndexingIncrMultiAxisVec::make(
  232. data_sym.flatten().fill_retain_dtype(0), out_grad[0],
  233. {indexing::AxisIndexer::make_index(0, opr.output(1))});
  234. return inp_set.reshape(data_sym.symshape()).node();
  235. }
  236. return nullptr;
  237. }
  238. #endif
  239. std::array<SymbolVar, 2> CondTake::make(
  240. SymbolVar data, SymbolVar mask,
  241. const Param &param, const OperatorNodeConfig &config) {
  242. auto ov0 = data.insert_single_output_opr<CondTake>(
  243. data.node(), mask.node(), param, config);
  244. return {ov0, ov0.node()->owner_opr()->output(1)};
  245. }
  246. void CondTake::init_output_static_infer_desc() {
  247. using namespace cg::static_infer;
  248. auto infer_workspace = [this](TensorShape& dest, const InpVal& iv) {
  249. auto dtype = input(0)->dtype();
  250. TensorLayout ily(iv.val[0].shape(), dtype);
  251. dest.ndim = 1;
  252. dest.shape[0] = megdnn_opr()->get_workspace_in_bytes(ily);
  253. return true;
  254. };
  255. owner_graph()->static_infer_manager().register_shape_infer(
  256. output(2),
  257. {SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_workspace});
  258. }
  259. void CondTake::add_input_layout_constraint() {
  260. mixin::megdnn_utils::add_input_layout_constraint_contig(*this);
  261. }
  262. void CondTake::scn_do_execute() {
  263. intl::MegDNNDynOutMallocImpl dyn_malloc{this, comp_node()};
  264. megdnn_opr()->exec(input(0)->dev_tensor().as_megdnn(),
  265. input(1)->dev_tensor().as_megdnn(),
  266. intl::get_megdnn_workspace_from_var(output().back()),
  267. &dyn_malloc);
  268. }
  269. /* ================= TopK ================= */
  270. MGB_DYN_TYPE_OBJ_FINAL_IMPL(TopK);
  271. TopK::TopK(VarNode* data, VarNode* k, const Param& param,
  272. const OperatorNodeConfig& config)
  273. : Super(data->owner_graph(), config, "top_k", {data, k}) {
  274. init_megdnn_opr(*this, param);
  275. add_input({data, k});
  276. if (param.mode == Param::Mode::KTH_ONLY) {
  277. output(1)
  278. ->add_flag(VarNode::Flag::VOLATILE_CONTENT)
  279. .add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
  280. }
  281. }
  282. std::array<SymbolVar, 2> TopK::make(SymbolVar data, SymbolVar k,
  283. const Param& param,
  284. const OperatorNodeConfig& config) {
  285. auto opr = data.node()->owner_graph()->insert_opr(
  286. std::make_unique<TopK>(data.node(), k.node(), param, config));
  287. auto o1 = opr->output(1);
  288. if (param.mode == Param::Mode::KTH_ONLY) {
  289. o1 = nullptr;
  290. }
  291. return {opr->output(0), o1};
  292. }
  293. void TopK::init_output_dtype() {
  294. mgb_assert(input(1)->dtype() == dtype::Int32{}, "k must be int32, got %s",
  295. input(1)->dtype().name());
  296. output(0)->dtype(input(0)->dtype());
  297. output(1)->dtype(dtype::Int32{});
  298. }
  299. void TopK::add_input_layout_constraint() {
  300. auto check = [](const TensorLayout& layout) {
  301. mgb_assert(layout.ndim == 2, "top-k input must be two-dim, got %s",
  302. layout.TensorShape::to_string().c_str());
  303. return layout.stride[1] == 1;
  304. };
  305. input(0)->add_layout_constraint(check);
  306. }
  307. void TopK::init_output_static_infer_desc() {
  308. using namespace cg::static_infer;
  309. auto&& mgr = owner_graph()->static_infer_manager();
  310. auto infer_oshp0 = [this](TensorShape& dst, const InpVal& iv) {
  311. auto&& k_tensor = iv.val[1].value();
  312. mgb_assert(k_tensor.shape().is_scalar(), "k must be scalar, got %s",
  313. k_tensor.shape().to_string().c_str());
  314. TensorLayout o0, o1;
  315. megdnn_opr()->deduce_layout(k_tensor.ptr<int>()[0],
  316. {iv.val[0].shape(), input(0)->dtype()}, o0,
  317. o1);
  318. dst = o0;
  319. return true;
  320. };
  321. mgr.register_shape_infer(output(0), {SourceType::DEP,
  322. {{input(0), DepType::SHAPE},
  323. {input(1), DepType::VALUE}},
  324. infer_oshp0});
  325. if (param().mode == Param::Mode::KTH_ONLY) {
  326. mgr.register_shape_infer(output(1), ShapeInferDesc::make_const({}));
  327. } else {
  328. mgr.register_shape_infer(output(1),
  329. ShapeInferDesc::make_identity(output(0)));
  330. }
  331. auto infer_workspace = [this](TensorShape& dst, const InpVal& iv) {
  332. auto k = iv.val[3].value().ptr<int>()[0];
  333. auto size = megdnn_opr()->get_workspace_in_bytes(
  334. k, {iv.val[0].shape(), input(0)->dtype()},
  335. {iv.val[1].shape(), output(0)->dtype()},
  336. {iv.val[2].shape(), output(1)->dtype()});
  337. dst.ndim = 1;
  338. dst.shape[0] = size;
  339. return true;
  340. };
  341. mgr.register_shape_infer(output(2), {SourceType::DEP,
  342. {{input(0), DepType::SHAPE},
  343. {output(0), DepType::SHAPE},
  344. {output(1), DepType::SHAPE},
  345. {input(1), DepType::VALUE}},
  346. infer_workspace});
  347. }
  348. void TopK::scn_do_execute() {
  349. auto&& mgr = owner_graph()->static_infer_manager();
  350. auto k = mgr.infer_value(input(1)).ptr<int>()[0];
  351. megdnn_opr()->exec(k, input(0)->dev_tensor().as_megdnn(),
  352. output(0)->dev_tensor().as_megdnn(),
  353. output(1)->dev_tensor().as_megdnn(),
  354. intl::get_megdnn_workspace_from_var(output(2)));
  355. }
  356. void TopK::record_execute_deps(ExecDependencyArray& deps) {
  357. record_megdnn_opr(deps);
  358. }
  359. #if MGB_ENABLE_GRAD
  360. MGB_IMPL_OPR_GRAD(TopK) {
  361. if (opr.param().mode == TopK::Param::Mode::KTH_ONLY) {
  362. mgb_assert(out_grad[0] && !out_grad[1] && !out_grad[2]);
  363. auto add_axis = [](SymbolVar x) {
  364. return opr::AxisAddRemove::make(
  365. x, {opr::AxisAddRemove::AxisDesc::make_add(1)});
  366. };
  367. SymbolVar mask = opr::eq(add_axis(opr.output(0)), opr.input(0)),
  368. og = add_axis(out_grad[0]) / opr::reduce_ax_sum(mask, 1);
  369. return (og * mask).node();
  370. }
  371. if (!out_grad[0])
  372. return nullptr;
  373. return ArgsortBackward::make(out_grad[0], opr.output(1), opr.input(0))
  374. .node();
  375. }
  376. #endif
  377. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台