You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.h 5.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. /**
  2. * \file dnn/include/megdnn/oprs/utils.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megdnn/internal/opr_header_prologue.h"
  13. namespace megdnn {
  14. //! base class for random number generators
  15. class RNGBase: public OperatorBase {
  16. DEF_OPR_IMPL_CTOR(RNGBase, OperatorBase);
  17. public:
  18. virtual void exec(_megdnn_tensor_out dst,
  19. _megdnn_workspace workspace) = 0;
  20. virtual size_t get_workspace_in_bytes(const TensorLayout &dst) = 0;
  21. protected:
  22. virtual void check_exec(const TensorLayout &dst, size_t workspace_in_bytes) = 0;
  23. };
  24. //! sample from poisson distribution
  25. class PoissonRNG: public OperatorBase {
  26. DEF_OPR_IMPL(PoissonRNG, OperatorBase, 1, 1);
  27. DEF_OPR_PARAM(PoissonRNG);
  28. public:
  29. virtual void exec(_megdnn_tensor_in lam,
  30. _megdnn_tensor_out dst,
  31. _megdnn_workspace workspace) = 0;
  32. virtual size_t get_workspace_in_bytes(const TensorLayout &lam,
  33. const TensorLayout &dst) = 0;
  34. protected:
  35. void check_exec(const TensorLayout &lam, const TensorLayout &dst,
  36. size_t workspace_in_bytes);
  37. };
  38. //! sample from beta distribution
  39. class BetaRNG: public OperatorBase {
  40. DEF_OPR_IMPL(BetaRNG, OperatorBase, 2, 1);
  41. DEF_OPR_PARAM(BetaRNG);
  42. public:
  43. virtual void exec(_megdnn_tensor_in alpha,
  44. _megdnn_tensor_in beta,
  45. _megdnn_tensor_out dst,
  46. _megdnn_workspace workspace) = 0;
  47. virtual size_t get_workspace_in_bytes(const TensorLayout &alpha,
  48. const TensorLayout &beta, const TensorLayout &dst) = 0;
  49. protected:
  50. void check_exec(const TensorLayout &alpha, const TensorLayout &beta,
  51. const TensorLayout &dst, size_t workspace_in_bytes);
  52. };
  53. //! sample from gamma distribution
  54. class GammaRNG: public OperatorBase {
  55. DEF_OPR_IMPL(GammaRNG, OperatorBase, 2, 1);
  56. DEF_OPR_PARAM(GammaRNG);
  57. public:
  58. virtual void exec(_megdnn_tensor_in shape,
  59. _megdnn_tensor_in scale,
  60. _megdnn_tensor_out dst,
  61. _megdnn_workspace workspace) = 0;
  62. virtual size_t get_workspace_in_bytes(const TensorLayout &shape,
  63. const TensorLayout &scale, const TensorLayout &dst) = 0;
  64. protected:
  65. void check_exec(const TensorLayout &shape, const TensorLayout &scale,
  66. const TensorLayout &dst, size_t workspace_in_bytes);
  67. };
  68. //! sample from uniform distribution on the interval (0, 1]
  69. class UniformRNG: public RNGBase {
  70. DEF_OPR_IMPL(UniformRNG, RNGBase, 0, 1);
  71. DEF_OPR_PARAM(UniformRNG);
  72. protected:
  73. void check_exec(const TensorLayout &dst, size_t workspace_in_bytes);
  74. };
  75. //! sample from gaussian distribution
  76. class GaussianRNG: public RNGBase {
  77. DEF_OPR_IMPL(GaussianRNG, RNGBase, 0, 1);
  78. DEF_OPR_PARAM(GaussianRNG);
  79. protected:
  80. void check_exec(const TensorLayout &dst, size_t workspace_in_bytes);
  81. };
  82. class PermutationRNG: public RNGBase {
  83. DEF_OPR_IMPL(PermutationRNG, RNGBase, 0, 1);
  84. DEF_OPR_PARAM(PermutationRNG);
  85. protected:
  86. void check_exec(const TensorLayout &dst, size_t workspace_in_bytes);
  87. };
  88. /*!
  89. * \brief sleep for specific time on the computing device; useful for testing
  90. * async problems
  91. */
  92. class SleepForward: public OperatorBase {
  93. DEF_OPR_IMPL(SleepForward, OperatorBase, 0, 0);
  94. DEF_OPR_PARAM(Sleep);
  95. public:
  96. virtual void exec() = 0;
  97. };
  98. using Sleep = SleepForward;
  99. /*!
  100. * \brief calculating checksum of a tensor
  101. *
  102. * data must be a one-dimensional contiguous tensor with dtype byte
  103. */
  104. class ChecksumForward: public OperatorBase {
  105. DEF_OPR_PARAM(Empty);
  106. DEF_OPR_IMPL(ChecksumForward, OperatorBase, 0, 1);
  107. public:
  108. using Result = opr_result::Checksum;
  109. virtual size_t get_workspace_in_bytes(const TensorLayout &data) = 0;
  110. virtual Result exec(_megdnn_tensor_in data,
  111. _megdnn_workspace workspace) = 0;
  112. protected:
  113. void check_exec(const TensorLayout &layout, size_t workspace_in_bytes);
  114. };
  115. using Checksum = ChecksumForward;
  116. /*!
  117. * \brief calculating max absolute difference of the two input tensors
  118. *
  119. * src1 and src2 must be a one-dimensional contiguous tensor.
  120. */
  121. class MaxTensorDiff : public OperatorBase {
  122. DEF_OPR_PARAM(Empty);
  123. DEF_OPR_IMPL(MaxTensorDiff, OperatorBase, 0, 2);
  124. public:
  125. virtual size_t get_workspace_in_bytes(const TensorLayout& layout1,
  126. const TensorLayout& layout2) = 0;
  127. virtual float exec(_megdnn_tensor_in src1, _megdnn_tensor_in src2,
  128. _megdnn_workspace workspace) = 0;
  129. protected:
  130. void check_exec(const TensorLayout& layout1,
  131. const TensorLayout& layout2, size_t workspace_in_bytes);
  132. };
  133. bool check_bias_share_in_channel(const TensorLayout& bias,
  134. const param::ConvBias::Format format);
  135. } // namespace megdnn
  136. #include "megdnn/internal/opr_header_epilogue.h"
  137. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台