You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_impl.cpp 4.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. /**
  2. * \file dnn/src/cuda/dropout/opr_impl.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "src/cuda/dropout/opr_impl.h"
  13. namespace megdnn {
  14. namespace cuda {
  15. using Param = megdnn::Dropout::Param;
  16. struct DropoutTensorDesc : public TensorDesc {
  17. public:
  18. DropoutTensorDesc(const TensorLayout& layout) : TensorDesc() {
  19. set_dropout_desc(layout);
  20. }
  21. void set_dropout_desc(const TensorLayout& layout) {
  22. cudnnDataType_t cudnn_dtype;
  23. switch (layout.dtype.enumv()) {
  24. case DTypeEnum::Float32:
  25. cudnn_dtype = CUDNN_DATA_FLOAT;
  26. break;
  27. case DTypeEnum::Float16:
  28. cudnn_dtype = CUDNN_DATA_HALF;
  29. break;
  30. default:
  31. megdnn_throw("dtype must be float16/float32");
  32. }
  33. cudnn_check(cudnnSetTensor4dDescriptor(
  34. desc, CUDNN_TENSOR_NCHW, cudnn_dtype, 1, 1, 1,
  35. layout.total_nr_elems()));
  36. }
  37. };
  38. size_t DropoutForwardImpl::get_mask_size_in_bytes(const TensorLayout& inp) {
  39. size_t reserve_space_size_in_bytes = 0;
  40. DropoutTensorDesc ddesc(inp);
  41. cudnn_check(
  42. cudnnDropoutGetReserveSpaceSize(ddesc.desc, &reserve_space_size_in_bytes));
  43. return reserve_space_size_in_bytes;
  44. }
  45. void DropoutForwardImpl::exec(
  46. _megdnn_tensor_in inp, _megdnn_tensor_out oup, _megdnn_tensor_out mask,
  47. _megdnn_workspace workspace) {
  48. check_exec(inp.layout, oup.layout, mask.layout, workspace.size);
  49. uint64_t seed = param().seed;
  50. float drop_prob = param().drop_prob;
  51. if (!dropout_status.initialized()) {
  52. dropout_status.set(cudnn_handle(this->handle()), seed, drop_prob);
  53. }
  54. if (dropout_status.drop_prob != drop_prob) {
  55. dropout_status.drop_prob = drop_prob;
  56. dropout_status.restore_desc(cudnn_handle(this->handle()));
  57. }
  58. megdnn_assert(dropout_status.seed == seed);
  59. DropoutTensorDesc inp_desc(inp.layout), oup_desc(oup.layout);
  60. auto&& op_desc = dropout_status.desc;
  61. cudnn_check(cudnnDropoutForward(
  62. cudnn_handle(this->handle()), op_desc.desc, inp_desc.desc, inp.raw_ptr(),
  63. oup_desc.desc, oup.raw_ptr(), mask.raw_ptr(),
  64. mask.layout.total_nr_elems()));
  65. }
  66. void DropoutBackwardImpl::exec(
  67. _megdnn_tensor_in doup, _megdnn_tensor_in mask, _megdnn_tensor_out dinp,
  68. _megdnn_workspace workspace) {
  69. check_exec(doup.layout, mask.layout, dinp.layout, workspace.size);
  70. #if CUDNN_VERSION >= 7000
  71. size_t status_size_in_bytes = 0;
  72. cudnn_check(cudnnDropoutGetStatesSize(
  73. cudnn_handle(this->handle()), &status_size_in_bytes));
  74. DropoutTensorDesc doup_desc(doup.layout), dinp_desc(dinp.layout);
  75. op_desc.restore(
  76. cudnn_handle(this->handle()), param().drop_prob, nullptr,
  77. status_size_in_bytes, 0);
  78. cudnn_check(cudnnDropoutBackward(
  79. cudnn_handle(this->handle()), op_desc.desc, doup_desc.desc, doup.raw_ptr(),
  80. dinp_desc.desc, dinp.raw_ptr(), mask.raw_ptr(),
  81. mask.layout.total_nr_elems()));
  82. #else
  83. uint64_t seed = param().seed;
  84. float drop_prob = param().drop_prob;
  85. if (!dropout_status.initialized()) {
  86. dropout_status.set(cudnn_handle(this->handle()), seed, drop_prob);
  87. }
  88. if (dropout_status.drop_prob != drop_prob) {
  89. dropout_status.drop_prob = drop_prob;
  90. dropout_status.restore_desc(cudnn_handle(this->handle()));
  91. }
  92. auto&& op_desc = dropout_status.desc;
  93. DropoutTensorDesc doup_desc(doup.layout), dinp_desc(dinp.layout);
  94. cudnn_check(cudnnDropoutBackward(
  95. cudnn_handle(this->handle()), op_desc.desc, doup_desc.desc, doup.raw_ptr(),
  96. dinp_desc.desc, dinp.raw_ptr(), mask.raw_ptr(),
  97. mask.layout.total_nr_elems()));
  98. #endif
  99. }
  100. } // namespace cuda
  101. } // namespace megdnn
  102. // vim: syntax=cpp.doxygen