|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118 |
- /**
- * \file dnn/src/cuda/dropout/opr_impl.cpp
- * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- *
- * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
- * implied.
- */
-
- #include "src/cuda/dropout/opr_impl.h"
-
- namespace megdnn {
- namespace cuda {
-
- using Param = megdnn::Dropout::Param;
-
- struct DropoutTensorDesc : public TensorDesc {
- public:
- DropoutTensorDesc(const TensorLayout& layout) : TensorDesc() {
- set_dropout_desc(layout);
- }
- void set_dropout_desc(const TensorLayout& layout) {
- cudnnDataType_t cudnn_dtype;
- switch (layout.dtype.enumv()) {
- case DTypeEnum::Float32:
- cudnn_dtype = CUDNN_DATA_FLOAT;
- break;
- case DTypeEnum::Float16:
- cudnn_dtype = CUDNN_DATA_HALF;
- break;
- default:
- megdnn_throw("dtype must be float16/float32");
- }
- cudnn_check(cudnnSetTensor4dDescriptor(
- desc, CUDNN_TENSOR_NCHW, cudnn_dtype, 1, 1, 1,
- layout.total_nr_elems()));
- }
- };
-
- size_t DropoutForwardImpl::get_mask_size_in_bytes(const TensorLayout& inp) {
- size_t reserve_space_size_in_bytes = 0;
- DropoutTensorDesc ddesc(inp);
- cudnn_check(
- cudnnDropoutGetReserveSpaceSize(ddesc.desc, &reserve_space_size_in_bytes));
- return reserve_space_size_in_bytes;
- }
-
- void DropoutForwardImpl::exec(
- _megdnn_tensor_in inp, _megdnn_tensor_out oup, _megdnn_tensor_out mask,
- _megdnn_workspace workspace) {
- check_exec(inp.layout, oup.layout, mask.layout, workspace.size);
- uint64_t seed = param().seed;
- float drop_prob = param().drop_prob;
-
- if (!dropout_status.initialized()) {
- dropout_status.set(cudnn_handle(this->handle()), seed, drop_prob);
- }
- if (dropout_status.drop_prob != drop_prob) {
- dropout_status.drop_prob = drop_prob;
- dropout_status.restore_desc(cudnn_handle(this->handle()));
- }
- megdnn_assert(dropout_status.seed == seed);
-
- DropoutTensorDesc inp_desc(inp.layout), oup_desc(oup.layout);
- auto&& op_desc = dropout_status.desc;
-
- cudnn_check(cudnnDropoutForward(
- cudnn_handle(this->handle()), op_desc.desc, inp_desc.desc, inp.raw_ptr(),
- oup_desc.desc, oup.raw_ptr(), mask.raw_ptr(),
- mask.layout.total_nr_elems()));
- }
-
- void DropoutBackwardImpl::exec(
- _megdnn_tensor_in doup, _megdnn_tensor_in mask, _megdnn_tensor_out dinp,
- _megdnn_workspace workspace) {
- check_exec(doup.layout, mask.layout, dinp.layout, workspace.size);
-
- #if CUDNN_VERSION >= 7000
- size_t status_size_in_bytes = 0;
- cudnn_check(cudnnDropoutGetStatesSize(
- cudnn_handle(this->handle()), &status_size_in_bytes));
-
- DropoutTensorDesc doup_desc(doup.layout), dinp_desc(dinp.layout);
- op_desc.restore(
- cudnn_handle(this->handle()), param().drop_prob, nullptr,
- status_size_in_bytes, 0);
- cudnn_check(cudnnDropoutBackward(
- cudnn_handle(this->handle()), op_desc.desc, doup_desc.desc, doup.raw_ptr(),
- dinp_desc.desc, dinp.raw_ptr(), mask.raw_ptr(),
- mask.layout.total_nr_elems()));
- #else
- uint64_t seed = param().seed;
- float drop_prob = param().drop_prob;
-
- if (!dropout_status.initialized()) {
- dropout_status.set(cudnn_handle(this->handle()), seed, drop_prob);
- }
- if (dropout_status.drop_prob != drop_prob) {
- dropout_status.drop_prob = drop_prob;
- dropout_status.restore_desc(cudnn_handle(this->handle()));
- }
-
- auto&& op_desc = dropout_status.desc;
- DropoutTensorDesc doup_desc(doup.layout), dinp_desc(dinp.layout);
-
- cudnn_check(cudnnDropoutBackward(
- cudnn_handle(this->handle()), op_desc.desc, doup_desc.desc, doup.raw_ptr(),
- dinp_desc.desc, dinp.raw_ptr(), mask.raw_ptr(),
- mask.layout.total_nr_elems()));
- #endif
- }
-
- } // namespace cuda
- } // namespace megdnn
- // vim: syntax=cpp.doxygen
|