/** * \file dnn/src/cuda/cond_take/opr_impl.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "./opr_impl.h" #include "./kern.cuh" #include "src/common/utils.h" #include "src/common/cond_take/predicate.cuh" #include "src/cuda/handle.h" #include "src/cuda/utils.h" using namespace megdnn; using namespace cuda; using namespace cuda::cond_take; using namespace megdnn::cond_take; using Param = CondTake::Param; WorkspaceBundle CondTakeImpl::make_bundle(size_t nr_item) { cuda_check(cudaSetDevice(concrete_handle(handle())->device_id())); auto gen_idx_wk_size = gen_idx_get_workspace_size(nr_item); return {nullptr, {(nr_item + 1) * sizeof(IdxType), gen_idx_wk_size}, handle()->alignment_requirement()}; } size_t CondTakeImpl::get_workspace_in_bytes(const TensorLayout& data) { return make_bundle(data.total_nr_elems()).total_size_in_bytes(); } CondTakeImpl::Output CondTakeImpl::exec( _megdnn_tensor_in data, _megdnn_tensor_in mask, _megdnn_workspace workspace, DynOutMallocPolicyCall malloc_policy) { size_t size = check_exec_get_size(data.layout, mask.layout, workspace.size); auto wk_bundle = make_bundle(size); wk_bundle.set(workspace.raw_ptr); auto idx_tmp = static_cast(wk_bundle.get(0)); KParam kparam(param()); auto stream = cuda_stream(handle()); size_t out_size; switch (mask.layout.dtype.enumv()) { #define cb(_dt) \ case DTypeTrait<_dt>::enumv: { \ using ctype = DTypeTrait<_dt>::ctype; \ out_size = gen_idx(wk_bundle.get(1), wk_bundle.get_size(1), \ idx_tmp, mask.ptr(), \ size, static_cast(param().mode), kparam, \ stream); \ break; \ } MEGDNN_FOREACH_COMPUTING_DTYPE(cb) cb(::megdnn::dtype::Bool) #undef cb default: megdnn_throw("bad mask dtype"); } auto out_data = malloc_policy.alloc_output(0, data.layout.dtype, {out_size}); auto out_idx = malloc_policy.alloc_output(1, dtype::Int32(), {out_size}); auto out_idx_ptr = out_idx.ptr(); switch (data.layout.dtype.enumv()) { #define cb(_dt) \ case DTypeTrait<_dt>::enumv: { \ using ctype = DTypeTrait<_dt>::ctype; \ auto out_data_ptr = out_data.ptr(); \ auto data_ptr = data.ptr(); \ copy_output( \ out_data_ptr, out_idx_ptr, data_ptr, idx_tmp, size, \ stream); \ break; \ } MEGDNN_FOREACH_COMPUTING_DTYPE(cb) cb(::megdnn::dtype::Bool) #undef cb default: megdnn_throw("bad data dtype"); } return {{out_data, out_idx}}; } // vim: syntax=cpp.doxygen