You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.h 7.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. /**
  2. * \file dnn/include/megdnn/oprs/utils.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #pragma once
  13. #include "megdnn/internal/opr_header_prologue.h"
  14. namespace megdnn {
  15. //! base class for random number generators
  16. class RNGBase: public OperatorBase {
  17. DEF_OPR_IMPL_CTOR(RNGBase, OperatorBase);
  18. public:
  19. virtual void exec(_megdnn_tensor_out dst,
  20. _megdnn_workspace workspace) = 0;
  21. virtual size_t get_workspace_in_bytes(const TensorLayout &dst) = 0;
  22. protected:
  23. virtual void check_exec(const TensorLayout &dst, size_t workspace_in_bytes) = 0;
  24. };
  25. //! sample from poisson distribution
  26. class PoissonRNG: public OperatorBase {
  27. DEF_OPR_IMPL(PoissonRNG, OperatorBase, 1, 1);
  28. DEF_OPR_PARAM(PoissonRNG);
  29. public:
  30. virtual void exec(_megdnn_tensor_in lam,
  31. _megdnn_tensor_out dst,
  32. _megdnn_workspace workspace) = 0;
  33. virtual size_t get_workspace_in_bytes(const TensorLayout &lam,
  34. const TensorLayout &dst) = 0;
  35. protected:
  36. void check_exec(const TensorLayout &lam, const TensorLayout &dst,
  37. size_t workspace_in_bytes);
  38. };
  39. //! sample from beta distribution
  40. class BetaRNG: public OperatorBase {
  41. DEF_OPR_IMPL(BetaRNG, OperatorBase, 2, 1);
  42. DEF_OPR_PARAM(BetaRNG);
  43. public:
  44. virtual void exec(_megdnn_tensor_in alpha,
  45. _megdnn_tensor_in beta,
  46. _megdnn_tensor_out dst,
  47. _megdnn_workspace workspace) = 0;
  48. virtual size_t get_workspace_in_bytes(const TensorLayout &alpha,
  49. const TensorLayout &beta, const TensorLayout &dst) = 0;
  50. protected:
  51. void check_exec(const TensorLayout &alpha, const TensorLayout &beta,
  52. const TensorLayout &dst, size_t workspace_in_bytes);
  53. };
  54. //! sample from gamma distribution
  55. class GammaRNG: public OperatorBase {
  56. DEF_OPR_IMPL(GammaRNG, OperatorBase, 2, 1);
  57. DEF_OPR_PARAM(GammaRNG);
  58. public:
  59. virtual void exec(_megdnn_tensor_in shape,
  60. _megdnn_tensor_in scale,
  61. _megdnn_tensor_out dst,
  62. _megdnn_workspace workspace) = 0;
  63. virtual size_t get_workspace_in_bytes(const TensorLayout &shape,
  64. const TensorLayout &scale, const TensorLayout &dst) = 0;
  65. protected:
  66. void check_exec(const TensorLayout &shape, const TensorLayout &scale,
  67. const TensorLayout &dst, size_t workspace_in_bytes);
  68. };
  69. //! sample from uniform distribution on the interval (0, 1]
  70. class UniformRNG: public RNGBase {
  71. DEF_OPR_IMPL(UniformRNG, RNGBase, 0, 1);
  72. DEF_OPR_PARAM(UniformRNG);
  73. protected:
  74. void check_exec(const TensorLayout &dst, size_t workspace_in_bytes);
  75. };
  76. //! sample from gaussian distribution
  77. class GaussianRNG: public RNGBase {
  78. DEF_OPR_IMPL(GaussianRNG, RNGBase, 0, 1);
  79. DEF_OPR_PARAM(GaussianRNG);
  80. protected:
  81. void check_exec(const TensorLayout &dst, size_t workspace_in_bytes);
  82. };
  83. class PermutationRNG: public RNGBase {
  84. DEF_OPR_IMPL(PermutationRNG, RNGBase, 0, 1);
  85. DEF_OPR_PARAM(PermutationRNG);
  86. protected:
  87. void check_exec(const TensorLayout &dst, size_t workspace_in_bytes);
  88. };
  89. class ShuffleRNGForward : public OperatorBase {
  90. DEF_OPR_IMPL(ShuffleRNGForward, OperatorBase, 1, 2);
  91. DEF_OPR_PARAM(ShuffleRNG);
  92. public:
  93. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  94. _megdnn_tensor_out indices,
  95. _megdnn_workspace workspace) = 0;
  96. void deduce_layout(const TensorLayout& src, TensorLayout& dst,
  97. TensorLayout& indices);
  98. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  99. const TensorLayout& dst,
  100. const TensorLayout& indices) = 0;
  101. protected:
  102. void check_exec(const TensorLayout& src, const TensorLayout& dst,
  103. const TensorLayout& indices, size_t workspace_in_bytes);
  104. };
  105. using ShuffleRNG = ShuffleRNGForward;
  106. class ShuffleRNGBackward : public OperatorBase {
  107. DEF_OPR_IMPL(ShuffleRNGBackward, OperatorBase, 2, 1);
  108. DEF_OPR_PARAM(ShuffleRNG);
  109. public:
  110. virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in indices,
  111. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  112. virtual size_t get_workspace_in_bytes(const TensorLayout& diff,
  113. const TensorLayout& indices,
  114. const TensorLayout& grad) = 0;
  115. protected:
  116. void check_exec(const TensorLayout& diff, const TensorLayout& indices,
  117. const TensorLayout& grad, size_t workspace_in_bytes);
  118. };
  119. /*!
  120. * \brief sleep for specific time on the computing device; useful for testing
  121. * async problems
  122. */
  123. class SleepForward: public OperatorBase {
  124. DEF_OPR_IMPL(SleepForward, OperatorBase, 0, 0);
  125. DEF_OPR_PARAM(Sleep);
  126. public:
  127. virtual void exec() = 0;
  128. };
  129. using Sleep = SleepForward;
  130. /*!
  131. * \brief calculating checksum of a tensor
  132. *
  133. * data must be a one-dimensional contiguous tensor with dtype byte
  134. */
  135. class ChecksumForward: public OperatorBase {
  136. DEF_OPR_PARAM(Empty);
  137. DEF_OPR_IMPL(ChecksumForward, OperatorBase, 0, 1);
  138. public:
  139. using Result = opr_result::Checksum;
  140. virtual size_t get_workspace_in_bytes(const TensorLayout &data) = 0;
  141. virtual Result exec(_megdnn_tensor_in data,
  142. _megdnn_workspace workspace) = 0;
  143. protected:
  144. void check_exec(const TensorLayout &layout, size_t workspace_in_bytes);
  145. };
  146. using Checksum = ChecksumForward;
  147. /*!
  148. * \brief calculating max absolute difference of the two input tensors
  149. *
  150. * src1 and src2 must be a one-dimensional contiguous tensor.
  151. */
  152. class MaxTensorDiff : public OperatorBase {
  153. DEF_OPR_PARAM(Empty);
  154. DEF_OPR_IMPL(MaxTensorDiff, OperatorBase, 0, 2);
  155. public:
  156. virtual size_t get_workspace_in_bytes(const TensorLayout& layout1,
  157. const TensorLayout& layout2) = 0;
  158. virtual float exec(_megdnn_tensor_in src1, _megdnn_tensor_in src2,
  159. _megdnn_workspace workspace) = 0;
  160. protected:
  161. void check_exec(const TensorLayout& layout1,
  162. const TensorLayout& layout2, size_t workspace_in_bytes);
  163. };
  164. bool check_bias_share_in_channel(const TensorLayout& bias,
  165. const param::ConvBias::Format format);
  166. } // namespace megdnn
  167. #include "megdnn/internal/opr_header_epilogue.h"
  168. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台