/** * \file dnn/src/rocm/param_pack/opr_impl.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "hcc_detail/hcc_defs_prologue.h" #include "src/rocm/param_pack/opr_impl.h" #include "src/rocm/param_pack/param_pack.h.hip" #include "src/rocm/utils.h" namespace megdnn { namespace rocm { size_t ParamPackConcatImpl::get_workspace_in_bytes(const TensorShapeArray& srcs, const TensorShape&, const TensorShape&) { return sizeof(size_t) * srcs.size(); } template void ParamPackConcatImpl::exec_internal(_megdnn_tensor_in srcs, _megdnn_tensor_in offsets, _megdnn_tensor_out dst, _megdnn_workspace workspace) { size_t inp_size = srcs.layout.shape[0], out_size = dst.layout.total_nr_elems(); auto stream = hip_stream(this->handle()); auto src_cpu = static_cast(srcs.raw_ptr); megdnn_assert_internal(src_cpu); auto src_gpu = reinterpret_cast(workspace.raw_ptr); auto offsets_gpu = offsets.ptr(); hip_check(hipMemcpyAsync(src_gpu, src_cpu, sizeof(const T*) * inp_size, hipMemcpyHostToDevice, stream)); param_pack::concat_proxy(src_gpu, dst.ptr(), inp_size, out_size, offsets_gpu, stream); } void ParamPackConcatImpl::exec(_megdnn_tensor_in srcs, _megdnn_tensor_in offsets, _megdnn_tensor_out dst, _megdnn_workspace workspace) { check_exec(dst.layout, offsets.layout, srcs.layout); #define cb(DType) \ if (dst.layout.dtype == DType()) { \ using ctype = typename DTypeTrait::ctype; \ exec_internal(srcs, offsets, dst, workspace); \ return; \ } MEGDNN_FOREACH_COMPUTING_DTYPE(cb) megdnn_throw("bad type"); #undef cb } } // namespace rocm } // namespace megdnn