GitOrigin-RevId: 95be929841
release-1.5
@@ -96,34 +96,6 @@ void ConvDesc::set(const param::Convolution& param, const size_t nr_group, | |||||
//! not supported | //! not supported | ||||
} | } | ||||
PoolingDesc::PoolingDesc() { | |||||
miopen_check(miopenCreatePoolingDescriptor(&desc)); | |||||
} | |||||
PoolingDesc::~PoolingDesc() { | |||||
miopen_check(miopenDestroyPoolingDescriptor(desc)); | |||||
} | |||||
void PoolingDesc::set(const param::Pooling& param) { | |||||
miopenPoolingMode_t mode; | |||||
switch (param.mode) { | |||||
case param::Pooling::Mode::MAX: | |||||
mode = miopenPoolingMax; | |||||
break; | |||||
case param::Pooling::Mode::AVERAGE_COUNT_EXCLUDE_PADDING: | |||||
mode = miopenPoolingAverage; | |||||
break; | |||||
case param::Pooling::Mode::AVERAGE: | |||||
mode = miopenPoolingAverageInclusive; | |||||
break; | |||||
default: | |||||
megdnn_throw("Unsupported pooling mode for miopen"); | |||||
} | |||||
miopen_check(miopenSet2dPoolingDescriptor( | |||||
desc, mode, param.window_h, param.window_w, param.pad_h, | |||||
param.pad_w, param.stride_h, param.stride_w)); | |||||
} | |||||
LRNDesc::LRNDesc() { | LRNDesc::LRNDesc() { | ||||
miopen_check(miopenCreateLRNDescriptor(&desc)); | miopen_check(miopenCreateLRNDescriptor(&desc)); | ||||
} | } | ||||
@@ -38,14 +38,6 @@ public: | |||||
miopenConvolutionDescriptor_t desc; | miopenConvolutionDescriptor_t desc; | ||||
}; | }; | ||||
class PoolingDesc { | |||||
public: | |||||
PoolingDesc(); | |||||
void set(const param::Pooling& param); | |||||
~PoolingDesc(); | |||||
miopenPoolingDescriptor_t desc; | |||||
}; | |||||
class LRNDesc { | class LRNDesc { | ||||
public: | public: | ||||
LRNDesc(); | LRNDesc(); | ||||
@@ -0,0 +1,209 @@ | |||||
/** | |||||
* \file dnn/src/rocm/pooling/algos.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "./algo.h" | |||||
#include "hcc_detail/hcc_defs_prologue.h" | |||||
#include "src/rocm/utils.h" | |||||
using namespace megdnn; | |||||
using namespace rocm; | |||||
PoolingForwardImpl::AlgoPack::AlgoPack() { | |||||
all_algos.push_back(&algo_miopen); | |||||
for (auto&& algo : all_algos) { | |||||
m_all_algos_map.emplace(algo->info().desc, algo); | |||||
} | |||||
} | |||||
PoolingForwardImpl::AlgoPack PoolingForwardImpl::sm_algo_pack; | |||||
MEGDNN_DEF_GET_ALGO_FROM_DESC(PoolingForwardImpl) | |||||
PoolingForwardImpl::AlgoBase::SizeArgs::SizeArgs(PoolingForwardImpl* o, | |||||
const TensorLayout& src, | |||||
const TensorLayout& dst) | |||||
: handle{concrete_handle(o->handle())}, | |||||
opr{o}, | |||||
layout_src{&src}, | |||||
layout_dst{&dst} {} | |||||
PoolingForwardImpl::AlgoBase::ExecArgs::ExecArgs(PoolingForwardImpl* opr, | |||||
_megdnn_tensor_in src, | |||||
_megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) | |||||
: SizeArgs(opr, src.layout, dst.layout), | |||||
src_tensor{&src}, | |||||
dst_tensor{&dst}, | |||||
workspace{workspace} {} | |||||
std::string PoolingForwardImpl::AlgoBase::SizeArgs::to_string() const { | |||||
return ssprintf("src=%s, dst=%s", layout_src->to_string().c_str(), | |||||
layout_dst->to_string().c_str()); | |||||
} | |||||
bool PoolingForwardImpl::AlgoMIOpen::is_available(const SizeArgs& args) const { | |||||
return true; | |||||
} | |||||
void PoolingForwardImpl::AlgoMIOpen::init_mode( | |||||
const ExecArgs& args, miopenPoolingMode_t& mode) const { | |||||
switch (args.opr->param().mode) { | |||||
case param::Pooling::Mode::MAX: | |||||
mode = miopenPoolingMax; | |||||
break; | |||||
case param::Pooling::Mode::AVERAGE: | |||||
mode = miopenPoolingAverage; | |||||
break; | |||||
case param::Pooling::Mode::AVERAGE_COUNT_EXCLUDE_PADDING: | |||||
mode = miopenPoolingAverageInclusive; | |||||
break; | |||||
default: | |||||
megdnn_throw(ssprintf("Unspport pooling mode : {%d}", | |||||
static_cast<int>(args.opr->param().mode))); | |||||
} | |||||
} | |||||
size_t PoolingForwardImpl::AlgoMIOpen::get_workspace_in_bytes( | |||||
const SizeArgs& args) const { | |||||
return 0; | |||||
} | |||||
void PoolingForwardImpl::AlgoMIOpen::exec(const ExecArgs& args) const { | |||||
auto handle = miopen_handle(args.handle); | |||||
TensorDesc src_desc, dst_desc; | |||||
args.init_desc(src_desc, dst_desc); | |||||
miopenPoolingMode_t mode; | |||||
init_mode(args, mode); | |||||
miopenPoolingDescriptor_t miopen_desc; | |||||
miopen_check(miopenCreatePoolingDescriptor(&miopen_desc)); | |||||
miopen_check(miopenSet2dPoolingDescriptor( | |||||
miopen_desc, mode, args.opr->param().window_h, | |||||
args.opr->param().window_w, args.opr->param().pad_h, | |||||
args.opr->param().pad_w, args.opr->param().stride_h, | |||||
args.opr->param().stride_w)); | |||||
dt_float32 alpha = 1.0f, beta = 0.0f; | |||||
miopen_check(miopenPoolingForward( | |||||
handle, miopen_desc, &alpha, src_desc.desc, | |||||
args.src_tensor->raw_ptr, &beta, dst_desc.desc, | |||||
args.src_tensor->raw_ptr, false, nullptr, 0_z)); | |||||
miopen_check(miopenDestroyPoolingDescriptor(miopen_desc)); | |||||
} | |||||
PoolingBackwardImpl::AlgoPack::AlgoPack() { | |||||
all_algos.push_back(&algo_miopen); | |||||
for (auto&& algo : all_algos) { | |||||
m_all_algos_map.emplace(algo->info().desc, algo); | |||||
} | |||||
} | |||||
PoolingBackwardImpl::AlgoPack PoolingBackwardImpl::sm_algo_pack; | |||||
MEGDNN_DEF_GET_ALGO_FROM_DESC(PoolingBackwardImpl) | |||||
PoolingBackwardImpl::AlgoBase::SizeArgs::SizeArgs(PoolingBackwardImpl* o, | |||||
const TensorLayout& src, | |||||
const TensorLayout& dst, | |||||
const TensorLayout& diff, | |||||
const TensorLayout& grad) | |||||
: handle{concrete_handle(o->handle())}, | |||||
opr{o}, | |||||
layout_src{&src}, | |||||
layout_dst{&dst}, | |||||
layout_diff{&diff}, | |||||
layout_grad{&grad} {} | |||||
PoolingBackwardImpl::AlgoBase::ExecArgs::ExecArgs(PoolingBackwardImpl* opr, | |||||
_megdnn_tensor_in src, | |||||
_megdnn_tensor_in dst, | |||||
_megdnn_tensor_in diff, | |||||
_megdnn_tensor_out grad, | |||||
_megdnn_workspace workspace) | |||||
: SizeArgs(opr, src.layout, dst.layout, diff.layout, grad.layout), | |||||
src_tensor{&src}, | |||||
dst_tensor{&dst}, | |||||
diff_tensor{&diff}, | |||||
grad_tensor{&grad}, | |||||
workspace{workspace} {} | |||||
std::string PoolingBackwardImpl::AlgoBase::SizeArgs::to_string() const { | |||||
return ssprintf( | |||||
"src=%s, dst=%s, diff=%s, grad=%s", layout_src->to_string().c_str(), | |||||
layout_dst->to_string().c_str(), layout_diff->to_string().c_str(), | |||||
layout_grad->to_string().c_str()); | |||||
} | |||||
bool PoolingBackwardImpl::AlgoMIOpen::is_available(const SizeArgs&) const { | |||||
return true; | |||||
} | |||||
size_t PoolingBackwardImpl::AlgoMIOpen::get_workspace_in_bytes( | |||||
const SizeArgs& args) const { | |||||
TensorDesc dst_desc; | |||||
dst_desc.set(*args.layout_dst); | |||||
size_t ws_size = 0_z; | |||||
miopenPoolingGetWorkSpaceSize(dst_desc.desc, &ws_size); | |||||
return ws_size; | |||||
} | |||||
void PoolingBackwardImpl::AlgoMIOpen::init_mode(const ExecArgs& args, | |||||
miopenPoolingMode_t& mode) const { | |||||
switch (args.opr->param().mode) { | |||||
case param::Pooling::Mode::MAX: | |||||
mode = miopenPoolingMax; | |||||
break; | |||||
case param::Pooling::Mode::AVERAGE: | |||||
mode = miopenPoolingAverage; | |||||
break; | |||||
case param::Pooling::Mode::AVERAGE_COUNT_EXCLUDE_PADDING: | |||||
mode = miopenPoolingAverageInclusive; | |||||
break; | |||||
default: | |||||
megdnn_throw(ssprintf("Unspport pooling mode : {%d}", | |||||
static_cast<int>(args.opr->param().mode))); | |||||
} | |||||
} | |||||
void PoolingBackwardImpl::AlgoMIOpen::exec(const ExecArgs& args) const { | |||||
auto handle = miopen_handle(args.handle); | |||||
TensorDesc src_desc, dst_desc, diff_desc, grad_desc; | |||||
args.init_desc(src_desc, dst_desc, diff_desc, grad_desc); | |||||
miopenPoolingMode_t mode; | |||||
init_mode(args, mode); | |||||
miopenPoolingDescriptor_t miopen_desc; | |||||
miopen_check(miopenCreatePoolingDescriptor(&miopen_desc)); | |||||
miopen_check(miopenSet2dPoolingDescriptor( | |||||
miopen_desc, mode, args.opr->param().window_h, | |||||
args.opr->param().window_w, args.opr->param().pad_h, | |||||
args.opr->param().pad_w, args.opr->param().stride_h, | |||||
args.opr->param().stride_w)); | |||||
float alpha = 1.0f, beta = 0.0f; | |||||
if (args.opr->param().mode == param::Pooling::Mode::MAX) { | |||||
//! FIXME: when using max pooling opr, the backward opr need the indices | |||||
//! of the forward opr which stored in workspace. We have to recompute | |||||
//! the indices by calling miopenPoolingForward again. | |||||
miopen_check(miopenPoolingForward( | |||||
handle, miopen_desc, &alpha, src_desc.desc, | |||||
args.src_tensor->raw_ptr, &beta, dst_desc.desc, | |||||
args.dst_tensor->raw_ptr, true, args.workspace.raw_ptr, | |||||
args.workspace.size)); | |||||
} | |||||
miopen_check(miopenPoolingBackward( | |||||
handle, miopen_desc, &alpha, dst_desc.desc, | |||||
args.dst_tensor->raw_ptr, diff_desc.desc, args.diff_tensor->raw_ptr, | |||||
src_desc.desc, args.src_tensor->raw_ptr, &beta, grad_desc.desc, | |||||
args.grad_tensor->raw_ptr, args.workspace.raw_ptr)); | |||||
} |
@@ -0,0 +1,195 @@ | |||||
/** | |||||
* \file dnn/src/rocm/pooling/algo.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include <unordered_map> | |||||
#include "src/common/algo_base.h" | |||||
#include "src/common/metahelper.h" | |||||
#include "src/rocm/miopen_wrapper.h" | |||||
#include "src/rocm/pooling/opr_impl.h" | |||||
#include "src/rocm/handle.h" | |||||
namespace megdnn { | |||||
namespace rocm { | |||||
class PoolingForwardImpl::AlgoBase : public Algorithm { | |||||
public: | |||||
enum class AlgoType : uint32_t { ROCM_MIOPEN }; | |||||
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||||
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::ROCM; } | |||||
struct SizeArgs { | |||||
HandleImpl* handle; | |||||
PoolingForwardImpl* opr; | |||||
const TensorLayout *layout_src, *layout_dst; | |||||
std::string to_string() const; | |||||
void init_desc(TensorDesc& src_desc, TensorDesc& dst_desc) const { | |||||
src_desc.set(*layout_src, opr->param().format); | |||||
dst_desc.set(*layout_dst, opr->param().format); | |||||
} | |||||
SizeArgs(PoolingForwardImpl* opr, const TensorLayout& src, | |||||
const TensorLayout& dst); | |||||
}; | |||||
struct ExecArgs : public SizeArgs { | |||||
const TensorND *src_tensor, *dst_tensor; | |||||
Workspace workspace; | |||||
ExecArgs(PoolingForwardImpl* opr, _megdnn_tensor_in src, | |||||
_megdnn_tensor_out dst, _megdnn_workspace workspace); | |||||
}; | |||||
virtual bool is_available(const SizeArgs& args) const = 0; | |||||
virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; | |||||
virtual void exec(const ExecArgs& args) const = 0; | |||||
bool is_available_attribute( | |||||
const SizeArgs& args, | |||||
const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE, | |||||
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { | |||||
return contain_attribute_all(positive_attr) && | |||||
!contain_attribute_any(negative_attr) && is_available(args); | |||||
} | |||||
protected: | |||||
~AlgoBase() = default; | |||||
}; | |||||
class PoolingForwardImpl::AlgoMIOpen final : public AlgoBase { | |||||
std::string m_algo_name; | |||||
AlgoAttribute m_algo_attribute; | |||||
public: | |||||
AlgoMIOpen(AlgoAttribute attr) | |||||
: m_algo_name("MIOpenPoolingForward"), m_algo_attribute(attr) {} | |||||
bool is_available(const SizeArgs& args) const override; | |||||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||||
void init_mode(const ExecArgs& args, miopenPoolingMode_t& mode) const; | |||||
void exec(const ExecArgs& args) const override; | |||||
const char* name() const override { return m_algo_name.c_str(); } | |||||
AlgoAttribute attribute() const override { return m_algo_attribute; } | |||||
MEGDNN_DECL_ALGO_TYPE(ROCM_MIOPEN) | |||||
std::string param() const override { | |||||
std::string ret; | |||||
serialize_write_pod(m_algo_attribute, ret); | |||||
return ret; | |||||
} | |||||
}; | |||||
class PoolingForwardImpl::AlgoPack : NonCopyableObj { | |||||
private: | |||||
AlgoBase::Mapper m_all_algos_map; | |||||
public: | |||||
AlgoPack(); | |||||
AlgoMIOpen algo_miopen{AlgoAttribute::REPRODUCIBLE}; | |||||
std::vector<AlgoBase*> all_algos; | |||||
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||||
}; | |||||
class PoolingBackwardImpl::AlgoBase : public Algorithm { | |||||
public: | |||||
enum class AlgoType : uint32_t { ROCM_MIOPEN }; | |||||
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||||
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::ROCM; } | |||||
struct SizeArgs { | |||||
HandleImpl* handle; | |||||
PoolingBackwardImpl* opr; | |||||
const TensorLayout *layout_src, *layout_dst, *layout_diff, *layout_grad; | |||||
std::string to_string() const; | |||||
void init_desc(TensorDesc& src_desc, TensorDesc& dst_desc, | |||||
TensorDesc& diff_desc, TensorDesc& grad_desc) const { | |||||
src_desc.set(*layout_src); | |||||
dst_desc.set(*layout_dst); | |||||
diff_desc.set(*layout_diff); | |||||
grad_desc.set(*layout_grad); | |||||
} | |||||
SizeArgs(PoolingBackwardImpl* opr, const TensorLayout& src, | |||||
const TensorLayout& dst, const TensorLayout& diff, | |||||
const TensorLayout& grad); | |||||
}; | |||||
struct ExecArgs : public SizeArgs { | |||||
const TensorND *src_tensor, *dst_tensor, *diff_tensor, *grad_tensor; | |||||
Workspace workspace; | |||||
ExecArgs(PoolingBackwardImpl* opr, _megdnn_tensor_in src, | |||||
_megdnn_tensor_in dst, _megdnn_tensor_in diff, | |||||
_megdnn_tensor_out grad, _megdnn_workspace workspace); | |||||
}; | |||||
virtual bool is_available(const SizeArgs& args) const = 0; | |||||
virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; | |||||
virtual void exec(const ExecArgs& args) const = 0; | |||||
bool is_available_attribute( | |||||
const SizeArgs& args, | |||||
const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE, | |||||
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { | |||||
return contain_attribute_all(positive_attr) && | |||||
!contain_attribute_any(negative_attr) && is_available(args); | |||||
} | |||||
protected: | |||||
~AlgoBase() = default; | |||||
}; | |||||
class PoolingBackwardImpl::AlgoMIOpen final : public AlgoBase { | |||||
std::string m_algo_name; | |||||
AlgoAttribute m_algo_attribute; | |||||
public: | |||||
AlgoMIOpen(AlgoAttribute attr) | |||||
: m_algo_name("MIOpenPoolingBackward"), m_algo_attribute(attr) {} | |||||
bool is_available(const SizeArgs& args) const override; | |||||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||||
void init_mode(const ExecArgs& args, miopenPoolingMode_t& mode) const; | |||||
void exec(const ExecArgs& args) const override; | |||||
const char* name() const override { return m_algo_name.c_str(); } | |||||
AlgoAttribute attribute() const override { | |||||
return m_algo_attribute; | |||||
} | |||||
MEGDNN_DECL_ALGO_TYPE(ROCM_MIOPEN) | |||||
std::string param() const override { | |||||
std::string ret; | |||||
serialize_write_pod(m_algo_attribute, ret); | |||||
return ret; | |||||
} | |||||
}; | |||||
class PoolingBackwardImpl::AlgoPack : NonCopyableObj { | |||||
private: | |||||
AlgoBase::Mapper m_all_algos_map; | |||||
public: | |||||
AlgoPack(); | |||||
AlgoMIOpen algo_miopen{AlgoAttribute::REPRODUCIBLE}; | |||||
std::vector<AlgoBase*> all_algos; | |||||
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||||
}; | |||||
} // namespace rocm | |||||
} // namespace megdnn |
@@ -10,18 +10,47 @@ | |||||
*/ | */ | ||||
#include "hcc_detail/hcc_defs_prologue.h" | #include "hcc_detail/hcc_defs_prologue.h" | ||||
#include "src/rocm/pooling/opr_impl.h" | #include "src/rocm/pooling/opr_impl.h" | ||||
#include "src/rocm/utils.h" | #include "src/rocm/utils.h" | ||||
#include "./algo.h" | |||||
#include "src/common/algo_chooser.h" | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace rocm { | namespace rocm { | ||||
void PoolingForwardImpl::setup_descs(const TensorLayout &src, | |||||
const TensorLayout &dst) | |||||
{ | |||||
src_desc.set(src, param().format); | |||||
dst_desc.set(dst, param().format); | |||||
pooling_desc.set(this->param()); | |||||
size_t PoolingForwardImpl::get_workspace_in_bytes(const TensorLayout& src, | |||||
const TensorLayout& dst) { | |||||
AlgoBase::SizeArgs args(this, src, dst); | |||||
return get_algorithm(this, src, dst)->get_workspace_in_bytes(args); | |||||
} | |||||
const char* PoolingForwardImpl::get_algorithm_set_name() const { | |||||
return "ROCM_POOLING_FORWARD"; | |||||
} | |||||
std::vector<PoolingForwardImpl::Algorithm*> | |||||
PoolingForwardImpl::get_all_algorithms(const TensorLayout& src, | |||||
const TensorLayout& dst) { | |||||
return megdnn::get_all_algorithms<PoolingForwardImpl>({this, src, dst}); | |||||
} | |||||
PoolingForwardImpl::Algorithm* PoolingForwardImpl::get_algorithm_heuristic( | |||||
const TensorLayout& src, const TensorLayout& dst, | |||||
size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | |||||
const AlgoAttribute& negative_attr) { | |||||
MEGDNN_MARK_USED_VAR(workspace_limit_in_bytes); | |||||
AlgoBase::SizeArgs args(this, src, dst); | |||||
for (auto&& iter : sm_algo_pack.all_algos) { | |||||
if (iter->is_available_attribute(args, positive_attr, negative_attr)) { | |||||
return iter; | |||||
} | |||||
} | |||||
megdnn_throw( | |||||
ssprintf("require algorithm with attribute(%s) and without " | |||||
"attribute(%s), but can't get suitable algo.\n", | |||||
Algorithm::attribute_str(positive_attr).c_str(), | |||||
Algorithm::attribute_str(negative_attr).c_str())); | |||||
return nullptr; | |||||
} | } | ||||
void PoolingForwardImpl::exec(_megdnn_tensor_in src, | void PoolingForwardImpl::exec(_megdnn_tensor_in src, | ||||
@@ -29,24 +58,52 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in src, | |||||
_megdnn_workspace workspace) | _megdnn_workspace workspace) | ||||
{ | { | ||||
check_exec(src.layout, dst.layout, workspace.size); | check_exec(src.layout, dst.layout, workspace.size); | ||||
auto handle = miopen_handle(this->handle()); | |||||
setup_descs(src.layout, dst.layout); | |||||
dt_float32 alpha = 1.0f, beta = 0.0f; | |||||
miopen_check(miopenPoolingForward(handle, pooling_desc.desc, &alpha, | |||||
src_desc.desc, src.raw_ptr, &beta, | |||||
dst_desc.desc, dst.raw_ptr, false, | |||||
nullptr, 0_z)); | |||||
{ | |||||
AlgoBase::ExecArgs args(this, src, dst, workspace); | |||||
auto algo = get_algorithm(this, src.layout, dst.layout); | |||||
algo->exec(args); | |||||
} | |||||
} | } | ||||
void PoolingBackwardImpl::setup_descs(const TensorLayout& src, | |||||
const TensorLayout& dst, | |||||
const TensorLayout& diff, | |||||
const TensorLayout& grad) { | |||||
src_desc.set(src); | |||||
dst_desc.set(dst); | |||||
diff_desc.set(diff); | |||||
grad_desc.set(grad); | |||||
pooling_desc.set(this->param()); | |||||
size_t PoolingBackwardImpl::get_workspace_in_bytes(const TensorLayout& src, | |||||
const TensorLayout& dst, | |||||
const TensorLayout& diff, | |||||
const TensorLayout& grad) { | |||||
AlgoBase::SizeArgs args(this, src, dst, diff, grad); | |||||
return get_algorithm(this, src, dst, diff, grad) | |||||
->get_workspace_in_bytes(args); | |||||
}; | |||||
const char* PoolingBackwardImpl::get_algorithm_set_name() const { | |||||
return "ROCM_POOLING_BACKWARD"; | |||||
} | |||||
std::vector<Algorithm*> PoolingBackwardImpl::get_all_algorithms( | |||||
const TensorLayout& src, const TensorLayout& dst, | |||||
const TensorLayout& diff, const TensorLayout& grad) { | |||||
return megdnn::get_all_algorithms<PoolingBackwardImpl>( | |||||
{this, src, dst, diff, grad}); | |||||
} | |||||
Algorithm* PoolingBackwardImpl::get_algorithm_heuristic( | |||||
const TensorLayout& src, const TensorLayout& dst, | |||||
const TensorLayout& diff, const TensorLayout& grad, | |||||
size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | |||||
const AlgoAttribute& negative_attr) { | |||||
MEGDNN_MARK_USED_VAR(workspace_limit_in_bytes); | |||||
AlgoBase::SizeArgs args(this, src, dst, diff, grad); | |||||
for (auto iter : sm_algo_pack.all_algos) { | |||||
if (iter->is_available_attribute(args, positive_attr, negative_attr)) { | |||||
return iter; | |||||
} | |||||
} | |||||
megdnn_throw( | |||||
ssprintf("require algorithm with attribute(%s) and without " | |||||
"attribute(%s), but can't get suitable algo.\n", | |||||
Algorithm::attribute_str(positive_attr).c_str(), | |||||
Algorithm::attribute_str(negative_attr).c_str())); | |||||
return nullptr; | |||||
} | } | ||||
void PoolingBackwardImpl::exec(_megdnn_tensor_in src, | void PoolingBackwardImpl::exec(_megdnn_tensor_in src, | ||||
@@ -55,35 +112,16 @@ void PoolingBackwardImpl::exec(_megdnn_tensor_in src, | |||||
_megdnn_tensor_out grad, | _megdnn_tensor_out grad, | ||||
_megdnn_workspace workspace) | _megdnn_workspace workspace) | ||||
{ | { | ||||
check_exec(src.layout, dst.layout, diff.layout, grad.layout, workspace.size); | |||||
auto handle = miopen_handle(this->handle()); | |||||
setup_descs(src.layout, dst.layout, diff.layout, grad.layout); | |||||
float alpha = 1.0f, beta = 0.0f; | |||||
if (param().mode == param::Pooling::Mode::MAX) { | |||||
//! FIXME: when using max pooling opr, the backward opr need the indices | |||||
//! of the forward opr which stored in workspace. We have to recompute | |||||
//! the indices by calling miopenPoolingForward again. | |||||
miopen_check(miopenPoolingForward(handle, pooling_desc.desc, &alpha, | |||||
src_desc.desc, src.raw_ptr, &beta, | |||||
dst_desc.desc, dst.raw_ptr, true, | |||||
workspace.raw_ptr, workspace.size)); | |||||
check_exec(src.layout, dst.layout, diff.layout, grad.layout, | |||||
workspace.size); | |||||
{ | |||||
AlgoBase::ExecArgs args(this, src, dst, diff, grad, workspace); | |||||
auto algo = get_algorithm(this, src.layout, dst.layout, diff.layout, | |||||
grad.layout); | |||||
algo->exec(args); | |||||
} | } | ||||
miopen_check(miopenPoolingBackward( | |||||
handle, pooling_desc.desc, &alpha, dst_desc.desc, dst.raw_ptr, | |||||
diff_desc.desc, diff.raw_ptr, src_desc.desc, src.raw_ptr, &beta, | |||||
grad_desc.desc, grad.raw_ptr, workspace.raw_ptr)); | |||||
} | } | ||||
size_t PoolingBackwardImpl::get_workspace_in_bytes(const TensorLayout& src, | |||||
const TensorLayout& dst, | |||||
const TensorLayout& diff, | |||||
const TensorLayout& grad) { | |||||
setup_descs(src, dst, diff, grad); | |||||
size_t ws_size = 0_z; | |||||
miopenPoolingGetWorkSpaceSize(dst_desc.desc, &ws_size); | |||||
return ws_size; | |||||
}; | |||||
} // namespace rocm | } // namespace rocm | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -22,13 +22,37 @@ class PoolingForwardImpl final: public PoolingForward { | |||||
void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | ||||
_megdnn_workspace workspace) override; | _megdnn_workspace workspace) override; | ||||
size_t get_workspace_in_bytes(const TensorLayout &, | size_t get_workspace_in_bytes(const TensorLayout &, | ||||
const TensorLayout &) override { | |||||
return 0; | |||||
const TensorLayout &) override; | |||||
const char* get_algorithm_set_name() const override; | |||||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||||
AlgorithmInfo get_algorithm_info_heuristic( | |||||
const TensorLayout& src, const TensorLayout& dst, | |||||
size_t workspace_limit_in_bytes, | |||||
const AlgoAttribute& positive_attr, | |||||
const AlgoAttribute& negative_attr) { | |||||
return get_algorithm_heuristic(src, dst, workspace_limit_in_bytes, | |||||
positive_attr, negative_attr) | |||||
->info(); | |||||
} | } | ||||
class AlgoBase; | |||||
class AlgoMIOpen; | |||||
class AlgoPack; | |||||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||||
protected: | |||||
std::vector<Algorithm*> get_all_algorithms( | |||||
const TensorLayout& src, const TensorLayout& dst) override; | |||||
Algorithm* get_algorithm_heuristic( | |||||
const TensorLayout& src, const TensorLayout& dst, | |||||
size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | |||||
const AlgoAttribute& negative_attr) override; | |||||
private: | private: | ||||
TensorDesc src_desc, dst_desc; | |||||
PoolingDesc pooling_desc; | |||||
void setup_descs(const TensorLayout &src, const TensorLayout &dst); | |||||
static AlgoPack sm_algo_pack; | |||||
}; | }; | ||||
class PoolingBackwardImpl final: public PoolingBackward { | class PoolingBackwardImpl final: public PoolingBackward { | ||||
@@ -43,14 +67,41 @@ class PoolingBackwardImpl final: public PoolingBackward { | |||||
const TensorLayout& dst, | const TensorLayout& dst, | ||||
const TensorLayout& diff, | const TensorLayout& diff, | ||||
const TensorLayout& grad) override; | const TensorLayout& grad) override; | ||||
private: | |||||
TensorDesc src_desc, dst_desc, diff_desc, grad_desc; | |||||
PoolingDesc pooling_desc; | |||||
void setup_descs(const TensorLayout &src, | |||||
const TensorLayout &dst, | |||||
const TensorLayout &diff, | |||||
const TensorLayout &grad); | |||||
const char* get_algorithm_set_name() const override; | |||||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||||
AlgorithmInfo get_algorithm_info_heuristic( | |||||
const TensorLayout& src, const TensorLayout& dst, | |||||
const TensorLayout& diff, const TensorLayout& grad, | |||||
size_t workspace_limit_in_bytes, | |||||
const AlgoAttribute& positive_attr, | |||||
const AlgoAttribute& negative_attr) { | |||||
return get_algorithm_heuristic(src, dst, diff, grad, | |||||
workspace_limit_in_bytes, | |||||
positive_attr, negative_attr) | |||||
->info(); | |||||
} | |||||
class AlgoBase; | |||||
class AlgoMIOpen; | |||||
class AlgoPack; | |||||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||||
protected: | |||||
std::vector<Algorithm*> get_all_algorithms( | |||||
const TensorLayout& src, const TensorLayout& dst, | |||||
const TensorLayout& diff, const TensorLayout& grad) override; | |||||
Algorithm* get_algorithm_heuristic( | |||||
const TensorLayout& src, const TensorLayout& dst, | |||||
const TensorLayout& diff, const TensorLayout& grad, | |||||
size_t workspace_limit_in_bytes, | |||||
const AlgoAttribute& positive_attr, | |||||
const AlgoAttribute& negative_attr) override; | |||||
private: | |||||
static AlgoPack sm_algo_pack; | |||||
}; | }; | ||||
} // namespace rocm | } // namespace rocm | ||||