@@ -94,6 +94,11 @@ DefaultConvolution3DBackwardFilterAlgorithm | |||||
HandleImpl::m_default_conv3d_bwd_filter_algo; | HandleImpl::m_default_conv3d_bwd_filter_algo; | ||||
DefaultBatchConvBiasForwardAlgorithm | DefaultBatchConvBiasForwardAlgorithm | ||||
HandleImpl::m_default_batch_conv_bias_fwd_algo; | HandleImpl::m_default_batch_conv_bias_fwd_algo; | ||||
DefaultLocalShareForwardAlgorithm HandleImpl::m_default_local_share_fwd_algo; | |||||
DefaultLocalShareBackwardDataAlgorithm | |||||
HandleImpl::m_default_local_share_bwd_data_algo; | |||||
DefaultLocalShareBackwardFilterAlgorithm | |||||
HandleImpl::m_default_local_share_bwd_filter_algo; | |||||
HandleImpl::HandleImpl(megcoreComputingHandle_t computing_handle, | HandleImpl::HandleImpl(megcoreComputingHandle_t computing_handle, | ||||
HandleType type) | HandleType type) | ||||
@@ -13,6 +13,7 @@ | |||||
#include "src/common/handle_impl.h" | #include "src/common/handle_impl.h" | ||||
#include "src/naive/convolution/algorithms.h" | #include "src/naive/convolution/algorithms.h" | ||||
#include "src/naive/local_share/algorithms.h" | |||||
#include "src/naive/convolution3d/algorithms.h" | #include "src/naive/convolution3d/algorithms.h" | ||||
#include <functional> | #include <functional> | ||||
@@ -39,6 +40,11 @@ class HandleImpl : public HandleImplHelper { | |||||
m_default_conv3d_bwd_filter_algo; | m_default_conv3d_bwd_filter_algo; | ||||
static DefaultBatchConvBiasForwardAlgorithm | static DefaultBatchConvBiasForwardAlgorithm | ||||
m_default_batch_conv_bias_fwd_algo; | m_default_batch_conv_bias_fwd_algo; | ||||
static DefaultLocalShareForwardAlgorithm m_default_local_share_fwd_algo; | |||||
static DefaultLocalShareBackwardDataAlgorithm | |||||
m_default_local_share_bwd_data_algo; | |||||
static DefaultLocalShareBackwardFilterAlgorithm | |||||
m_default_local_share_bwd_filter_algo; | |||||
//! move KernFunc to alloc_kern()->func, destruct func, and call dispatch | //! move KernFunc to alloc_kern()->func, destruct func, and call dispatch | ||||
template <typename T> | template <typename T> | ||||
@@ -91,6 +97,18 @@ public: | |||||
return &m_default_batch_conv_bias_fwd_algo; | return &m_default_batch_conv_bias_fwd_algo; | ||||
} | } | ||||
LocalShareForward::Algorithm* default_local_share_fwd_algo() { | |||||
return &m_default_local_share_fwd_algo; | |||||
} | |||||
LocalShareBackwardData::Algorithm* default_local_share_bwd_data_algo() { | |||||
return &m_default_local_share_bwd_data_algo; | |||||
} | |||||
LocalShareBackwardFilter::Algorithm* default_local_share_bwd_filter_algo() { | |||||
return &m_default_local_share_bwd_filter_algo; | |||||
} | |||||
Relayout* relayout_opr() override { | Relayout* relayout_opr() override { | ||||
return get_helper_opr<Relayout, 2>(this); | return get_helper_opr<Relayout, 2>(this); | ||||
} | } | ||||
@@ -0,0 +1,41 @@ | |||||
/** | |||||
* \file dnn/src/naive/local_share/algorithms.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megdnn/oprs.h" | |||||
namespace megdnn { | |||||
namespace naive { | |||||
class DefaultLocalShareForwardAlgorithm final: | |||||
public megdnn::LocalShareForward::Algorithm { | |||||
bool is_reproducible() const override | |||||
{ return true; } | |||||
const char* name() const override | |||||
{ return "DEFAULT"; } | |||||
}; | |||||
class DefaultLocalShareBackwardDataAlgorithm final: | |||||
public megdnn::LocalShareBackwardData::Algorithm { | |||||
bool is_reproducible() const override | |||||
{ return true; } | |||||
const char* name() const override | |||||
{ return "DEFAULT"; } | |||||
}; | |||||
class DefaultLocalShareBackwardFilterAlgorithm final: | |||||
public megdnn::LocalShareBackwardFilter::Algorithm { | |||||
bool is_reproducible() const override | |||||
{ return true; } | |||||
const char* name() const override | |||||
{ return "DEFAULT"; } | |||||
}; | |||||
} // namespace naive | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -152,4 +152,77 @@ void LocalShareBackwardFilterImpl::exec(_megdnn_tensor_in src, | |||||
StrategyBwdFlt>(src, grad, diff, param()));); | StrategyBwdFlt>(src, grad, diff, param()));); | ||||
} | } | ||||
std::vector<LocalShareForward::Algorithm*> | |||||
LocalShareForwardImpl::get_all_algorithms(const TensorLayout&, | |||||
const TensorLayout&, | |||||
const TensorLayout&) { | |||||
return {static_cast<HandleImpl*>(handle())->default_local_share_fwd_algo()}; | |||||
} | |||||
LocalShareForward::Algorithm* LocalShareForwardImpl::get_algorithm_heuristic( | |||||
const TensorLayout& /* src */, const TensorLayout& /* diff */, | |||||
const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */, | |||||
bool reproducible) { | |||||
auto algo = | |||||
static_cast<HandleImpl*>(handle())->default_local_share_fwd_algo(); | |||||
if (reproducible) { | |||||
megdnn_assert(algo->is_reproducible(), | |||||
"require reproducible algorithm, but heuristic " | |||||
"algorithm(%s) is not " | |||||
"reproducible", | |||||
algo->name()); | |||||
} | |||||
return algo; | |||||
} | |||||
std::vector<LocalShareBackwardData::Algorithm*> | |||||
LocalShareBackwardDataImpl::get_all_algorithms(const TensorLayout&, | |||||
const TensorLayout&, | |||||
const TensorLayout&) { | |||||
return {static_cast<HandleImpl*>(handle()) | |||||
->default_local_share_bwd_data_algo()}; | |||||
} | |||||
LocalShareBackwardData::Algorithm* | |||||
LocalShareBackwardDataImpl::get_algorithm_heuristic( | |||||
const TensorLayout& /* filter */, const TensorLayout& /* diff */, | |||||
const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */, | |||||
bool reproducible) { | |||||
auto algo = static_cast<HandleImpl*>(handle()) | |||||
->default_local_share_bwd_data_algo(); | |||||
if (reproducible) { | |||||
megdnn_assert(algo->is_reproducible(), | |||||
"require reproducible algorithm, but heuristic " | |||||
"algorithm(%s) is not " | |||||
"reproducible", | |||||
algo->name()); | |||||
} | |||||
return algo; | |||||
} | |||||
std::vector<LocalShareBackwardFilter::Algorithm*> | |||||
LocalShareBackwardFilterImpl::get_all_algorithms(const TensorLayout&, | |||||
const TensorLayout&, | |||||
const TensorLayout&) { | |||||
return {static_cast<HandleImpl*>(handle()) | |||||
->default_local_share_bwd_filter_algo()}; | |||||
} | |||||
LocalShareBackwardFilter::Algorithm* | |||||
LocalShareBackwardFilterImpl::get_algorithm_heuristic( | |||||
const TensorLayout& /* src */, const TensorLayout& /* diff */, | |||||
const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */, | |||||
bool reproducible) { | |||||
auto algo = static_cast<HandleImpl*>(handle()) | |||||
->default_local_share_bwd_filter_algo(); | |||||
if (reproducible) { | |||||
megdnn_assert(algo->is_reproducible(), | |||||
"require reproducible algorithm, but heuristic " | |||||
"algorithm(%s) is not " | |||||
"reproducible", | |||||
algo->name()); | |||||
} | |||||
return algo; | |||||
} | |||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -27,17 +27,13 @@ public: | |||||
std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
const TensorLayout& /*src*/, const TensorLayout& /*filter*/, | const TensorLayout& /*src*/, const TensorLayout& /*filter*/, | ||||
const TensorLayout& /*dst*/) override { | |||||
return {}; | |||||
} | |||||
const TensorLayout& /*dst*/) override; | |||||
Algorithm* get_algorithm_heuristic(const TensorLayout& /*src*/, | Algorithm* get_algorithm_heuristic(const TensorLayout& /*src*/, | ||||
const TensorLayout& /*filter*/, | const TensorLayout& /*filter*/, | ||||
const TensorLayout& /*dst*/, | const TensorLayout& /*dst*/, | ||||
size_t /*workspace_limit_in_bytes*/, | size_t /*workspace_limit_in_bytes*/, | ||||
bool /*reproducible*/) override { | |||||
return nullptr; | |||||
} | |||||
bool /*reproducible*/) override; | |||||
const char* get_algorithm_set_name() const override { return "DEFAULT"; } | const char* get_algorithm_set_name() const override { return "DEFAULT"; } | ||||
}; | }; | ||||
@@ -55,17 +51,13 @@ public: | |||||
std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
const TensorLayout& /*filter*/, const TensorLayout& /*diff*/, | const TensorLayout& /*filter*/, const TensorLayout& /*diff*/, | ||||
const TensorLayout& /*grad*/) override { | |||||
return {}; | |||||
} | |||||
const TensorLayout& /*grad*/) override; | |||||
Algorithm* get_algorithm_heuristic(const TensorLayout& /*filter*/, | Algorithm* get_algorithm_heuristic(const TensorLayout& /*filter*/, | ||||
const TensorLayout& /*diff*/, | const TensorLayout& /*diff*/, | ||||
const TensorLayout& /*grad*/, | const TensorLayout& /*grad*/, | ||||
size_t /*workspace_limit_in_bytes*/, | size_t /*workspace_limit_in_bytes*/, | ||||
bool /*reproducible*/) override { | |||||
return nullptr; | |||||
} | |||||
bool /*reproducible*/) override; | |||||
const char* get_algorithm_set_name() const override { return "DEFAULT"; } | const char* get_algorithm_set_name() const override { return "DEFAULT"; } | ||||
}; | }; | ||||
@@ -83,17 +75,13 @@ public: | |||||
std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
const TensorLayout& /*src*/, const TensorLayout& /*diff*/, | const TensorLayout& /*src*/, const TensorLayout& /*diff*/, | ||||
const TensorLayout& /*grad*/) override { | |||||
return {}; | |||||
} | |||||
const TensorLayout& /*grad*/) override; | |||||
Algorithm* get_algorithm_heuristic(const TensorLayout& /*src*/, | Algorithm* get_algorithm_heuristic(const TensorLayout& /*src*/, | ||||
const TensorLayout& /*diff*/, | const TensorLayout& /*diff*/, | ||||
const TensorLayout& /*grad*/, | const TensorLayout& /*grad*/, | ||||
size_t /*workspace_limit_in_bytes*/, | size_t /*workspace_limit_in_bytes*/, | ||||
bool /*reproducible*/) override { | |||||
return nullptr; | |||||
} | |||||
bool /*reproducible*/) override; | |||||
const char* get_algorithm_set_name() const override { return "DEFAULT"; } | const char* get_algorithm_set_name() const override { return "DEFAULT"; } | ||||
}; | }; | ||||
@@ -14,6 +14,7 @@ | |||||
#include "megbrain/gopt/basic_arith.h" | #include "megbrain/gopt/basic_arith.h" | ||||
#include "megbrain/graph/event.h" | #include "megbrain/graph/event.h" | ||||
#include "megbrain/opr/dnn/batch_norm.h" | #include "megbrain/opr/dnn/batch_norm.h" | ||||
#include "megbrain/opr/dnn/local.h" | |||||
#include "megbrain/utils/shared_set.h" | #include "megbrain/utils/shared_set.h" | ||||
#include "megbrain/serialization/opr_shallow_copy.h" | #include "megbrain/serialization/opr_shallow_copy.h" | ||||
#include "megbrain/opr/basic_arith.h" | #include "megbrain/opr/basic_arith.h" | ||||
@@ -1358,23 +1359,28 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { | |||||
return new_pooling_opr.node()->owner_opr(); | return new_pooling_opr.node()->owner_opr(); | ||||
}; | }; | ||||
auto relayout_inp_to_chw = [](OperatorNodeBase* opr, | |||||
auto var_to_chw = [](VarNode* inp, VarNode* new_inp) { | |||||
if (!inp->shape().eq_shape(new_inp->shape())) { | |||||
mgb_assert(inp->shape().ndim == 4 && | |||||
inp->format().type() != | |||||
TensorFormat::Type::IMAGE2D_PACK4); | |||||
mgb_assert(new_inp->shape().ndim == 5 && | |||||
new_inp->format().type() == | |||||
TensorFormat::Type::IMAGE2D_PACK4); | |||||
auto param = megdnn::param::RelayoutFormat(); | |||||
param.mode = megdnn::param::RelayoutFormat::Mode::NHWCD4I_NCHW; | |||||
auto rf = opr::RelayoutFormat::make(new_inp, param); | |||||
return rf.node(); | |||||
} | |||||
return new_inp; | |||||
}; | |||||
auto relayout_inp_to_chw = [var_to_chw](OperatorNodeBase* opr, | |||||
const VarNodeArray& new_inp) { | const VarNodeArray& new_inp) { | ||||
mgb_assert(opr->input().size() == new_inp.size()); | mgb_assert(opr->input().size() == new_inp.size()); | ||||
VarNodeArray t_inp = new_inp; | VarNodeArray t_inp = new_inp; | ||||
for (size_t i = 0; i < opr->input().size(); i++) { | for (size_t i = 0; i < opr->input().size(); i++) { | ||||
if (!opr->input(i)->shape().eq_shape(new_inp[i]->shape())) { | |||||
mgb_assert(opr->input(i)->shape().ndim == 4 && | |||||
opr->input(i)->format().type() != | |||||
TensorFormat::Type::IMAGE2D_PACK4); | |||||
mgb_assert(new_inp[i]->shape().ndim == 5 && | |||||
new_inp[i]->format().type() == | |||||
TensorFormat::Type::IMAGE2D_PACK4); | |||||
auto param = megdnn::param::RelayoutFormat(); | |||||
param.mode = megdnn::param::RelayoutFormat::Mode::NHWCD4I_NCHW; | |||||
auto rf = opr::RelayoutFormat::make(new_inp[i], param); | |||||
t_inp[i] = rf.node(); | |||||
} | |||||
t_inp[i] = var_to_chw(opr->input(i), new_inp[i]); | |||||
} | } | ||||
auto new_opr = | auto new_opr = | ||||
serialization::copy_opr_shallow(*opr, t_inp, opr->config()); | serialization::copy_opr_shallow(*opr, t_inp, opr->config()); | ||||
@@ -1415,6 +1421,18 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { | |||||
} | } | ||||
}; | }; | ||||
/* This helper function converts the first input to the NCHW format to | |||||
* handle operations that do not support NHWCD4 format | |||||
*/ | |||||
auto relayout_first_inp_to_chw = | |||||
[var_to_chw](OperatorNodeBase* opr, | |||||
const VarNodeArray& new_inp) -> OperatorNodeBase* { | |||||
mgb_assert(opr->input().size() == new_inp.size()); | |||||
VarNodeArray t_inp = new_inp; | |||||
t_inp[0] = var_to_chw(opr->input(0), new_inp[0]); | |||||
return serialization::copy_opr_shallow(*opr, t_inp, opr->config()); | |||||
}; | |||||
auto ret = std::make_unique<ConvertFormatPass>(); | auto ret = std::make_unique<ConvertFormatPass>(); | ||||
ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); | ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); | ||||
auto&& replace_func = ret->m_opr_replace_func; | auto&& replace_func = ret->m_opr_replace_func; | ||||
@@ -1436,6 +1454,9 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { | |||||
replace_func[opr::WarpPerspectiveForward::typeinfo()] = | replace_func[opr::WarpPerspectiveForward::typeinfo()] = | ||||
replace_warp_perspective_opr; | replace_warp_perspective_opr; | ||||
replace_func[opr::WarpAffineForward::typeinfo()] = replace_warp_affine_opr; | replace_func[opr::WarpAffineForward::typeinfo()] = replace_warp_affine_opr; | ||||
replace_func[opr::LocalForward::typeinfo()] = relayout_first_inp_to_chw; | |||||
replace_func[opr::GroupLocalForward::typeinfo()] = | |||||
relayout_first_inp_to_chw; | |||||
return ret; | return ret; | ||||
} | } | ||||
@@ -9,6 +9,7 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
*/ | */ | ||||
#include "megbrain/opr/dnn/local.h" | |||||
#include "megbrain/test/helper.h" | #include "megbrain/test/helper.h" | ||||
#include "megbrain/gopt/inference.h" | #include "megbrain/gopt/inference.h" | ||||
@@ -919,6 +920,69 @@ TEST(TestGoptInference, ConvertFormatNHWCD4) { | |||||
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); | MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); | ||||
} | } | ||||
TEST(TestGoptInference, ConvertFormatNHWCD4LOCAL) { | |||||
// hwcd4 is only supported in naive handle | |||||
NaiveMegDNNHandleScope naive_megdnn_handle; | |||||
HostTensorGenerator<> gen; | |||||
auto cn = CompNode::load("cpu0"); | |||||
auto graph = ComputingGraph::make(); | |||||
graph->options().graph_opt_level = 0; | |||||
auto mkcvar = [&](const char* name, const TensorShape& shp) { | |||||
return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) | |||||
.rename(name); | |||||
}; | |||||
auto host_x = gen({2, 8, 8, 16}, cn); | |||||
auto x = opr::Host2DeviceCopy::make(*graph, host_x); | |||||
opr::Convolution::Param param; | |||||
param.pad_h = param.pad_w = 1; | |||||
auto w1 = mkcvar("w1", {4, 8, 3, 3}), | |||||
conv1 = opr::Convolution::make(x, w1, param); | |||||
auto w2 = mkcvar("w2", {8, 16, 4, 3, 3, 4}), | |||||
local = opr::Local::make(conv1, w2, param); | |||||
auto w3 = mkcvar("w3", {4, 4, 3, 3}), | |||||
conv2 = opr::Convolution::make(local, w3, param); | |||||
opr::GroupLocal::Param param_group_local; | |||||
param_group_local.pad_h = param_group_local.pad_w = 1; | |||||
auto w4 = mkcvar("w4", {2, 8, 16, 2, 3, 3, 2}), | |||||
group_local = opr::GroupLocal::make(conv2, w4, param_group_local); | |||||
auto w5 = mkcvar("w5", {4, 4, 3, 3}), | |||||
y = opr::Convolution::make(group_local, w5, param); | |||||
SymbolVar y_opt; | |||||
unpack_vector( | |||||
gopt::optimize_for_inference( | |||||
{y}, | |||||
gopt::OptimizeForInferenceOptions{}.enable_use_nhwcd4()), | |||||
y_opt); | |||||
ASSERT_EQ(opr::Convolution::Param::Format::NHWCD4, | |||||
find_opr<opr::Convolution>(y_opt).param().format); | |||||
ASSERT_EQ(opr::Local::Param::Format::NCHW, | |||||
find_opr<opr::Local>(y_opt).param().format); | |||||
ASSERT_EQ(opr::GroupLocal::Param::Format::NCHW, | |||||
find_opr<opr::GroupLocal>(y_opt).param().format); | |||||
graph->compile({{y_opt, {}}}) | |||||
->to_json() | |||||
->writeto_fpath(output_file( | |||||
"TestGoptInference.ConvertFormatNHWCD4LOCAL.json")); | |||||
HostTensorND host_y_opt, host_y; | |||||
auto func = graph->compile({make_callback_copy(y, host_y), | |||||
make_callback_copy(y_opt, host_y_opt)}); | |||||
func->execute(); | |||||
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); | |||||
} | |||||
TEST(TestGoptInference, ConvertFormatNHWCD4Deconv) { | TEST(TestGoptInference, ConvertFormatNHWCD4Deconv) { | ||||
// hwcd4 is only supported in naive handle | // hwcd4 is only supported in naive handle | ||||
NaiveMegDNNHandleScope naive_megdnn_handle; | NaiveMegDNNHandleScope naive_megdnn_handle; | ||||