You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

rng.cpp 24 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666
  1. /**
  2. * \file imperative/src/impl/ops/rng.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "megbrain/imperative/ops/rng.h"
  13. #include "megbrain/comp_node_env.h"
  14. #include "megbrain/graph/helper.h"
  15. #include "megbrain/opr/rand.h"
  16. #include "../dnn_op_helper.h"
  17. #include "../op_trait.h"
  18. namespace mgb::imperative::rng {
  19. namespace {
  20. template <typename HandleFactory, typename THandle>
  21. class DnnOpManagerT : public CompNodeDepedentObject, public NonCopyableObj {
  22. public:
  23. using DT = CompNode::DeviceType;
  24. using Handle = THandle;
  25. using OpTypeInfo = size_t;
  26. template <typename... Args>
  27. Handle new_handle(Args&&... args) {
  28. return static_cast<HandleFactory*>(this)->do_new_handle(
  29. std::forward<Args>(args)...);
  30. }
  31. size_t delete_handle(Handle handle) {
  32. size_t removed = 0;
  33. if (!is_finalized()) {
  34. MGB_LOCK_GUARD(m_mtx);
  35. removed = m_handle2ops.erase(handle);
  36. }
  37. static_cast<HandleFactory*>(this)->do_delete_handle(handle);
  38. return removed;
  39. }
  40. template <typename DnnOp>
  41. auto get_dnn_op(Handle handle, OpTypeInfo tpinfo, CompNode cn) {
  42. mgb_assert(!is_finalized());
  43. DnnOpWithMutex* dnn_op_with_mtx;
  44. {
  45. MGB_LOCK_GUARD(m_mtx);
  46. dnn_op_with_mtx = &m_handle2ops[handle][tpinfo];
  47. }
  48. auto dnn_handle = MegDNNHandle::get(CompNodeEnv::from_comp_node(cn)).handle();
  49. std::unique_lock<std::mutex> lock(dnn_op_with_mtx->mtx);
  50. bool initialized = false;
  51. DnnOp* dnn_op = static_cast<DnnOp*>(dnn_op_with_mtx->op.get());
  52. if (dnn_op != nullptr) {
  53. mgb_assert(dnn_op->handle() == dnn_handle);
  54. initialized = true;
  55. } else {
  56. auto new_op = dnn_handle->create_operator<DnnOp>();
  57. dnn_op = new_op.get();
  58. dnn_op_with_mtx->op = std::move(new_op);
  59. }
  60. return std::make_tuple(initialized, dnn_op, std::move(lock));
  61. }
  62. protected:
  63. using DnnOpManagerBase = DnnOpManagerT<HandleFactory, Handle>;
  64. DnnOpManagerT() = default;
  65. private:
  66. struct DnnOpWithMutex {
  67. std::mutex mtx;
  68. std::unique_ptr<megdnn::OperatorBase> op;
  69. DnnOpWithMutex() : op{nullptr} {}
  70. };
  71. std::shared_ptr<void> on_comp_node_finalize() override {
  72. MGB_LOCK_GUARD(m_mtx);
  73. m_handle2ops.clear();
  74. return {};
  75. }
  76. std::unordered_map<Handle, std::unordered_map<OpTypeInfo, DnnOpWithMutex>>
  77. m_handle2ops;
  78. std::mutex m_mtx;
  79. };
  80. class RNGDnnOpManager final : public DnnOpManagerT<RNGDnnOpManager, Handle> {
  81. public:
  82. Handle new_handle(CompNode comp_node, uint64_t seed) {
  83. MGB_LOCK_GUARD(sm_mtx);
  84. return DnnOpManagerBase::new_handle(comp_node, seed);
  85. }
  86. size_t delete_handle(Handle handle) {
  87. MGB_LOCK_GUARD(sm_mtx);
  88. return DnnOpManagerBase::delete_handle(handle);
  89. }
  90. Handle do_new_handle(CompNode comp_node, uint64_t seed) {
  91. auto handle = m_handle_pool.alloc(comp_node, seed);
  92. return reinterpret_cast<Handle>(handle);
  93. }
  94. void do_delete_handle(Handle handle) {
  95. m_handle_pool.free(reinterpret_cast<HandleData*>(handle));
  96. }
  97. static uint64_t get_seed(Handle handle) {
  98. if (!handle) {
  99. return glob_default_seed;
  100. }
  101. return reinterpret_cast<HandleData*>(handle)->seed;
  102. }
  103. static CompNode get_comp_node(Handle handle) {
  104. mgb_assert(handle, "invalid handle");
  105. return reinterpret_cast<HandleData*>(handle)->comp_node;
  106. }
  107. static Handle get_default_handle(CompNode comp_node) {
  108. mgb_assert(comp_node.valid());
  109. MGB_LOCK_GUARD(sm_mtx);
  110. auto&& glob_handle = glob_default_handles[comp_node];
  111. if (!glob_handle) {
  112. glob_handle = inst().do_new_handle(comp_node, glob_default_seed);
  113. }
  114. mgb_assert(get_seed(glob_handle) == glob_default_seed);
  115. return glob_handle;
  116. }
  117. static RNGDnnOpManager& inst() {
  118. static RNGDnnOpManager mgr;
  119. return mgr;
  120. }
  121. static void set_glob_default_seed(uint64_t seed) {
  122. MGB_LOCK_GUARD(sm_mtx);
  123. for (auto&& elem : glob_default_handles) {
  124. mgb_assert(elem.first.valid());
  125. if (elem.second) {
  126. inst().DnnOpManagerBase::delete_handle(elem.second);
  127. }
  128. elem.second = inst().do_new_handle(elem.first, seed);
  129. }
  130. glob_default_seed = seed;
  131. }
  132. static uint64_t get_glob_default_seed() {
  133. MGB_LOCK_GUARD(sm_mtx);
  134. return glob_default_seed;
  135. }
  136. private:
  137. struct HandleData {
  138. CompNode comp_node;
  139. uint64_t seed;
  140. HandleData(CompNode cn, uint64_t seed) : comp_node(cn), seed(seed) {}
  141. };
  142. MemPool<HandleData> m_handle_pool;
  143. static std::mutex sm_mtx;
  144. static CompNode::UnorderedMap<Handle> glob_default_handles;
  145. static uint64_t glob_default_seed;
  146. };
  147. uint64_t RNGDnnOpManager::glob_default_seed = 0;
  148. std::mutex RNGDnnOpManager::sm_mtx;
  149. CompNode::UnorderedMap<Handle> RNGDnnOpManager::glob_default_handles;
  150. template <typename Op>
  151. struct OpMeth;
  152. template <>
  153. struct OpMeth<UniformRNG> {
  154. using DnnOp = megdnn::UniformRNG;
  155. using Param = DnnOp::Param;
  156. using OpNode = mgb::opr::UniformRNG;
  157. static Param make_param(const UniformRNG& rng) {
  158. auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
  159. mgb_assert(
  160. handle_seed == rng.seed,
  161. "inconsistent rng seed: rng op: %lu handle: %lu", handle_seed,
  162. rng.seed);
  163. return {handle_seed, rng.dtype.enumv()};
  164. }
  165. };
  166. template <>
  167. struct OpMeth<PoissonRNG> {
  168. using DnnOp = megdnn::PoissonRNG;
  169. using Param = DnnOp::Param;
  170. using OpNode = mgb::opr::PoissonRNG;
  171. static Param make_param(const PoissonRNG& rng) {
  172. auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
  173. mgb_assert(
  174. handle_seed == rng.seed,
  175. "inconsistent rng seed: rng op: %lu handle: %lu", handle_seed,
  176. rng.seed);
  177. return {handle_seed};
  178. }
  179. };
  180. template <>
  181. struct OpMeth<GaussianRNG> {
  182. using DnnOp = megdnn::GaussianRNG;
  183. using Param = DnnOp::Param;
  184. using OpNode = mgb::opr::GaussianRNG;
  185. static Param make_param(const GaussianRNG& rng) {
  186. auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
  187. mgb_assert(
  188. handle_seed == rng.seed,
  189. "inconsistent rng seed: rng op: %lu handle: %lu", handle_seed,
  190. rng.seed);
  191. return {handle_seed, rng.mean, rng.std, rng.dtype.enumv()};
  192. }
  193. };
  194. template <>
  195. struct OpMeth<GammaRNG> {
  196. using DnnOp = megdnn::GammaRNG;
  197. using Param = DnnOp::Param;
  198. using OpNode = mgb::opr::GammaRNG;
  199. static Param make_param(const GammaRNG& rng) {
  200. auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
  201. mgb_assert(
  202. handle_seed == rng.seed,
  203. "inconsistent rng seed: rng op: %lu handle: %lu", handle_seed,
  204. rng.seed);
  205. return {handle_seed};
  206. }
  207. };
  208. template <>
  209. struct OpMeth<PermutationRNG> {
  210. using DnnOp = megdnn::PermutationRNG;
  211. using Param = DnnOp::Param;
  212. using OpNode = mgb::opr::PermutationRNG;
  213. static Param make_param(const PermutationRNG& rng) {
  214. auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
  215. mgb_assert(
  216. handle_seed == rng.seed,
  217. "inconsistent rng seed: rng op: %lu handle: %lu", handle_seed,
  218. rng.seed);
  219. return {handle_seed, rng.dtype.enumv()};
  220. }
  221. };
  222. template <>
  223. struct OpMeth<BetaRNG> {
  224. using DnnOp = megdnn::BetaRNG;
  225. using Param = DnnOp::Param;
  226. using OpNode = mgb::opr::BetaRNG;
  227. static Param make_param(const BetaRNG& rng) {
  228. auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
  229. mgb_assert(
  230. handle_seed == rng.seed,
  231. "inconsistent rng seed: rng op: %lu handle: %lu", handle_seed,
  232. rng.seed);
  233. return {handle_seed};
  234. }
  235. };
  236. template <>
  237. struct OpMeth<ShuffleRNG> {
  238. using DnnOp = megdnn::ShuffleRNG;
  239. using Param = DnnOp::Param;
  240. using OpNode = mgb::opr::ShuffleRNG;
  241. static Param make_param(const ShuffleRNG& rng) {
  242. auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
  243. mgb_assert(
  244. handle_seed == rng.seed,
  245. "inconsistent rng seed: rng op: %lu handle: %lu", handle_seed,
  246. rng.seed);
  247. return {handle_seed};
  248. }
  249. };
  250. template <>
  251. struct OpMeth<Dropout> {
  252. using DnnOp = megdnn::Dropout;
  253. using Param = DnnOp::Param;
  254. using OpNode = mgb::opr::Dropout;
  255. static Param make_param(const Dropout& opdef) {
  256. auto handle_seed = RNGDnnOpManager::get_seed(opdef.handle);
  257. mgb_assert(
  258. handle_seed == opdef.seed,
  259. "inconsistent dropout seed: dropout op: %lu handle: %lu", handle_seed,
  260. opdef.seed);
  261. return {opdef.drop_prob, handle_seed};
  262. }
  263. };
  264. template <bool>
  265. struct _InferLayout;
  266. template <int nr_in>
  267. struct _RNGOprMaker;
  268. template <int nr_in, int nr_out>
  269. struct _RNGOprInvoker;
  270. template <>
  271. struct _InferLayout<true> {
  272. template <typename Op>
  273. static TensorLayout do_infer(const TensorPtr& inp, const Op& rng) {
  274. TensorShape tshape;
  275. auto hv = inp->get_value().proxy_to_default_cpu();
  276. cg::copy_tensor_value_to_shape(tshape, hv);
  277. return TensorLayout(tshape, rng.dtype);
  278. }
  279. template <typename Op>
  280. static TensorLayout do_infer(const LogicalTensorDesc& inp, const Op& rng) {
  281. TensorLayout out_layout = inp.layout;
  282. out_layout.dtype = rng.dtype;
  283. if (inp.layout.ndim == 0 || inp.value.empty()) {
  284. out_layout.ndim = 0;
  285. return out_layout;
  286. }
  287. mgb_assert(
  288. inp.layout.ndim == 1,
  289. "target shape of %s expects ndim=1; got ndim=%lu actually",
  290. rng.dyn_typeinfo()->name, inp.layout.ndim);
  291. size_t target_ndim = inp.layout.shape[0];
  292. out_layout.ndim = target_ndim;
  293. auto* ptr = inp.value.ptr<dt_int32>();
  294. for (size_t i = 0; i < target_ndim; ++i) {
  295. out_layout.shape[i] = ptr[i];
  296. }
  297. out_layout.init_contiguous_stride();
  298. return out_layout;
  299. }
  300. };
  301. template <>
  302. struct _InferLayout<false> {
  303. template <typename Op>
  304. static TensorLayout do_infer(const TensorPtr& inp, const Op& rng) {
  305. return inp->layout();
  306. }
  307. template <typename Op>
  308. static TensorLayout do_infer(const LogicalTensorDesc& inp, const Op& rng) {
  309. mgb_assert(inp.layout.ndim);
  310. return inp.layout;
  311. }
  312. };
  313. #define _INST_RNG_INVOLKER(DNN_NR_INPUTS, DNN_NR_OUTPUTS) \
  314. template <> \
  315. struct _RNGOprInvoker<DNN_NR_INPUTS, DNN_NR_OUTPUTS> { \
  316. template <typename Opr> \
  317. static void exec( \
  318. Opr* dnn_op, const SmallVector<TensorPtr>& inputs, \
  319. const SmallVector<TensorPtr>& outputs) { \
  320. size_t wk_size = 0; \
  321. wk_size = dnn_op->get_workspace_in_bytes( \
  322. _FOR_EACH_IN(->layout()) _FOR_EACH_OUT(->layout())); \
  323. auto workspace = Blob::make(outputs[0]->comp_node(), wk_size); \
  324. megdnn::Workspace dnn_wk(workspace->storage().get(), wk_size); \
  325. dnn_op->exec( \
  326. _FOR_EACH_IN(->dev_tensor().as_megdnn()) \
  327. _FOR_EACH_OUT(->dev_tensor().as_megdnn()), \
  328. dnn_wk); \
  329. } \
  330. };
  331. #define _INST_RNG_MAKER(MGB_NR_INPUTS) \
  332. template <> \
  333. struct _RNGOprMaker<MGB_NR_INPUTS> { \
  334. template <typename Op> \
  335. static auto make(const VarNodeArray& inputs, const Op& rng) { \
  336. auto param = OpMeth<Op>::make_param(rng); \
  337. OperatorNodeConfig config; \
  338. if (rng.handle) { \
  339. config = { \
  340. rng.make_name(), RNGDnnOpManager::get_comp_node(rng.handle)}; \
  341. } else { \
  342. config = {rng.make_name()}; \
  343. } \
  344. return OpMeth<Op>::OpNode::make(_FOR_EACH_IN() param, config); \
  345. } \
  346. };
  347. #define _FOR_EACH_IN(subfix)
  348. #define _FOR_EACH_OUT(subfix) outputs[0] subfix
  349. _INST_RNG_INVOLKER(0, 1)
  350. #undef _FOR_EACH_OUT
  351. #undef _FOR_EACH_IN
  352. #define _FOR_EACH_IN(subfix) inputs[0] subfix,
  353. #define _FOR_EACH_OUT(subfix) outputs[0] subfix
  354. _INST_RNG_INVOLKER(1, 1)
  355. #undef _FOR_EACH_OUT
  356. #define _FOR_EACH_OUT(subfix) outputs[0] subfix, outputs[1] subfix
  357. _INST_RNG_INVOLKER(1, 2)
  358. _INST_RNG_MAKER(1)
  359. #undef _FOR_EACH_OUT
  360. #undef _FOR_EACH_IN
  361. #define _FOR_EACH_IN(subfix) inputs[0] subfix, inputs[1] subfix,
  362. #define _FOR_EACH_OUT(subfix) outputs[0] subfix
  363. _INST_RNG_INVOLKER(2, 1)
  364. _INST_RNG_MAKER(2)
  365. #undef _FOR_EACH_OUT
  366. #undef _FOR_EACH_IN
  367. #undef _INST_RNG_INVOLKER
  368. #undef _INST_RNG_MAKER
  369. template <typename Op>
  370. void exec(
  371. const OpDef& op, const SmallVector<TensorPtr>& inputs,
  372. const SmallVector<TensorPtr>& outputs,
  373. const SmallVector<TensorPtr>& workspace) {
  374. auto&& rng = op.cast_final_safe<Op>();
  375. auto dest = outputs[0];
  376. if (dest->layout().is_empty())
  377. return;
  378. auto cn = dest->comp_node();
  379. auto handle = rng.handle;
  380. if (!handle) {
  381. handle = RNGDnnOpManager::get_default_handle(cn);
  382. }
  383. // retrieve dnn_op from glob cache
  384. auto dnn_op_thread_safe =
  385. RNGDnnOpManager::inst().get_dnn_op<typename OpMeth<Op>::DnnOp>(
  386. handle, reinterpret_cast<size_t>(op.dyn_typeinfo()), cn);
  387. auto initialized = std::get<0>(dnn_op_thread_safe);
  388. auto dnn_op = std::get<1>(dnn_op_thread_safe);
  389. if (initialized) {
  390. auto handle_seed = RNGDnnOpManager::get_seed(handle);
  391. mgb_assert(
  392. dnn_op->param().seed == handle_seed,
  393. "inconsistent rng seed: handle: %lu, dnn_op: %lu", handle_seed,
  394. dnn_op->param().seed);
  395. }
  396. dnn_op->param() = OpMeth<Op>::make_param(rng);
  397. _RNGOprInvoker<OpMeth<Op>::DnnOp::NR_INPUTS, OpMeth<Op>::DnnOp::NR_OUTPUTS>::exec(
  398. dnn_op, inputs, outputs);
  399. }
  400. template <typename Op>
  401. SmallVector<LogicalTensorDesc> infer_output_attrs(
  402. const OpDef& op, const SmallVector<TensorPtr>& inputs) {
  403. LogicalTensorDesc dest;
  404. auto&& rng = op.cast_final_safe<Op>();
  405. auto handle = rng.handle;
  406. if (handle) {
  407. dest.comp_node = RNGDnnOpManager::get_comp_node(handle);
  408. } else {
  409. dest.comp_node = inputs[0]->comp_node();
  410. }
  411. constexpr bool rng_with_shape = OpMeth<Op>::DnnOp::NR_INPUTS == 0;
  412. if (!rng_with_shape) {
  413. for (int i = 0; i < inputs.size(); ++i) {
  414. mgb_assert(
  415. inputs[i]->comp_node() == dest.comp_node,
  416. "%s expects the device of inputs[%d] to be same as the device of "
  417. "handle; "
  418. "got %s and %s actually",
  419. rng.dyn_typeinfo()->name, i,
  420. inputs[i]->comp_node().to_string().c_str(),
  421. dest.comp_node.to_string().c_str());
  422. }
  423. }
  424. dest.layout = _InferLayout<rng_with_shape>::do_infer(inputs[0], rng);
  425. return {dest};
  426. }
  427. template <>
  428. SmallVector<LogicalTensorDesc> infer_output_attrs<ShuffleRNG>(
  429. const OpDef& op, const SmallVector<TensorPtr>& inputs) {
  430. SmallVector<LogicalTensorDesc> dests(2);
  431. auto&& rng = op.cast_final_safe<ShuffleRNG>();
  432. auto handle = rng.handle;
  433. if (handle) {
  434. dests[0].comp_node = RNGDnnOpManager::get_comp_node(handle);
  435. dests[1].comp_node = RNGDnnOpManager::get_comp_node(handle);
  436. } else {
  437. dests[0].comp_node = inputs[0]->comp_node();
  438. dests[1].comp_node = inputs[0]->comp_node();
  439. }
  440. dests[0].layout = TensorLayout(inputs[0]->layout());
  441. dests[0].layout.dtype = inputs[0]->layout().dtype;
  442. dests[1].layout =
  443. TensorLayout(TensorShape({inputs[0]->layout()[0]}), dtype::Int32());
  444. return dests;
  445. }
  446. template <>
  447. SmallVector<LogicalTensorDesc> infer_output_attrs<Dropout>(
  448. const OpDef& op, const SmallVector<TensorPtr>& inputs) {
  449. SmallVector<LogicalTensorDesc> dests(2);
  450. auto&& cn = inputs[0]->comp_node();
  451. dests[0].comp_node = cn;
  452. dests[0].layout = TensorLayout(inputs[0]->layout());
  453. dests[0].layout.dtype = inputs[0]->layout().dtype;
  454. auto get_mask_size = [&]() -> size_t {
  455. auto dnn_handle = MegDNNHandle::get(CompNodeEnv::from_comp_node(cn)).handle();
  456. return dnn_handle->create_operator<megdnn::Dropout>()->get_mask_size_in_bytes(
  457. inputs[0]->layout());
  458. };
  459. dests[1].comp_node = cn;
  460. dests[1].layout = TensorLayout(TensorShape({get_mask_size()}), dtype::Byte());
  461. return dests;
  462. }
  463. template <typename Op>
  464. SmallVector<TensorPtr> apply_on_physical_tensor(
  465. const OpDef& def, const SmallVector<TensorPtr>& inputs,
  466. SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
  467. SmallVector<TensorPtr> outputs;
  468. SmallVector<LogicalTensorDesc> desc = infer_output_attrs<Op>(def, inputs);
  469. for (auto&& i : desc) {
  470. outputs.push_back(Tensor::make(i.layout, i.comp_node));
  471. }
  472. exec<Op>(def, inputs, outputs, {});
  473. return outputs;
  474. }
  475. template <typename Op, typename Output>
  476. Output apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  477. size_t nr_inp = inputs.size();
  478. constexpr size_t dnn_nr_inp = OpMeth<Op>::DnnOp::NR_INPUTS;
  479. auto&& rng = def.cast_final_safe<Op>();
  480. if (dnn_nr_inp == 0) {
  481. mgb_assert(
  482. nr_inp == 1, "%s expects 1 inputs; got %lu actually",
  483. rng.dyn_typeinfo()->name, nr_inp);
  484. }
  485. constexpr size_t mgb_nr_inp = dnn_nr_inp + !dnn_nr_inp;
  486. return _RNGOprMaker<mgb_nr_inp>::make(inputs, rng);
  487. }
  488. template <typename Op>
  489. std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
  490. const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
  491. bool success = inputs[0].layout.ndim != 0;
  492. LogicalTensorDesc dest;
  493. auto&& xxx_rng_def = def.cast_final_safe<Op>();
  494. size_t nr_inp = inputs.size();
  495. constexpr bool rng_with_shape = OpMeth<Op>::DnnOp::NR_INPUTS == 0;
  496. if (rng_with_shape) {
  497. mgb_assert(
  498. nr_inp == 1, "%s expects 1 inputs; got %lu actually",
  499. xxx_rng_def.dyn_typeinfo()->name, nr_inp);
  500. }
  501. dest.comp_node = inputs[0].comp_node;
  502. if (success) {
  503. dest.layout = _InferLayout<rng_with_shape>::do_infer(inputs[0], xxx_rng_def);
  504. } else {
  505. dest.layout = TensorLayout(inputs[0].layout.dtype);
  506. }
  507. return {{dest}, inputs[0].layout.ndim != 0};
  508. }
  509. template <>
  510. std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible<
  511. ShuffleRNG>(const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
  512. bool success = inputs[0].layout.ndim != 0;
  513. SmallVector<LogicalTensorDesc> dests(2);
  514. dests[0].comp_node = inputs[0].comp_node;
  515. dests[0].layout = TensorLayout(inputs[0].layout);
  516. dests[0].layout.dtype = inputs[0].layout.dtype;
  517. dests[1].comp_node = inputs[0].comp_node;
  518. if (success) {
  519. dests[1].layout =
  520. TensorLayout(TensorShape({inputs[0].layout.shape[0]}), dtype::Int32());
  521. } else {
  522. dests[1].layout = TensorLayout(dtype::Int32());
  523. }
  524. return {dests, success};
  525. }
  526. template <>
  527. std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible<Dropout>(
  528. const OpDef& op, const SmallVector<LogicalTensorDesc>& inputs) {
  529. bool success = inputs[0].layout.ndim != 0;
  530. SmallVector<LogicalTensorDesc> dests(2);
  531. auto cn = inputs[0].comp_node;
  532. dests[0].comp_node = cn;
  533. dests[0].layout = TensorLayout(inputs[0].layout);
  534. dests[0].layout.dtype = inputs[0].layout.dtype;
  535. auto get_mask_size = [&]() -> size_t {
  536. auto dnn_handle = MegDNNHandle::get(CompNodeEnv::from_comp_node(cn)).handle();
  537. return dnn_handle->create_operator<megdnn::Dropout>()->get_mask_size_in_bytes(
  538. inputs[0].layout);
  539. };
  540. dests[1].comp_node = cn;
  541. if (success) {
  542. dests[1].layout = TensorLayout(TensorShape({get_mask_size()}), dtype::Byte());
  543. } else {
  544. dests[1].layout = TensorLayout(dtype::Byte());
  545. }
  546. return {dests, success};
  547. }
  548. template <typename Op>
  549. SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint(
  550. const OpDef& def, const SmallVector<TensorPtr>& inputs) {
  551. SmallVector<VarNode::LayoutConstraintCallback> layout_checker(inputs.size());
  552. return layout_checker;
  553. }
  554. } // anonymous namespace
  555. Handle new_handle(CompNode comp_node, uint64_t seed) {
  556. return RNGDnnOpManager::inst().new_handle(comp_node, seed);
  557. }
  558. size_t delete_handle(Handle handle) {
  559. return RNGDnnOpManager::inst().delete_handle(handle);
  560. }
  561. void set_global_rng_seed(uint64_t seed) {
  562. RNGDnnOpManager::set_glob_default_seed(seed);
  563. }
  564. uint64_t get_global_rng_seed() {
  565. return RNGDnnOpManager::get_glob_default_seed();
  566. }
  567. CompNode get_rng_handle_compnode(Handle handle) {
  568. return RNGDnnOpManager::get_comp_node(handle);
  569. }
  570. #define REG_RNG_OP(NAME, Output) \
  571. namespace { \
  572. OP_TRAIT_REG(NAME, NAME, OpMeth<NAME>::OpNode) \
  573. .apply_on_var_node(apply_on_var_node<NAME, Output>) \
  574. .apply_on_physical_tensor(apply_on_physical_tensor<NAME>) \
  575. .infer_output_attrs_fallible(infer_output_attrs_fallible<NAME>) \
  576. .get_input_layout_constraint(get_input_layout_constraint<NAME>) \
  577. .fallback(); \
  578. }
  579. REG_RNG_OP(UniformRNG, SymbolVar)
  580. REG_RNG_OP(GaussianRNG, SymbolVar)
  581. REG_RNG_OP(GammaRNG, SymbolVar)
  582. REG_RNG_OP(PermutationRNG, SymbolVar)
  583. REG_RNG_OP(PoissonRNG, SymbolVar)
  584. REG_RNG_OP(BetaRNG, SymbolVar)
  585. REG_RNG_OP(ShuffleRNG, SymbolVarArray)
  586. REG_RNG_OP(Dropout, SymbolVarArray)
  587. #undef REG_RNG_OP
  588. } // namespace mgb::imperative::rng
  589. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}