@@ -36,6 +36,7 @@ | |||||
#include "src/rocm/argmxx/opr_impl.h" | #include "src/rocm/argmxx/opr_impl.h" | ||||
#include "src/rocm/sleep/opr_impl.h" | #include "src/rocm/sleep/opr_impl.h" | ||||
#include "src/rocm/batch_normalization/opr_impl.h" | #include "src/rocm/batch_normalization/opr_impl.h" | ||||
#include "src/rocm/param_pack/opr_impl.h" | |||||
#include <miopen/version.h> | #include <miopen/version.h> | ||||
#include <hip/hip_version.h> | #include <hip/hip_version.h> | ||||
@@ -174,6 +175,7 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(ArgminForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(SleepForward); | MEGDNN_SPECIALIZE_CREATE_OPERATOR(SleepForward); | ||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(BNForward); | MEGDNN_SPECIALIZE_CREATE_OPERATOR(BNForward); | ||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(BNBackward); | MEGDNN_SPECIALIZE_CREATE_OPERATOR(BNBackward); | ||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ParamPackConcat); | |||||
#pragma GCC diagnostic push | #pragma GCC diagnostic push | ||||
#pragma GCC diagnostic ignored "-Wpragmas" | #pragma GCC diagnostic ignored "-Wpragmas" | ||||
@@ -0,0 +1,65 @@ | |||||
/** | |||||
* \file dnn/src/rocm/param_pack/opr_impl.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "hcc_detail/hcc_defs_prologue.h" | |||||
#include "src/rocm/param_pack/opr_impl.h" | |||||
#include "src/rocm/param_pack/param_pack.h.hip" | |||||
#include "src/rocm/utils.h" | |||||
namespace megdnn { | |||||
namespace rocm { | |||||
size_t ParamPackConcatImpl::get_workspace_in_bytes(const TensorShapeArray& srcs, | |||||
const TensorShape&, | |||||
const TensorShape&) { | |||||
return sizeof(size_t) * srcs.size(); | |||||
} | |||||
template <typename T> | |||||
void ParamPackConcatImpl::exec_internal(_megdnn_tensor_in srcs, | |||||
_megdnn_tensor_in offsets, | |||||
_megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) { | |||||
size_t inp_size = srcs.layout.shape[0], | |||||
out_size = dst.layout.total_nr_elems(); | |||||
auto stream = hip_stream(this->handle()); | |||||
auto src_cpu = static_cast<const T**>(srcs.raw_ptr); | |||||
megdnn_assert_internal(src_cpu); | |||||
auto src_gpu = reinterpret_cast<const T**>(workspace.raw_ptr); | |||||
auto offsets_gpu = offsets.ptr<int32_t>(); | |||||
hip_check(hipMemcpyAsync(src_gpu, src_cpu, sizeof(const T*) * inp_size, | |||||
hipMemcpyHostToDevice, stream)); | |||||
param_pack::concat_proxy<T>(src_gpu, dst.ptr<T>(), inp_size, out_size, | |||||
offsets_gpu, stream); | |||||
} | |||||
void ParamPackConcatImpl::exec(_megdnn_tensor_in srcs, | |||||
_megdnn_tensor_in offsets, | |||||
_megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) { | |||||
check_exec(dst.layout, offsets.layout, srcs.layout); | |||||
#define cb(DType) \ | |||||
if (dst.layout.dtype == DType()) { \ | |||||
using ctype = typename DTypeTrait<DType>::ctype; \ | |||||
exec_internal<ctype>(srcs, offsets, dst, workspace); \ | |||||
return; \ | |||||
} | |||||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||||
megdnn_throw("bad type"); | |||||
#undef cb | |||||
} | |||||
} // namespace rocm | |||||
} // namespace megdnn |
@@ -0,0 +1,35 @@ | |||||
/** | |||||
* \file dnn/src/rocm/param_pack/opr_impl.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megdnn/oprs.h" | |||||
namespace megdnn { | |||||
namespace rocm { | |||||
class ParamPackConcatImpl final : public ParamPackConcat { | |||||
public: | |||||
using ParamPackConcat::ParamPackConcat; | |||||
void exec(_megdnn_tensor_in srcs, _megdnn_tensor_in table, | |||||
_megdnn_tensor_out dst, _megdnn_workspace workspace) override; | |||||
size_t get_workspace_in_bytes(const TensorShapeArray& srcs, | |||||
const TensorShape& table, | |||||
const TensorShape& dst) override; | |||||
private: | |||||
template <typename T> | |||||
void exec_internal(_megdnn_tensor_in srcs, _megdnn_tensor_in table, | |||||
_megdnn_tensor_out dst, _megdnn_workspace workspace); | |||||
}; | |||||
} // namespace rocm | |||||
} // namespace megdnn |
@@ -0,0 +1,67 @@ | |||||
/** | |||||
* \file dnn/src/rocm/param_pack/param_pack.cpp.hip | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "hcc_detail/hcc_defs_prologue.h" | |||||
#include "src/rocm/param_pack/param_pack.h.hip" | |||||
#include "megdnn/dtype.h" | |||||
#include "src/rocm/utils.h.hip" | |||||
namespace megdnn { | |||||
namespace rocm { | |||||
namespace param_pack { | |||||
template <typename T> | |||||
__global__ void concat_kernel(const T** srcs, T* dst, | |||||
const int32_t* offsets, | |||||
size_t srcs_size, | |||||
size_t total_size) { | |||||
size_t addr = threadIdx.x + blockIdx.x * blockDim.x; | |||||
if (addr < total_size) { | |||||
size_t l = 0, r = srcs_size - 1, mid; | |||||
while (l < r) { | |||||
mid = (l + r) >> 1; | |||||
if (offsets[(mid << 1) + 1] > addr) { | |||||
r = mid; | |||||
} else { | |||||
l = mid + 1; | |||||
} | |||||
} | |||||
if (addr < offsets[l << 1]) | |||||
dst[addr] = 0; | |||||
else | |||||
dst[addr] = srcs[l][addr - offsets[l << 1]]; | |||||
} | |||||
} | |||||
template <typename T> | |||||
void concat_proxy(const T** srcs, T* dst, size_t srcs_size, size_t total_size, | |||||
const int32_t* offsets, | |||||
hipStream_t stream) { | |||||
size_t NR_BLOCKS = DIVUP(total_size, NR_THREADS); | |||||
hipLaunchKernelGGL(concat_kernel, NR_BLOCKS, NR_THREADS, 0, stream, | |||||
srcs, dst, offsets, srcs_size, total_size); | |||||
after_kernel_launch(); | |||||
} | |||||
#define INST(T) \ | |||||
template void concat_proxy<T>(const T**, T*, size_t, size_t, \ | |||||
const int32_t*, \ | |||||
hipStream_t); | |||||
#define cb(DType) INST(typename DTypeTrait<DType>::ctype) | |||||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||||
#undef cb | |||||
#undef INST | |||||
} // namespace param_pack | |||||
} // namespace hip | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,30 @@ | |||||
/** | |||||
* \file dnn/src/rocm/param_pack/param_pack.h.hip | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "hip_header.h" | |||||
#include <stdint.h> | |||||
#include <stdio.h> | |||||
namespace megdnn { | |||||
namespace rocm { | |||||
namespace param_pack { | |||||
template <typename T> | |||||
void concat_proxy(const T** srcs, T* dst, size_t srcs_size, size_t total_size, | |||||
const int32_t* offsets, hipStream_t stream); | |||||
} // namespace param_pack | |||||
} // namespace rocm | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,156 @@ | |||||
/** | |||||
* \file dnn/test/rocm/param_pack.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "hcc_detail/hcc_defs_prologue.h" | |||||
#include "test/rocm/fixture.h" | |||||
#include "test/common/checker.h" | |||||
#include "test/common/utils.h" | |||||
using namespace megdnn; | |||||
using namespace test; | |||||
namespace { | |||||
template <class T> | |||||
std::vector<int32_t> create_offsets(const TensorShapeArray& shapes, | |||||
size_t alignment) { | |||||
size_t dtype_size = sizeof(T); | |||||
if (alignment < dtype_size) | |||||
alignment = dtype_size; | |||||
alignment /= dtype_size; | |||||
auto get_aligned = [alignment](size_t v) { | |||||
auto mod = v & (alignment - 1); | |||||
return v + ((alignment - mod) & (alignment - 1)); | |||||
}; | |||||
std::vector<dt_int32> offsets(shapes.size() << 1); | |||||
size_t offset = 0; | |||||
for (size_t i = 0; i < shapes.size(); i++) { | |||||
offset = get_aligned(offset); | |||||
offsets[i << 1] = offset; | |||||
offset += shapes[i].total_nr_elems(); | |||||
offsets[(i << 1) + 1] = offset; | |||||
} | |||||
return offsets; | |||||
} | |||||
template <class T> | |||||
std::vector<T> create_pack(size_t pack_size, | |||||
const std::vector<int32_t>& offsets, | |||||
const std::vector<std::vector<T>>& ptr) { | |||||
megdnn_assert(pack_size == static_cast<size_t>(offsets.back())); | |||||
std::vector<T> data(pack_size, 0); | |||||
for (size_t i = 0; i * 2 < offsets.size(); ++i) { | |||||
size_t begin = offsets[i * 2], end = offsets[i * 2 + 1]; | |||||
for (size_t j = 0; j < end - begin; j++) | |||||
data[begin + j] = ptr[i][j]; | |||||
} | |||||
return data; | |||||
} | |||||
template <class T> | |||||
std::vector<std::vector<T>> create_params(size_t nr_params, | |||||
const TensorShapeArray& shapes) { | |||||
std::vector<std::vector<T>> params; | |||||
for (size_t i = 0; i < nr_params; ++i) { | |||||
std::vector<T> expected_data; | |||||
for (size_t x = 0; x < shapes[i].total_nr_elems(); ++x) { | |||||
expected_data.push_back(rand()); | |||||
} | |||||
params.push_back(std::move(expected_data)); | |||||
} | |||||
return params; | |||||
} | |||||
template <class T> | |||||
T* create_device_data(Handle* handle, const T* data, size_t size) { | |||||
T* data_device = | |||||
static_cast<T*>(test::megdnn_malloc(handle, size * sizeof(T))); | |||||
if (data) | |||||
test::megdnn_memcpy_H2D(handle, data_device, data, size * sizeof(T)); | |||||
return data_device; | |||||
} | |||||
template <class T> | |||||
void test_param_pack_concat(Handle* handle, const TensorShapeArray& shapes, | |||||
DType type) { | |||||
auto concat = handle->create_operator<ParamPackConcat>(); | |||||
size_t nr_params = shapes.size(); | |||||
std::vector<T*> param_ptrs; | |||||
std::vector<std::vector<T>> params = create_params<T>(nr_params, shapes); | |||||
for (size_t i = 0; i < nr_params; ++i) { | |||||
param_ptrs.push_back(create_device_data<T>(handle, params[i].data(), | |||||
shapes[i].total_nr_elems())); | |||||
} | |||||
std::vector<int32_t> offsets = | |||||
create_offsets<T>(shapes, handle->alignment_requirement()); | |||||
size_t pack_size = offsets.back(); | |||||
int32_t* offsets_gpu = | |||||
create_device_data<int32_t>(handle, offsets.data(), offsets.size()); | |||||
std::vector<T> expected_pack = create_pack<T>(pack_size, offsets, params); | |||||
T* pack_gpu = create_device_data<T>(handle, nullptr, expected_pack.size()); | |||||
TensorLayout dst_layout({pack_size}, type); | |||||
TensorND dst_tensor(pack_gpu, dst_layout); | |||||
TensorLayout offsets_layout({offsets.size()}, dtype::Int32()); | |||||
TensorND offsets_tensor(offsets_gpu, offsets_layout); | |||||
test::WorkspaceWrapper workspace( | |||||
handle, concat->get_workspace_in_bytes(shapes, offsets_layout, | |||||
{pack_size})); | |||||
TensorND src_tensor(param_ptrs.data(), | |||||
TensorLayout({nr_params}, dtype::Int32())); | |||||
concat->exec(src_tensor, offsets_tensor, dst_tensor, workspace.workspace()); | |||||
// check | |||||
T* actual_pack = static_cast<T*>(malloc(pack_size * sizeof(T))); | |||||
test::megdnn_memcpy_D2H(handle, actual_pack, pack_gpu, | |||||
sizeof(T) * pack_size); | |||||
for (size_t i = 0; i < pack_size; ++i) { | |||||
ASSERT_EQ(actual_pack[i], expected_pack[i]); | |||||
} | |||||
free(actual_pack); | |||||
test::megdnn_free(handle, pack_gpu); | |||||
test::megdnn_free(handle, offsets_gpu); | |||||
for (auto ptr : param_ptrs) { | |||||
test::megdnn_free(handle, ptr); | |||||
} | |||||
} | |||||
} // namespace | |||||
TEST_F(ROCM, PARAM_PACK) { | |||||
SmallVector<TensorShapeArray> shapes_vec; | |||||
shapes_vec.push_back({{1}}); | |||||
shapes_vec.push_back({{129}, {21}}); | |||||
shapes_vec.push_back({{15}, {21}, {34}}); | |||||
shapes_vec.push_back({{1, 2}, {3, 5}, {5, 8}, {7, 11}, {9, 14}}); | |||||
shapes_vec.push_back({{1, 2}, | |||||
{3, 5}, | |||||
{1}, | |||||
{3, 3, 3, 4}, | |||||
{71}, | |||||
{9, 14}, | |||||
{111, 111, 111}, | |||||
{128, 128, 128}}); | |||||
for (auto shapes : shapes_vec) { | |||||
test_param_pack_concat<int32_t>(handle_rocm(), shapes, dtype::Int32()); | |||||
test_param_pack_concat<int16_t>(handle_rocm(), shapes, dtype::Int16()); | |||||
test_param_pack_concat<float>(handle_rocm(), shapes, dtype::Float32()); | |||||
} | |||||
} | |||||
// vim: syntax=cpp.doxygen |