You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.h 6.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. /**
  2. * \file dnn/include/megdnn/oprs/utils.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #pragma once
  13. #include "megdnn/internal/opr_header_prologue.h"
  14. namespace megdnn {
  15. //! base class for random number generators
  16. class RNGBase : public OperatorBase {
  17. DEF_OPR_IMPL_CTOR(RNGBase, OperatorBase);
  18. public:
  19. virtual void exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  20. virtual size_t get_workspace_in_bytes(const TensorLayout& dst) = 0;
  21. protected:
  22. virtual void check_exec(const TensorLayout& dst, size_t workspace_in_bytes) = 0;
  23. };
  24. //! sample from poisson distribution
  25. class PoissonRNG : public OperatorBase {
  26. DEF_OPR_IMPL(PoissonRNG, OperatorBase, 1, 1);
  27. DEF_OPR_PARAM(PoissonRNG);
  28. public:
  29. virtual void exec(
  30. _megdnn_tensor_in lam, _megdnn_tensor_out dst,
  31. _megdnn_workspace workspace) = 0;
  32. virtual size_t get_workspace_in_bytes(
  33. const TensorLayout& lam, const TensorLayout& dst) = 0;
  34. protected:
  35. void check_exec(
  36. const TensorLayout& lam, const TensorLayout& dst,
  37. size_t workspace_in_bytes);
  38. };
  39. //! sample from beta distribution
  40. class BetaRNG : public OperatorBase {
  41. DEF_OPR_IMPL(BetaRNG, OperatorBase, 2, 1);
  42. DEF_OPR_PARAM(BetaRNG);
  43. public:
  44. virtual void exec(
  45. _megdnn_tensor_in alpha, _megdnn_tensor_in beta, _megdnn_tensor_out dst,
  46. _megdnn_workspace workspace) = 0;
  47. virtual size_t get_workspace_in_bytes(
  48. const TensorLayout& alpha, const TensorLayout& beta,
  49. const TensorLayout& dst) = 0;
  50. protected:
  51. void check_exec(
  52. const TensorLayout& alpha, const TensorLayout& beta,
  53. const TensorLayout& dst, size_t workspace_in_bytes);
  54. };
  55. //! sample from gamma distribution
  56. class GammaRNG : public OperatorBase {
  57. DEF_OPR_IMPL(GammaRNG, OperatorBase, 2, 1);
  58. DEF_OPR_PARAM(GammaRNG);
  59. public:
  60. virtual void exec(
  61. _megdnn_tensor_in shape, _megdnn_tensor_in scale, _megdnn_tensor_out dst,
  62. _megdnn_workspace workspace) = 0;
  63. virtual size_t get_workspace_in_bytes(
  64. const TensorLayout& shape, const TensorLayout& scale,
  65. const TensorLayout& dst) = 0;
  66. protected:
  67. void check_exec(
  68. const TensorLayout& shape, const TensorLayout& scale,
  69. const TensorLayout& dst, size_t workspace_in_bytes);
  70. };
  71. //! sample from uniform distribution on the interval (0, 1]
  72. class UniformRNG : public RNGBase {
  73. DEF_OPR_IMPL(UniformRNG, RNGBase, 0, 1);
  74. DEF_OPR_PARAM(UniformRNG);
  75. protected:
  76. void check_exec(const TensorLayout& dst, size_t workspace_in_bytes);
  77. };
  78. //! sample from gaussian distribution
  79. class GaussianRNG : public RNGBase {
  80. DEF_OPR_IMPL(GaussianRNG, RNGBase, 0, 1);
  81. DEF_OPR_PARAM(GaussianRNG);
  82. protected:
  83. void check_exec(const TensorLayout& dst, size_t workspace_in_bytes);
  84. };
  85. class PermutationRNG : public RNGBase {
  86. DEF_OPR_IMPL(PermutationRNG, RNGBase, 0, 1);
  87. DEF_OPR_PARAM(PermutationRNG);
  88. protected:
  89. void check_exec(const TensorLayout& dst, size_t workspace_in_bytes);
  90. };
  91. class ShuffleRNGForward : public OperatorBase {
  92. DEF_OPR_IMPL(ShuffleRNGForward, OperatorBase, 1, 2);
  93. DEF_OPR_PARAM(ShuffleRNG);
  94. public:
  95. virtual void exec(
  96. _megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_tensor_out indices,
  97. _megdnn_workspace workspace) = 0;
  98. void deduce_layout(
  99. const TensorLayout& src, TensorLayout& dst, TensorLayout& indices);
  100. virtual size_t get_workspace_in_bytes(
  101. const TensorLayout& src, const TensorLayout& dst,
  102. const TensorLayout& indices) = 0;
  103. protected:
  104. void check_exec(
  105. const TensorLayout& src, const TensorLayout& dst,
  106. const TensorLayout& indices, size_t workspace_in_bytes);
  107. };
  108. using ShuffleRNG = ShuffleRNGForward;
  109. class ShuffleRNGBackward : public OperatorBase {
  110. DEF_OPR_IMPL(ShuffleRNGBackward, OperatorBase, 2, 1);
  111. DEF_OPR_PARAM(ShuffleRNG);
  112. public:
  113. virtual void exec(
  114. _megdnn_tensor_in diff, _megdnn_tensor_in indices, _megdnn_tensor_out grad,
  115. _megdnn_workspace workspace) = 0;
  116. virtual size_t get_workspace_in_bytes(
  117. const TensorLayout& diff, const TensorLayout& indices,
  118. const TensorLayout& grad) = 0;
  119. protected:
  120. void check_exec(
  121. const TensorLayout& diff, const TensorLayout& indices,
  122. const TensorLayout& grad, size_t workspace_in_bytes);
  123. };
  124. /*!
  125. * \brief sleep for specific time on the computing device; useful for testing
  126. * async problems
  127. */
  128. class SleepForward : public OperatorBase {
  129. DEF_OPR_IMPL(SleepForward, OperatorBase, 0, 0);
  130. DEF_OPR_PARAM(Sleep);
  131. public:
  132. virtual void exec() = 0;
  133. };
  134. using Sleep = SleepForward;
  135. /*!
  136. * \brief calculating checksum of a tensor
  137. *
  138. * data must be a one-dimensional contiguous tensor with dtype byte
  139. */
  140. class ChecksumForward : public OperatorBase {
  141. DEF_OPR_PARAM(Empty);
  142. DEF_OPR_IMPL(ChecksumForward, OperatorBase, 0, 1);
  143. public:
  144. using Result = opr_result::Checksum;
  145. virtual size_t get_workspace_in_bytes(const TensorLayout& data) = 0;
  146. virtual Result exec(_megdnn_tensor_in data, _megdnn_workspace workspace) = 0;
  147. protected:
  148. void check_exec(const TensorLayout& layout, size_t workspace_in_bytes);
  149. };
  150. using Checksum = ChecksumForward;
  151. /*!
  152. * \brief calculating max absolute difference of the two input tensors
  153. *
  154. * src1 and src2 must be a one-dimensional contiguous tensor.
  155. */
  156. class MaxTensorDiff : public OperatorBase {
  157. DEF_OPR_PARAM(Empty);
  158. DEF_OPR_IMPL(MaxTensorDiff, OperatorBase, 0, 2);
  159. public:
  160. virtual size_t get_workspace_in_bytes(
  161. const TensorLayout& layout1, const TensorLayout& layout2) = 0;
  162. virtual float exec(
  163. _megdnn_tensor_in src1, _megdnn_tensor_in src2,
  164. _megdnn_workspace workspace) = 0;
  165. protected:
  166. void check_exec(
  167. const TensorLayout& layout1, const TensorLayout& layout2,
  168. size_t workspace_in_bytes);
  169. };
  170. bool check_bias_share_in_channel(
  171. const TensorLayout& bias, const param::ConvBias::Format format);
  172. } // namespace megdnn
  173. #include "megdnn/internal/opr_header_epilogue.h"
  174. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台