@@ -94,6 +94,11 @@ DefaultConvolution3DBackwardFilterAlgorithm | |||
HandleImpl::m_default_conv3d_bwd_filter_algo; | |||
DefaultBatchConvBiasForwardAlgorithm | |||
HandleImpl::m_default_batch_conv_bias_fwd_algo; | |||
DefaultLocalShareForwardAlgorithm HandleImpl::m_default_local_share_fwd_algo; | |||
DefaultLocalShareBackwardDataAlgorithm | |||
HandleImpl::m_default_local_share_bwd_data_algo; | |||
DefaultLocalShareBackwardFilterAlgorithm | |||
HandleImpl::m_default_local_share_bwd_filter_algo; | |||
HandleImpl::HandleImpl(megcoreComputingHandle_t computing_handle, | |||
HandleType type) | |||
@@ -13,6 +13,7 @@ | |||
#include "src/common/handle_impl.h" | |||
#include "src/naive/convolution/algorithms.h" | |||
#include "src/naive/local_share/algorithms.h" | |||
#include "src/naive/convolution3d/algorithms.h" | |||
#include <functional> | |||
@@ -39,6 +40,11 @@ class HandleImpl : public HandleImplHelper { | |||
m_default_conv3d_bwd_filter_algo; | |||
static DefaultBatchConvBiasForwardAlgorithm | |||
m_default_batch_conv_bias_fwd_algo; | |||
static DefaultLocalShareForwardAlgorithm m_default_local_share_fwd_algo; | |||
static DefaultLocalShareBackwardDataAlgorithm | |||
m_default_local_share_bwd_data_algo; | |||
static DefaultLocalShareBackwardFilterAlgorithm | |||
m_default_local_share_bwd_filter_algo; | |||
//! move KernFunc to alloc_kern()->func, destruct func, and call dispatch | |||
template <typename T> | |||
@@ -91,6 +97,18 @@ public: | |||
return &m_default_batch_conv_bias_fwd_algo; | |||
} | |||
LocalShareForward::Algorithm* default_local_share_fwd_algo() { | |||
return &m_default_local_share_fwd_algo; | |||
} | |||
LocalShareBackwardData::Algorithm* default_local_share_bwd_data_algo() { | |||
return &m_default_local_share_bwd_data_algo; | |||
} | |||
LocalShareBackwardFilter::Algorithm* default_local_share_bwd_filter_algo() { | |||
return &m_default_local_share_bwd_filter_algo; | |||
} | |||
Relayout* relayout_opr() override { | |||
return get_helper_opr<Relayout, 2>(this); | |||
} | |||
@@ -0,0 +1,41 @@ | |||
/** | |||
* \file dnn/src/naive/local_share/algorithms.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/oprs.h" | |||
namespace megdnn { | |||
namespace naive { | |||
class DefaultLocalShareForwardAlgorithm final: | |||
public megdnn::LocalShareForward::Algorithm { | |||
bool is_reproducible() const override | |||
{ return true; } | |||
const char* name() const override | |||
{ return "DEFAULT"; } | |||
}; | |||
class DefaultLocalShareBackwardDataAlgorithm final: | |||
public megdnn::LocalShareBackwardData::Algorithm { | |||
bool is_reproducible() const override | |||
{ return true; } | |||
const char* name() const override | |||
{ return "DEFAULT"; } | |||
}; | |||
class DefaultLocalShareBackwardFilterAlgorithm final: | |||
public megdnn::LocalShareBackwardFilter::Algorithm { | |||
bool is_reproducible() const override | |||
{ return true; } | |||
const char* name() const override | |||
{ return "DEFAULT"; } | |||
}; | |||
} // namespace naive | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -152,4 +152,77 @@ void LocalShareBackwardFilterImpl::exec(_megdnn_tensor_in src, | |||
StrategyBwdFlt>(src, grad, diff, param()));); | |||
} | |||
std::vector<LocalShareForward::Algorithm*> | |||
LocalShareForwardImpl::get_all_algorithms(const TensorLayout&, | |||
const TensorLayout&, | |||
const TensorLayout&) { | |||
return {static_cast<HandleImpl*>(handle())->default_local_share_fwd_algo()}; | |||
} | |||
LocalShareForward::Algorithm* LocalShareForwardImpl::get_algorithm_heuristic( | |||
const TensorLayout& /* src */, const TensorLayout& /* diff */, | |||
const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */, | |||
bool reproducible) { | |||
auto algo = | |||
static_cast<HandleImpl*>(handle())->default_local_share_fwd_algo(); | |||
if (reproducible) { | |||
megdnn_assert(algo->is_reproducible(), | |||
"require reproducible algorithm, but heuristic " | |||
"algorithm(%s) is not " | |||
"reproducible", | |||
algo->name()); | |||
} | |||
return algo; | |||
} | |||
std::vector<LocalShareBackwardData::Algorithm*> | |||
LocalShareBackwardDataImpl::get_all_algorithms(const TensorLayout&, | |||
const TensorLayout&, | |||
const TensorLayout&) { | |||
return {static_cast<HandleImpl*>(handle()) | |||
->default_local_share_bwd_data_algo()}; | |||
} | |||
LocalShareBackwardData::Algorithm* | |||
LocalShareBackwardDataImpl::get_algorithm_heuristic( | |||
const TensorLayout& /* filter */, const TensorLayout& /* diff */, | |||
const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */, | |||
bool reproducible) { | |||
auto algo = static_cast<HandleImpl*>(handle()) | |||
->default_local_share_bwd_data_algo(); | |||
if (reproducible) { | |||
megdnn_assert(algo->is_reproducible(), | |||
"require reproducible algorithm, but heuristic " | |||
"algorithm(%s) is not " | |||
"reproducible", | |||
algo->name()); | |||
} | |||
return algo; | |||
} | |||
std::vector<LocalShareBackwardFilter::Algorithm*> | |||
LocalShareBackwardFilterImpl::get_all_algorithms(const TensorLayout&, | |||
const TensorLayout&, | |||
const TensorLayout&) { | |||
return {static_cast<HandleImpl*>(handle()) | |||
->default_local_share_bwd_filter_algo()}; | |||
} | |||
LocalShareBackwardFilter::Algorithm* | |||
LocalShareBackwardFilterImpl::get_algorithm_heuristic( | |||
const TensorLayout& /* src */, const TensorLayout& /* diff */, | |||
const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */, | |||
bool reproducible) { | |||
auto algo = static_cast<HandleImpl*>(handle()) | |||
->default_local_share_bwd_filter_algo(); | |||
if (reproducible) { | |||
megdnn_assert(algo->is_reproducible(), | |||
"require reproducible algorithm, but heuristic " | |||
"algorithm(%s) is not " | |||
"reproducible", | |||
algo->name()); | |||
} | |||
return algo; | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -27,17 +27,13 @@ public: | |||
std::vector<Algorithm*> get_all_algorithms( | |||
const TensorLayout& /*src*/, const TensorLayout& /*filter*/, | |||
const TensorLayout& /*dst*/) override { | |||
return {}; | |||
} | |||
const TensorLayout& /*dst*/) override; | |||
Algorithm* get_algorithm_heuristic(const TensorLayout& /*src*/, | |||
const TensorLayout& /*filter*/, | |||
const TensorLayout& /*dst*/, | |||
size_t /*workspace_limit_in_bytes*/, | |||
bool /*reproducible*/) override { | |||
return nullptr; | |||
} | |||
bool /*reproducible*/) override; | |||
const char* get_algorithm_set_name() const override { return "DEFAULT"; } | |||
}; | |||
@@ -55,17 +51,13 @@ public: | |||
std::vector<Algorithm*> get_all_algorithms( | |||
const TensorLayout& /*filter*/, const TensorLayout& /*diff*/, | |||
const TensorLayout& /*grad*/) override { | |||
return {}; | |||
} | |||
const TensorLayout& /*grad*/) override; | |||
Algorithm* get_algorithm_heuristic(const TensorLayout& /*filter*/, | |||
const TensorLayout& /*diff*/, | |||
const TensorLayout& /*grad*/, | |||
size_t /*workspace_limit_in_bytes*/, | |||
bool /*reproducible*/) override { | |||
return nullptr; | |||
} | |||
bool /*reproducible*/) override; | |||
const char* get_algorithm_set_name() const override { return "DEFAULT"; } | |||
}; | |||
@@ -83,17 +75,13 @@ public: | |||
std::vector<Algorithm*> get_all_algorithms( | |||
const TensorLayout& /*src*/, const TensorLayout& /*diff*/, | |||
const TensorLayout& /*grad*/) override { | |||
return {}; | |||
} | |||
const TensorLayout& /*grad*/) override; | |||
Algorithm* get_algorithm_heuristic(const TensorLayout& /*src*/, | |||
const TensorLayout& /*diff*/, | |||
const TensorLayout& /*grad*/, | |||
size_t /*workspace_limit_in_bytes*/, | |||
bool /*reproducible*/) override { | |||
return nullptr; | |||
} | |||
bool /*reproducible*/) override; | |||
const char* get_algorithm_set_name() const override { return "DEFAULT"; } | |||
}; | |||
@@ -14,6 +14,7 @@ | |||
#include "megbrain/gopt/basic_arith.h" | |||
#include "megbrain/graph/event.h" | |||
#include "megbrain/opr/dnn/batch_norm.h" | |||
#include "megbrain/opr/dnn/local.h" | |||
#include "megbrain/utils/shared_set.h" | |||
#include "megbrain/serialization/opr_shallow_copy.h" | |||
#include "megbrain/opr/basic_arith.h" | |||
@@ -1358,23 +1359,28 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { | |||
return new_pooling_opr.node()->owner_opr(); | |||
}; | |||
auto relayout_inp_to_chw = [](OperatorNodeBase* opr, | |||
auto var_to_chw = [](VarNode* inp, VarNode* new_inp) { | |||
if (!inp->shape().eq_shape(new_inp->shape())) { | |||
mgb_assert(inp->shape().ndim == 4 && | |||
inp->format().type() != | |||
TensorFormat::Type::IMAGE2D_PACK4); | |||
mgb_assert(new_inp->shape().ndim == 5 && | |||
new_inp->format().type() == | |||
TensorFormat::Type::IMAGE2D_PACK4); | |||
auto param = megdnn::param::RelayoutFormat(); | |||
param.mode = megdnn::param::RelayoutFormat::Mode::NHWCD4I_NCHW; | |||
auto rf = opr::RelayoutFormat::make(new_inp, param); | |||
return rf.node(); | |||
} | |||
return new_inp; | |||
}; | |||
auto relayout_inp_to_chw = [var_to_chw](OperatorNodeBase* opr, | |||
const VarNodeArray& new_inp) { | |||
mgb_assert(opr->input().size() == new_inp.size()); | |||
VarNodeArray t_inp = new_inp; | |||
for (size_t i = 0; i < opr->input().size(); i++) { | |||
if (!opr->input(i)->shape().eq_shape(new_inp[i]->shape())) { | |||
mgb_assert(opr->input(i)->shape().ndim == 4 && | |||
opr->input(i)->format().type() != | |||
TensorFormat::Type::IMAGE2D_PACK4); | |||
mgb_assert(new_inp[i]->shape().ndim == 5 && | |||
new_inp[i]->format().type() == | |||
TensorFormat::Type::IMAGE2D_PACK4); | |||
auto param = megdnn::param::RelayoutFormat(); | |||
param.mode = megdnn::param::RelayoutFormat::Mode::NHWCD4I_NCHW; | |||
auto rf = opr::RelayoutFormat::make(new_inp[i], param); | |||
t_inp[i] = rf.node(); | |||
} | |||
t_inp[i] = var_to_chw(opr->input(i), new_inp[i]); | |||
} | |||
auto new_opr = | |||
serialization::copy_opr_shallow(*opr, t_inp, opr->config()); | |||
@@ -1415,6 +1421,18 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { | |||
} | |||
}; | |||
/* This helper function converts the first input to the NCHW format to | |||
* handle operations that do not support NHWCD4 format | |||
*/ | |||
auto relayout_first_inp_to_chw = | |||
[var_to_chw](OperatorNodeBase* opr, | |||
const VarNodeArray& new_inp) -> OperatorNodeBase* { | |||
mgb_assert(opr->input().size() == new_inp.size()); | |||
VarNodeArray t_inp = new_inp; | |||
t_inp[0] = var_to_chw(opr->input(0), new_inp[0]); | |||
return serialization::copy_opr_shallow(*opr, t_inp, opr->config()); | |||
}; | |||
auto ret = std::make_unique<ConvertFormatPass>(); | |||
ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); | |||
auto&& replace_func = ret->m_opr_replace_func; | |||
@@ -1436,6 +1454,9 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { | |||
replace_func[opr::WarpPerspectiveForward::typeinfo()] = | |||
replace_warp_perspective_opr; | |||
replace_func[opr::WarpAffineForward::typeinfo()] = replace_warp_affine_opr; | |||
replace_func[opr::LocalForward::typeinfo()] = relayout_first_inp_to_chw; | |||
replace_func[opr::GroupLocalForward::typeinfo()] = | |||
relayout_first_inp_to_chw; | |||
return ret; | |||
} | |||
@@ -9,6 +9,7 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "megbrain/opr/dnn/local.h" | |||
#include "megbrain/test/helper.h" | |||
#include "megbrain/gopt/inference.h" | |||
@@ -919,6 +920,69 @@ TEST(TestGoptInference, ConvertFormatNHWCD4) { | |||
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); | |||
} | |||
TEST(TestGoptInference, ConvertFormatNHWCD4LOCAL) { | |||
// hwcd4 is only supported in naive handle | |||
NaiveMegDNNHandleScope naive_megdnn_handle; | |||
HostTensorGenerator<> gen; | |||
auto cn = CompNode::load("cpu0"); | |||
auto graph = ComputingGraph::make(); | |||
graph->options().graph_opt_level = 0; | |||
auto mkcvar = [&](const char* name, const TensorShape& shp) { | |||
return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) | |||
.rename(name); | |||
}; | |||
auto host_x = gen({2, 8, 8, 16}, cn); | |||
auto x = opr::Host2DeviceCopy::make(*graph, host_x); | |||
opr::Convolution::Param param; | |||
param.pad_h = param.pad_w = 1; | |||
auto w1 = mkcvar("w1", {4, 8, 3, 3}), | |||
conv1 = opr::Convolution::make(x, w1, param); | |||
auto w2 = mkcvar("w2", {8, 16, 4, 3, 3, 4}), | |||
local = opr::Local::make(conv1, w2, param); | |||
auto w3 = mkcvar("w3", {4, 4, 3, 3}), | |||
conv2 = opr::Convolution::make(local, w3, param); | |||
opr::GroupLocal::Param param_group_local; | |||
param_group_local.pad_h = param_group_local.pad_w = 1; | |||
auto w4 = mkcvar("w4", {2, 8, 16, 2, 3, 3, 2}), | |||
group_local = opr::GroupLocal::make(conv2, w4, param_group_local); | |||
auto w5 = mkcvar("w5", {4, 4, 3, 3}), | |||
y = opr::Convolution::make(group_local, w5, param); | |||
SymbolVar y_opt; | |||
unpack_vector( | |||
gopt::optimize_for_inference( | |||
{y}, | |||
gopt::OptimizeForInferenceOptions{}.enable_use_nhwcd4()), | |||
y_opt); | |||
ASSERT_EQ(opr::Convolution::Param::Format::NHWCD4, | |||
find_opr<opr::Convolution>(y_opt).param().format); | |||
ASSERT_EQ(opr::Local::Param::Format::NCHW, | |||
find_opr<opr::Local>(y_opt).param().format); | |||
ASSERT_EQ(opr::GroupLocal::Param::Format::NCHW, | |||
find_opr<opr::GroupLocal>(y_opt).param().format); | |||
graph->compile({{y_opt, {}}}) | |||
->to_json() | |||
->writeto_fpath(output_file( | |||
"TestGoptInference.ConvertFormatNHWCD4LOCAL.json")); | |||
HostTensorND host_y_opt, host_y; | |||
auto func = graph->compile({make_callback_copy(y, host_y), | |||
make_callback_copy(y_opt, host_y_opt)}); | |||
func->execute(); | |||
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); | |||
} | |||
TEST(TestGoptInference, ConvertFormatNHWCD4Deconv) { | |||
// hwcd4 is only supported in naive handle | |||
NaiveMegDNNHandleScope naive_megdnn_handle; | |||