You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.h 6.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. #pragma once
  2. #include "megdnn/internal/opr_header_prologue.h"
  3. namespace megdnn {
  4. //! base class for random number generators
  5. class RNGBase : public OperatorBase {
  6. DEF_OPR_IMPL_CTOR(RNGBase, OperatorBase);
  7. public:
  8. virtual void exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  9. virtual size_t get_workspace_in_bytes(const TensorLayout& dst) = 0;
  10. protected:
  11. virtual void check_exec(const TensorLayout& dst, size_t workspace_in_bytes) = 0;
  12. };
  13. //! sample from poisson distribution
  14. class PoissonRNG : public OperatorBase {
  15. DEF_OPR_IMPL(PoissonRNG, OperatorBase, 1, 1);
  16. DEF_OPR_PARAM(PoissonRNG);
  17. public:
  18. virtual void exec(
  19. _megdnn_tensor_in lam, _megdnn_tensor_out dst,
  20. _megdnn_workspace workspace) = 0;
  21. virtual size_t get_workspace_in_bytes(
  22. const TensorLayout& lam, const TensorLayout& dst) = 0;
  23. protected:
  24. void check_exec(
  25. const TensorLayout& lam, const TensorLayout& dst,
  26. size_t workspace_in_bytes);
  27. };
  28. //! sample from beta distribution
  29. class BetaRNG : public OperatorBase {
  30. DEF_OPR_IMPL(BetaRNG, OperatorBase, 2, 1);
  31. DEF_OPR_PARAM(BetaRNG);
  32. public:
  33. virtual void exec(
  34. _megdnn_tensor_in alpha, _megdnn_tensor_in beta, _megdnn_tensor_out dst,
  35. _megdnn_workspace workspace) = 0;
  36. virtual size_t get_workspace_in_bytes(
  37. const TensorLayout& alpha, const TensorLayout& beta,
  38. const TensorLayout& dst) = 0;
  39. protected:
  40. void check_exec(
  41. const TensorLayout& alpha, const TensorLayout& beta,
  42. const TensorLayout& dst, size_t workspace_in_bytes);
  43. };
  44. //! sample from gamma distribution
  45. class GammaRNG : public OperatorBase {
  46. DEF_OPR_IMPL(GammaRNG, OperatorBase, 2, 1);
  47. DEF_OPR_PARAM(GammaRNG);
  48. public:
  49. virtual void exec(
  50. _megdnn_tensor_in shape, _megdnn_tensor_in scale, _megdnn_tensor_out dst,
  51. _megdnn_workspace workspace) = 0;
  52. virtual size_t get_workspace_in_bytes(
  53. const TensorLayout& shape, const TensorLayout& scale,
  54. const TensorLayout& dst) = 0;
  55. protected:
  56. void check_exec(
  57. const TensorLayout& shape, const TensorLayout& scale,
  58. const TensorLayout& dst, size_t workspace_in_bytes);
  59. };
  60. //! sample from uniform distribution on the interval (0, 1]
  61. class UniformRNG : public RNGBase {
  62. DEF_OPR_IMPL(UniformRNG, RNGBase, 0, 1);
  63. DEF_OPR_PARAM(UniformRNG);
  64. protected:
  65. void check_exec(const TensorLayout& dst, size_t workspace_in_bytes);
  66. };
  67. //! sample from gaussian distribution
  68. class GaussianRNG : public RNGBase {
  69. DEF_OPR_IMPL(GaussianRNG, RNGBase, 0, 1);
  70. DEF_OPR_PARAM(GaussianRNG);
  71. protected:
  72. void check_exec(const TensorLayout& dst, size_t workspace_in_bytes);
  73. };
  74. class PermutationRNG : public RNGBase {
  75. DEF_OPR_IMPL(PermutationRNG, RNGBase, 0, 1);
  76. DEF_OPR_PARAM(PermutationRNG);
  77. protected:
  78. void check_exec(const TensorLayout& dst, size_t workspace_in_bytes);
  79. };
  80. class ShuffleRNGForward : public OperatorBase {
  81. DEF_OPR_IMPL(ShuffleRNGForward, OperatorBase, 1, 2);
  82. DEF_OPR_PARAM(ShuffleRNG);
  83. public:
  84. virtual void exec(
  85. _megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_tensor_out indices,
  86. _megdnn_workspace workspace) = 0;
  87. void deduce_layout(
  88. const TensorLayout& src, TensorLayout& dst, TensorLayout& indices);
  89. virtual size_t get_workspace_in_bytes(
  90. const TensorLayout& src, const TensorLayout& dst,
  91. const TensorLayout& indices) = 0;
  92. protected:
  93. void check_exec(
  94. const TensorLayout& src, const TensorLayout& dst,
  95. const TensorLayout& indices, size_t workspace_in_bytes);
  96. };
  97. using ShuffleRNG = ShuffleRNGForward;
  98. class ShuffleRNGBackward : public OperatorBase {
  99. DEF_OPR_IMPL(ShuffleRNGBackward, OperatorBase, 2, 1);
  100. DEF_OPR_PARAM(ShuffleRNG);
  101. public:
  102. virtual void exec(
  103. _megdnn_tensor_in diff, _megdnn_tensor_in indices, _megdnn_tensor_out grad,
  104. _megdnn_workspace workspace) = 0;
  105. virtual size_t get_workspace_in_bytes(
  106. const TensorLayout& diff, const TensorLayout& indices,
  107. const TensorLayout& grad) = 0;
  108. protected:
  109. void check_exec(
  110. const TensorLayout& diff, const TensorLayout& indices,
  111. const TensorLayout& grad, size_t workspace_in_bytes);
  112. };
  113. /*!
  114. * \brief sleep for specific time on the computing device; useful for testing
  115. * async problems
  116. */
  117. class SleepForward : public OperatorBase {
  118. DEF_OPR_IMPL(SleepForward, OperatorBase, 0, 0);
  119. DEF_OPR_PARAM(Sleep);
  120. public:
  121. virtual void exec() = 0;
  122. };
  123. using Sleep = SleepForward;
  124. /*!
  125. * \brief calculating checksum of a tensor
  126. *
  127. * data must be a one-dimensional contiguous tensor with dtype byte
  128. */
  129. class ChecksumForward : public OperatorBase {
  130. DEF_OPR_PARAM(Empty);
  131. DEF_OPR_IMPL(ChecksumForward, OperatorBase, 0, 1);
  132. public:
  133. using Result = opr_result::Checksum;
  134. virtual size_t get_workspace_in_bytes(const TensorLayout& data) = 0;
  135. virtual Result exec(_megdnn_tensor_in data, _megdnn_workspace workspace) = 0;
  136. protected:
  137. void check_exec(const TensorLayout& layout, size_t workspace_in_bytes);
  138. };
  139. using Checksum = ChecksumForward;
  140. /*!
  141. * \brief calculating max absolute difference of the two input tensors
  142. *
  143. * src1 and src2 must be a one-dimensional contiguous tensor.
  144. */
  145. class MaxTensorDiff : public OperatorBase {
  146. DEF_OPR_PARAM(Empty);
  147. DEF_OPR_IMPL(MaxTensorDiff, OperatorBase, 0, 2);
  148. public:
  149. virtual size_t get_workspace_in_bytes(
  150. const TensorLayout& layout1, const TensorLayout& layout2) = 0;
  151. virtual float exec(
  152. _megdnn_tensor_in src1, _megdnn_tensor_in src2,
  153. _megdnn_workspace workspace) = 0;
  154. protected:
  155. void check_exec(
  156. const TensorLayout& layout1, const TensorLayout& layout2,
  157. size_t workspace_in_bytes);
  158. };
  159. bool check_bias_share_in_channel(
  160. const TensorLayout& bias, const param::ConvBias::Format format);
  161. } // namespace megdnn
  162. #include "megdnn/internal/opr_header_epilogue.h"
  163. // vim: syntax=cpp.doxygen