You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.h 3.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. /**
  2. * \file dnn/include/megdnn/oprs/utils.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megdnn/internal/opr_header_prologue.h"
  13. namespace megdnn {
  14. //! base class for random number generators
  15. class RNGBase: public OperatorBase {
  16. DEF_OPR_IMPL_CTOR(RNGBase, OperatorBase);
  17. public:
  18. virtual void exec(_megdnn_tensor_out dst,
  19. _megdnn_workspace workspace) = 0;
  20. virtual size_t get_workspace_in_bytes(const TensorLayout &dst) = 0;
  21. protected:
  22. void check_exec(const TensorLayout &dst, size_t workspace_in_bytes);
  23. };
  24. //! sample from uniform distribution on the interval (0, 1]
  25. class UniformRNG: public RNGBase {
  26. DEF_OPR_IMPL(UniformRNG, RNGBase, 0, 1);
  27. DEF_OPR_PARAM(UniformRNG);
  28. };
  29. //! sample from gaussian distribution
  30. class GaussianRNG: public RNGBase {
  31. DEF_OPR_IMPL(GaussianRNG, RNGBase, 0, 1);
  32. DEF_OPR_PARAM(GaussianRNG);
  33. };
  34. /*!
  35. * \brief sleep for specific time on the computing device; useful for testing
  36. * async problems
  37. */
  38. class SleepForward: public OperatorBase {
  39. DEF_OPR_IMPL(SleepForward, OperatorBase, 0, 0);
  40. DEF_OPR_PARAM(Sleep);
  41. public:
  42. virtual void exec() = 0;
  43. };
  44. using Sleep = SleepForward;
  45. /*!
  46. * \brief calculating checksum of a tensor
  47. *
  48. * data must be a one-dimensional contiguous tensor with dtype byte
  49. */
  50. class ChecksumForward: public OperatorBase {
  51. DEF_OPR_PARAM(Empty);
  52. DEF_OPR_IMPL(ChecksumForward, OperatorBase, 0, 1);
  53. public:
  54. using Result = opr_result::Checksum;
  55. virtual size_t get_workspace_in_bytes(const TensorLayout &data) = 0;
  56. virtual Result exec(_megdnn_tensor_in data,
  57. _megdnn_workspace workspace) = 0;
  58. protected:
  59. void check_exec(const TensorLayout &layout, size_t workspace_in_bytes);
  60. };
  61. using Checksum = ChecksumForward;
  62. /*!
  63. * \brief calculating max absolute difference of the two input tensors
  64. *
  65. * src1 and src2 must be a one-dimensional contiguous tensor.
  66. */
  67. class MaxTensorDiff : public OperatorBase {
  68. DEF_OPR_PARAM(Empty);
  69. DEF_OPR_IMPL(MaxTensorDiff, OperatorBase, 0, 2);
  70. public:
  71. virtual size_t get_workspace_in_bytes(const TensorLayout& layout1,
  72. const TensorLayout& layout2) = 0;
  73. virtual float exec(_megdnn_tensor_in src1, _megdnn_tensor_in src2,
  74. _megdnn_workspace workspace) = 0;
  75. protected:
  76. void check_exec(const TensorLayout& layout1,
  77. const TensorLayout& layout2, size_t workspace_in_bytes);
  78. };
  79. /*!
  80. * \brief winograd preprocess opr.
  81. *
  82. * for the detail \see src/fallback/conv_bias/winograd/winograd.h
  83. *
  84. */
  85. class WinogradFilterPreprocess : public OperatorBase {
  86. DEF_OPR_PARAM(Winograd);
  87. DEF_OPR_IMPL(WinogradFilterPreprocess, OperatorBase, 1, 1);
  88. public:
  89. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  90. _megdnn_workspace) = 0;
  91. size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&);
  92. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  93. protected:
  94. void check_exec(const TensorLayout& src, const TensorLayout& dst,
  95. size_t workspace_in_bytes);
  96. };
  97. } // namespace megdnn
  98. #include "megdnn/internal/opr_header_epilogue.h"
  99. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台

Contributors (1)