You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.h 3.0 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. /**
  2. * \file dnn/include/megdnn/oprs/utils.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megdnn/internal/opr_header_prologue.h"
  13. namespace megdnn {
  14. //! base class for random number generators
  15. class RNGBase: public OperatorBase {
  16. DEF_OPR_IMPL_CTOR(RNGBase, OperatorBase);
  17. public:
  18. virtual void exec(_megdnn_tensor_out dst,
  19. _megdnn_workspace workspace) = 0;
  20. virtual size_t get_workspace_in_bytes(const TensorLayout &dst) = 0;
  21. protected:
  22. void check_exec(const TensorLayout &dst, size_t workspace_in_bytes);
  23. };
  24. //! sample from uniform distribution on the interval (0, 1]
  25. class UniformRNG: public RNGBase {
  26. DEF_OPR_IMPL(UniformRNG, RNGBase, 0, 1);
  27. DEF_OPR_PARAM(UniformRNG);
  28. };
  29. //! sample from gaussian distribution
  30. class GaussianRNG: public RNGBase {
  31. DEF_OPR_IMPL(GaussianRNG, RNGBase, 0, 1);
  32. DEF_OPR_PARAM(GaussianRNG);
  33. };
  34. /*!
  35. * \brief sleep for specific time on the computing device; useful for testing
  36. * async problems
  37. */
  38. class SleepForward: public OperatorBase {
  39. DEF_OPR_IMPL(SleepForward, OperatorBase, 0, 0);
  40. DEF_OPR_PARAM(Sleep);
  41. public:
  42. virtual void exec() = 0;
  43. };
  44. using Sleep = SleepForward;
  45. /*!
  46. * \brief calculating checksum of a tensor
  47. *
  48. * data must be a one-dimensional contiguous tensor with dtype byte
  49. */
  50. class ChecksumForward: public OperatorBase {
  51. DEF_OPR_PARAM(Empty);
  52. DEF_OPR_IMPL(ChecksumForward, OperatorBase, 0, 1);
  53. public:
  54. using Result = opr_result::Checksum;
  55. virtual size_t get_workspace_in_bytes(const TensorLayout &data) = 0;
  56. virtual Result exec(_megdnn_tensor_in data,
  57. _megdnn_workspace workspace) = 0;
  58. protected:
  59. void check_exec(const TensorLayout &layout, size_t workspace_in_bytes);
  60. };
  61. using Checksum = ChecksumForward;
  62. /*!
  63. * \brief calculating max absolute difference of the two input tensors
  64. *
  65. * src1 and src2 must be a one-dimensional contiguous tensor.
  66. */
  67. class MaxTensorDiff : public OperatorBase {
  68. DEF_OPR_PARAM(Empty);
  69. DEF_OPR_IMPL(MaxTensorDiff, OperatorBase, 0, 2);
  70. public:
  71. virtual size_t get_workspace_in_bytes(const TensorLayout& layout1,
  72. const TensorLayout& layout2) = 0;
  73. virtual float exec(_megdnn_tensor_in src1, _megdnn_tensor_in src2,
  74. _megdnn_workspace workspace) = 0;
  75. protected:
  76. void check_exec(const TensorLayout& layout1,
  77. const TensorLayout& layout2, size_t workspace_in_bytes);
  78. };
  79. } // namespace megdnn
  80. #include "megdnn/internal/opr_header_epilogue.h"
  81. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台