@@ -0,0 +1,52 @@ | |||||
/** | |||||
* \file dnn/src/rocm/adaptive_pooling/opr_impl.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "src/rocm/adaptive_pooling/opr_impl.h" | |||||
namespace megdnn { | |||||
namespace rocm { | |||||
void AdaptivePoolingForwardImpl::exec(_megdnn_tensor_in src, | |||||
_megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) { | |||||
auto opr = handle()->create_operator<PoolingForward>(); | |||||
opr->param() = deduce_pooling_param(src.layout, dst.layout); | |||||
opr->exec(src, dst, workspace); | |||||
} | |||||
size_t AdaptivePoolingForwardImpl::get_workspace_in_bytes( | |||||
const TensorLayout& src, const TensorLayout& dst) { | |||||
auto opr = handle()->create_operator<PoolingForward>(); | |||||
opr->param() = deduce_pooling_param(src, dst); | |||||
return opr->get_workspace_in_bytes(src, dst); | |||||
} | |||||
void AdaptivePoolingBackwardImpl::exec(_megdnn_tensor_in src, | |||||
_megdnn_tensor_in dst, | |||||
_megdnn_tensor_in diff, | |||||
_megdnn_tensor_out grad, | |||||
_megdnn_workspace workspace) { | |||||
auto opr = handle()->create_operator<PoolingBackward>(); | |||||
opr->param() = deduce_pooling_param(src.layout, dst.layout); | |||||
opr->exec(src, dst, diff, grad, workspace); | |||||
} | |||||
size_t AdaptivePoolingBackwardImpl::get_workspace_in_bytes( | |||||
const TensorLayout& src, const TensorLayout& dst, | |||||
const TensorLayout& diff, const TensorLayout& grad) { | |||||
auto opr = handle()->create_operator<PoolingBackward>(); | |||||
opr->param() = deduce_pooling_param(src, dst); | |||||
return opr->get_workspace_in_bytes(src, dst, diff, grad); | |||||
} | |||||
} // namespace rocm | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,41 @@ | |||||
/** | |||||
* \file dnn/src/rocm/adaptive_pooling/opr_impl.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megdnn/oprs.h" | |||||
namespace megdnn { | |||||
namespace rocm { | |||||
class AdaptivePoolingForwardImpl final : public AdaptivePoolingForward { | |||||
public: | |||||
using AdaptivePoolingForward::AdaptivePoolingForward; | |||||
void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) override; | |||||
size_t get_workspace_in_bytes(const TensorLayout& src, | |||||
const TensorLayout& dst) override; | |||||
}; | |||||
class AdaptivePoolingBackwardImpl final : public AdaptivePoolingBackward { | |||||
public: | |||||
using AdaptivePoolingBackward::AdaptivePoolingBackward; | |||||
void exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, | |||||
_megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||||
_megdnn_workspace workspace) override; | |||||
size_t get_workspace_in_bytes(const TensorLayout& src, | |||||
const TensorLayout& dst, | |||||
const TensorLayout& diff, | |||||
const TensorLayout& grad) override; | |||||
}; | |||||
} // namespace rocm | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -22,6 +22,7 @@ | |||||
#include "src/rocm/elemwise/opr_impl.h" | #include "src/rocm/elemwise/opr_impl.h" | ||||
#include "src/rocm/eye/opr_impl.h" | #include "src/rocm/eye/opr_impl.h" | ||||
#include "src/rocm/pooling/opr_impl.h" | #include "src/rocm/pooling/opr_impl.h" | ||||
#include "src/rocm/adaptive_pooling/opr_impl.h" | |||||
#include "src/rocm/reduce/opr_impl.h" | #include "src/rocm/reduce/opr_impl.h" | ||||
#include "src/rocm/type_cvt/opr_impl.h" | #include "src/rocm/type_cvt/opr_impl.h" | ||||
#include "src/rocm/topk/opr_impl.h" | #include "src/rocm/topk/opr_impl.h" | ||||
@@ -160,6 +161,8 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(Eye); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ChecksumForward); | MEGDNN_SPECIALIZE_CREATE_OPERATOR(ChecksumForward); | ||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(PoolingForward); | MEGDNN_SPECIALIZE_CREATE_OPERATOR(PoolingForward); | ||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(PoolingBackward); | MEGDNN_SPECIALIZE_CREATE_OPERATOR(PoolingBackward); | ||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(AdaptivePoolingForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(AdaptivePoolingBackward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ReduceForward); | MEGDNN_SPECIALIZE_CREATE_OPERATOR(ReduceForward); | ||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(TypeCvt); | MEGDNN_SPECIALIZE_CREATE_OPERATOR(TypeCvt); | ||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(TopK); | MEGDNN_SPECIALIZE_CREATE_OPERATOR(TopK); | ||||
@@ -0,0 +1,98 @@ | |||||
/** | |||||
* \file dnn/test/rocm/adaptive_pooling.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "hcc_detail/hcc_defs_prologue.h" | |||||
#include "test/rocm/fixture.h" | |||||
#include "megdnn/tensor_iter.h" | |||||
#include "test/common/adaptive_pooling.h" | |||||
#include "test/common/checker.h" | |||||
#include "src/common/utils.h" | |||||
#include "test/rocm/utils.h" | |||||
#include "test/rocm/benchmarker.h" | |||||
namespace megdnn { | |||||
namespace test { | |||||
TEST_F(ROCM, ADAPTIVE_POOLING_FORWARD) { | |||||
auto args = adaptive_pooling::get_args(); | |||||
using Format = param::AdaptivePooling::Format; | |||||
DType dtype = dtype::Float32(); | |||||
for (auto&& arg : args) { | |||||
auto param = arg.param; | |||||
auto src = arg.ishape; | |||||
auto dst = arg.oshape; | |||||
param.format = Format::NCHW; | |||||
Checker<AdaptivePooling> checker(handle_rocm()); | |||||
checker.set_epsilon(1e-2); | |||||
checker.set_param(param).set_dtype(0, dtype).set_dtype(1, dtype).exec( | |||||
TensorShapeArray{src, dst, {}}); | |||||
} | |||||
} | |||||
TEST_F(ROCM, ADAPTIVE_POOLING_BACKWARD) { | |||||
auto args = adaptive_pooling::get_args(); | |||||
for (auto&& arg : args) { | |||||
Checker<AdaptivePoolingBackward> checker(handle_rocm()); | |||||
TensorLayout ilayout = TensorLayout(arg.ishape, dtype::Float32()); | |||||
TensorLayout olayout = TensorLayout(arg.oshape, dtype::Float32()); | |||||
auto constraint = [this, | |||||
arg](CheckerHelper::TensorValueArray& tensors_orig) { | |||||
megdnn_assert(tensors_orig.size() == 4); | |||||
auto opr = handle_rocm()->create_operator<AdaptivePoolingForward>(); | |||||
opr->param() = arg.param; | |||||
auto tensors_rocm_storage = CheckerHelper::alloc_tensors( | |||||
handle_rocm(), | |||||
{tensors_orig[0].layout, tensors_orig[1].layout}, 0); | |||||
auto&& tensors_rocm = *tensors_rocm_storage; | |||||
auto span = tensors_rocm[0].layout.span(); | |||||
auto dst = static_cast<dt_byte*>(tensors_rocm[0].raw_ptr) + | |||||
span.low_byte; | |||||
auto src = static_cast<const dt_byte*>(tensors_orig[0].raw_ptr) + | |||||
span.low_byte; | |||||
megdnn_memcpy_H2D(handle_rocm(), dst, src, span.dist_byte()); | |||||
auto workspace_size = opr->get_workspace_in_bytes( | |||||
tensors_rocm[0].layout, tensors_rocm[1].layout); | |||||
auto workspace_rocm = megdnn_malloc(handle_rocm(), workspace_size); | |||||
Workspace workspace{static_cast<dt_byte*>(workspace_rocm), | |||||
workspace_size}; | |||||
opr->exec(tensors_rocm[0], tensors_rocm[1], workspace); | |||||
megdnn_free(handle_rocm(), workspace_rocm); | |||||
span = tensors_rocm[1].layout.span(); | |||||
dst = static_cast<dt_byte*>(tensors_orig[1].raw_ptr) + | |||||
span.low_byte; | |||||
src = static_cast<const dt_byte*>(tensors_rocm[1].raw_ptr) + | |||||
span.low_byte; | |||||
megdnn_memcpy_D2H(handle_rocm(), dst, src, span.dist_byte()); | |||||
}; | |||||
DType dtype = dtype::Float32(); | |||||
checker.set_tensors_constraint(constraint) | |||||
.set_dtype(0, dtype) | |||||
.set_dtype(1, dtype) | |||||
.set_dtype(2, dtype) | |||||
.set_dtype(3, dtype) | |||||
.set_param(arg.param) | |||||
.exec(TensorShapeArray{ilayout, olayout, olayout, ilayout}); | |||||
} | |||||
} | |||||
} // namespace test | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |