GitOrigin-RevId: 51e025973f
release-1.2
@@ -43,6 +43,12 @@ pdef('Axis').add_fields('int32', 'axis', 0) | |||||
Doc('NCHW4_NCHW32', 'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'), | Doc('NCHW4_NCHW32', 'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'), | ||||
Doc('NCHW32_NCHW4', 'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'), | Doc('NCHW32_NCHW4', 'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'), | ||||
Doc('NCHW4_NCHW', 'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'), | Doc('NCHW4_NCHW', 'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'), | ||||
Doc('NHWC_NCHW', 'NHWC_NCHW means input tensors are nhwc layout, ' | |||||
'output tensor is nchw layout'), | |||||
Doc('NHWC_NCHW4_IC_SMALL', 'NHWC_NCHW4_IC_SMALL means input tensors are nhwc(c < 4) layout, ' | |||||
'output tensor is nchw4 layout, padding c=4'), | |||||
Doc('NCHW_NCHW4_IC_SMALL', 'NCHW_NCHW4_IC_SMALL means input tensors are nchw(c < 4) layout, ' | |||||
'output tensor is nchw4 layout, padding c=4'), | |||||
Doc('CHWN4', 'CHWN4 is currently only used on Nvidia platform for fast implementation ' | Doc('CHWN4', 'CHWN4 is currently only used on Nvidia platform for fast implementation ' | ||||
'of convolution using CUDA/SASS. The channels are splitted to groups of 4 channels.')) | 'of convolution using CUDA/SASS. The channels are splitted to groups of 4 channels.')) | ||||
) | ) | ||||
@@ -6,7 +6,8 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
@@ -14,20 +15,17 @@ | |||||
namespace megdnn { | namespace megdnn { | ||||
void WarpPerspectiveBase::check_layout_fwd(const TensorLayout &src, | |||||
const TensorLayout &mat, | |||||
const TensorLayout &mat_idx, | |||||
const TensorLayout &dst) | |||||
{ | |||||
void WarpPerspectiveBase::check_layout_fwd(const TensorLayout& src, | |||||
const TensorLayout& mat, | |||||
const TensorLayout& mat_idx, | |||||
const TensorLayout& dst) { | |||||
megdnn_assert_contiguous(mat); | megdnn_assert_contiguous(mat); | ||||
megdnn_assert_contiguous(src); | megdnn_assert_contiguous(src); | ||||
megdnn_assert_contiguous(dst); | megdnn_assert_contiguous(dst); | ||||
auto errmsg = [&]() { | auto errmsg = [&]() { | ||||
return megdnn_layout_msg(src) + ", " + | |||||
megdnn_layout_msg(mat) + ", " + | |||||
megdnn_layout_msg(mat_idx) + ", " + | |||||
megdnn_layout_msg(dst) + ", " + | |||||
param_msg(); | |||||
return megdnn_layout_msg(src) + ", " + megdnn_layout_msg(mat) + ", " + | |||||
megdnn_layout_msg(mat_idx) + ", " + megdnn_layout_msg(dst) + | |||||
", " + param_msg(); | |||||
}; | }; | ||||
MEGDNN_MARK_USED_VAR(errmsg); | MEGDNN_MARK_USED_VAR(errmsg); | ||||
if (param().format == param::WarpPerspective::Format::NHWCD4 || | if (param().format == param::WarpPerspective::Format::NHWCD4 || | ||||
@@ -35,9 +33,17 @@ void WarpPerspectiveBase::check_layout_fwd(const TensorLayout &src, | |||||
megdnn_assert(src.ndim == 5_z, "%s", errmsg().c_str()); | megdnn_assert(src.ndim == 5_z, "%s", errmsg().c_str()); | ||||
megdnn_assert(dst.ndim == 5_z, "%s", errmsg().c_str()); | megdnn_assert(dst.ndim == 5_z, "%s", errmsg().c_str()); | ||||
} else if (param().format == | |||||
param::WarpPerspective::Format::NHWC_NCHW4_IC_SMALL || | |||||
param().format == | |||||
param::WarpPerspective::Format::NCHW_NCHW4_IC_SMALL) { | |||||
megdnn_assert(src.ndim == 4_z, "%s", errmsg().c_str()); | |||||
megdnn_assert(dst.ndim == 5_z, "%s", errmsg().c_str()); | |||||
} else { | } else { | ||||
megdnn_assert(param().format == param::WarpPerspective::Format::NHWC || | megdnn_assert(param().format == param::WarpPerspective::Format::NHWC || | ||||
param().format == param::WarpPerspective::Format::NCHW); | |||||
param().format == param::WarpPerspective::Format::NCHW || | |||||
param().format == | |||||
param::WarpPerspective::Format::NHWC_NCHW); | |||||
megdnn_assert(src.ndim == 4_z, "%s", errmsg().c_str()); | megdnn_assert(src.ndim == 4_z, "%s", errmsg().c_str()); | ||||
megdnn_assert(dst.ndim == 4_z, "%s", errmsg().c_str()); | megdnn_assert(dst.ndim == 4_z, "%s", errmsg().c_str()); | ||||
} | } | ||||
@@ -45,7 +51,7 @@ void WarpPerspectiveBase::check_layout_fwd(const TensorLayout &src, | |||||
megdnn_assert(dst.shape[0] == mat.shape[0], "%s", errmsg().c_str()); | megdnn_assert(dst.shape[0] == mat.shape[0], "%s", errmsg().c_str()); | ||||
if (mat_idx.ndim) { | if (mat_idx.ndim) { | ||||
megdnn_assert(mat_idx.dtype == dtype::Int32() && mat_idx.ndim == 1, | megdnn_assert(mat_idx.dtype == dtype::Int32() && mat_idx.ndim == 1, | ||||
"%s", errmsg().c_str()); | |||||
"%s", errmsg().c_str()); | |||||
megdnn_assert(mat.shape[0] == mat_idx.shape[0], "%s", errmsg().c_str()); | megdnn_assert(mat.shape[0] == mat_idx.shape[0], "%s", errmsg().c_str()); | ||||
megdnn_assert_contiguous(mat_idx); | megdnn_assert_contiguous(mat_idx); | ||||
} else { | } else { | ||||
@@ -54,35 +60,103 @@ void WarpPerspectiveBase::check_layout_fwd(const TensorLayout &src, | |||||
megdnn_assert(mat.shape[1] == 3_z, "%s", errmsg().c_str()); | megdnn_assert(mat.shape[1] == 3_z, "%s", errmsg().c_str()); | ||||
megdnn_assert(mat.shape[2] == 3_z, "%s", errmsg().c_str()); | megdnn_assert(mat.shape[2] == 3_z, "%s", errmsg().c_str()); | ||||
if (param().format == param::WarpPerspective::Format::NCHW) { | |||||
megdnn_assert( | |||||
src.dtype.enumv() == DTypeEnum::Float32 || | |||||
MEGDNN_FLOAT16_SELECT( | |||||
(src.dtype.enumv() == DTypeEnum::Float16 || | |||||
src.dtype.enumv() == DTypeEnum::BFloat16), | |||||
false) || | |||||
src.dtype.enumv() == DTypeEnum::Int8 || | |||||
src.dtype.enumv() == DTypeEnum::Uint8 || | |||||
(src.dtype.enumv() == DTypeEnum::QuantizedS8 || | |||||
src.dtype.enumv() == DTypeEnum::Quantized8Asymm), | |||||
"WarpPerspective NCHW input dtype should be " | |||||
"Float32/Int8/Uint8/QInt8/QUint8" MEGDNN_FLOAT16_SELECT( | |||||
"/Float16/BFloat16", "") "."); | |||||
megdnn_assert( | |||||
(src.dtype.category() == DTypeCategory::FLOAT && | |||||
(src.dtype == mat.dtype || | |||||
mat.dtype.enumv() == DTypeEnum::Float32)) || | |||||
((src.dtype.category() == DTypeCategory::INT || | |||||
src.dtype.category() == DTypeCategory::QUANTIZED) && | |||||
mat.dtype.enumv() == DTypeEnum::Float32), | |||||
"The input to WarpPerspective is in NCHW format, in this " | |||||
"case, if the input dtype is floating point, the " | |||||
"transformation matrix should have same dtype as the " | |||||
"input, otherwise, it should be in Float32, %s given.", | |||||
mat.dtype.name()); | |||||
if (src.format == dst.format && dst.dtype == src.dtype) { | |||||
if (param().format == param::WarpPerspective::Format::NCHW) { | |||||
megdnn_assert( | |||||
src.dtype.enumv() == DTypeEnum::Float32 || | |||||
MEGDNN_FLOAT16_SELECT( | |||||
(src.dtype.enumv() == DTypeEnum::Float16 || | |||||
src.dtype.enumv() == DTypeEnum::BFloat16), | |||||
false) || | |||||
src.dtype.enumv() == DTypeEnum::Int8 || | |||||
src.dtype.enumv() == DTypeEnum::Uint8 || | |||||
(src.dtype.enumv() == DTypeEnum::QuantizedS8 || | |||||
src.dtype.enumv() == DTypeEnum::Quantized8Asymm), | |||||
"WarpPerspective NCHW input dtype should be " | |||||
"Float32/Int8/Uint8/QInt8/QUint8" MEGDNN_FLOAT16_SELECT( | |||||
"/Float16/BFloat16", "") "."); | |||||
megdnn_assert( | |||||
(src.dtype.category() == DTypeCategory::FLOAT && | |||||
(src.dtype == mat.dtype || | |||||
mat.dtype.enumv() == DTypeEnum::Float32)) || | |||||
((src.dtype.category() == DTypeCategory::INT || | |||||
src.dtype.category() == | |||||
DTypeCategory::QUANTIZED) && | |||||
mat.dtype.enumv() == DTypeEnum::Float32), | |||||
"The input to WarpPerspective is in NCHW format, in this " | |||||
"case, if the input dtype is floating point, the " | |||||
"transformation matrix should have same dtype as the " | |||||
"input, otherwise, it should be in Float32, %s given.", | |||||
mat.dtype.name()); | |||||
megdnn_assert(src.shape[1] == dst.shape[1], "%s", errmsg().c_str()); | |||||
megdnn_assert(dst.dtype == src.dtype); | |||||
megdnn_assert(src.shape[1] == dst.shape[1], "%s", errmsg().c_str()); | |||||
megdnn_assert(param().imode == | |||||
param::WarpPerspective::InterpolationMode::LINEAR); | |||||
megdnn_assert(param().bmode != | |||||
param::WarpPerspective::BorderMode::TRANSPARENT); | |||||
megdnn_assert(param().bmode != | |||||
param::WarpPerspective::BorderMode::ISOLATED); | |||||
} else if (param().format == param::WarpPerspective::Format::NHWC) { | |||||
megdnn_assert(src.shape[3] == dst.shape[3], "%s", errmsg().c_str()); | |||||
} else if (param().format == param::WarpPerspective::Format::NCHW4) { | |||||
megdnn_assert(src.dtype.enumv() == DTypeEnum::QuantizedS8, | |||||
"src expected QuantizedS8, but got %s", | |||||
src.dtype.name()); | |||||
megdnn_assert(mat.dtype == dtype::Float32(), | |||||
"matrix dtype expected float, got %s", | |||||
mat.dtype.name()); | |||||
megdnn_assert(src.shape[4] == 4 && dst.shape[4] == 4); | |||||
megdnn_assert(src.shape[1] == dst.shape[1], "%s", errmsg().c_str()); | |||||
megdnn_assert(param().imode == | |||||
param::WarpPerspective::InterpolationMode::LINEAR); | |||||
megdnn_assert(param().bmode != | |||||
param::WarpPerspective::BorderMode::TRANSPARENT); | |||||
megdnn_assert(param().bmode != | |||||
param::WarpPerspective::BorderMode::ISOLATED); | |||||
} else { | |||||
megdnn_assert(param().format == | |||||
param::WarpPerspective::Format::NHWCD4); | |||||
megdnn_assert( | |||||
src.dtype == dtype::Float32() || | |||||
MEGDNN_FLOAT16_SELECT( | |||||
(src.dtype == dtype::Float16() || | |||||
src.dtype == dtype::BFloat16()), | |||||
false) || | |||||
src.dtype.enumv() == DTypeEnum::QuantizedS8 || | |||||
src.dtype.enumv() == DTypeEnum::Quantized8Asymm, | |||||
"WarpPerspective NHWCD4 input dtype should be " | |||||
"Float32" MEGDNN_FLOAT16_SELECT( | |||||
"/Float16/BFloat16", | |||||
"") ",QunatizedS8, Quantized8Asymm."); | |||||
megdnn_assert( | |||||
(src.dtype == mat.dtype || mat.dtype == dtype::Float32()), | |||||
"The input to WarpPerspective is in NHWCD4 format, in this " | |||||
"case, if the input dtype is floating point, the " | |||||
"transformation matrix should have same dtype as the " | |||||
"input, %s given.", | |||||
mat.dtype.name()); | |||||
//! number of channels is same | |||||
megdnn_assert(src.shape[2] == dst.shape[2], "%s", errmsg().c_str()); | |||||
megdnn_assert(param().imode == | |||||
param::WarpPerspective::InterpolationMode::LINEAR); | |||||
megdnn_assert(param().bmode != | |||||
param::WarpPerspective::BorderMode::TRANSPARENT); | |||||
megdnn_assert(param().bmode != | |||||
param::WarpPerspective::BorderMode::ISOLATED); | |||||
} | |||||
} else if (param().format == | |||||
param::WarpPerspective::Format::NHWC_NCHW4_IC_SMALL || | |||||
param().format == | |||||
param::WarpPerspective::Format::NCHW_NCHW4_IC_SMALL) { | |||||
megdnn_assert((src.dtype.enumv() == DTypeEnum::Quantized8Asymm || | |||||
src.dtype.enumv() == DTypeEnum::Uint8), | |||||
"src expected Quantized8Asymm or Uint8, but got %s", | |||||
src.dtype.name()); | |||||
megdnn_assert(mat.dtype == dtype::Float32(), | |||||
"matrix dtype expected float, got %s", mat.dtype.name()); | |||||
megdnn_assert(dst.shape[4] == 4); | |||||
megdnn_assert(param().imode == | megdnn_assert(param().imode == | ||||
param::WarpPerspective::InterpolationMode::LINEAR); | param::WarpPerspective::InterpolationMode::LINEAR); | ||||
@@ -90,16 +164,14 @@ void WarpPerspectiveBase::check_layout_fwd(const TensorLayout &src, | |||||
param::WarpPerspective::BorderMode::TRANSPARENT); | param::WarpPerspective::BorderMode::TRANSPARENT); | ||||
megdnn_assert(param().bmode != | megdnn_assert(param().bmode != | ||||
param::WarpPerspective::BorderMode::ISOLATED); | param::WarpPerspective::BorderMode::ISOLATED); | ||||
} else if (param().format == param::WarpPerspective::Format::NHWC) { | |||||
megdnn_assert(src.shape[3] == dst.shape[3], "%s", errmsg().c_str()); | |||||
} else if (param().format == param::WarpPerspective::Format::NCHW4) { | |||||
megdnn_assert(dst.dtype == src.dtype); | |||||
megdnn_assert(src.dtype.enumv() == DTypeEnum::QuantizedS8, | |||||
"src expected QuantizedS8, but got %s", src.dtype.name()); | |||||
} else if (param().format == param::WarpPerspective::Format::NHWC_NCHW) { | |||||
megdnn_assert((src.dtype.enumv() == DTypeEnum::Quantized8Asymm || | |||||
src.dtype.enumv() == DTypeEnum::Uint8), | |||||
"src expected Quantized8Asymm or Uint8, but got %s", | |||||
src.dtype.name()); | |||||
megdnn_assert(mat.dtype == dtype::Float32(), | megdnn_assert(mat.dtype == dtype::Float32(), | ||||
"matrix dtype expected float, got %s", mat.dtype.name()); | "matrix dtype expected float, got %s", mat.dtype.name()); | ||||
megdnn_assert(src.shape[4] == 4 && dst.shape[4] == 4); | |||||
megdnn_assert(src.shape[1] == dst.shape[1], "%s", errmsg().c_str()); | |||||
megdnn_assert(src.shape[3] == dst.shape[1], "%s", errmsg().c_str()); | |||||
megdnn_assert(param().imode == | megdnn_assert(param().imode == | ||||
param::WarpPerspective::InterpolationMode::LINEAR); | param::WarpPerspective::InterpolationMode::LINEAR); | ||||
@@ -108,40 +180,14 @@ void WarpPerspectiveBase::check_layout_fwd(const TensorLayout &src, | |||||
megdnn_assert(param().bmode != | megdnn_assert(param().bmode != | ||||
param::WarpPerspective::BorderMode::ISOLATED); | param::WarpPerspective::BorderMode::ISOLATED); | ||||
} else { | } else { | ||||
megdnn_assert(param().format == param::WarpPerspective::Format::NHWCD4); | |||||
megdnn_assert( | |||||
src.dtype == dtype::Float32() || | |||||
MEGDNN_FLOAT16_SELECT((src.dtype == dtype::Float16() || | |||||
src.dtype == dtype::BFloat16()), | |||||
false) || | |||||
src.dtype.enumv() == DTypeEnum::QuantizedS8 || | |||||
src.dtype.enumv() == DTypeEnum::Quantized8Asymm, | |||||
"WarpPerspective NHWCD4 input dtype should be " | |||||
"Float32" MEGDNN_FLOAT16_SELECT( | |||||
"/Float16/BFloat16", | |||||
"") ",QunatizedS8, Quantized8Asymm."); | |||||
megdnn_assert( | |||||
(src.dtype == mat.dtype || mat.dtype == dtype::Float32()), | |||||
"The input to WarpPerspective is in NHWCD4 format, in this " | |||||
"case, if the input dtype is floating point, the " | |||||
"transformation matrix should have same dtype as the " | |||||
"input, %s given.", | |||||
mat.dtype.name()); | |||||
megdnn_assert(dst.dtype == src.dtype); | |||||
//! number of channels is same | |||||
megdnn_assert(src.shape[2] == dst.shape[2], "%s", errmsg().c_str()); | |||||
megdnn_assert(param().imode == | |||||
param::WarpPerspective::InterpolationMode::LINEAR); | |||||
megdnn_assert(param().bmode != | |||||
param::WarpPerspective::BorderMode::TRANSPARENT); | |||||
megdnn_assert(param().bmode != | |||||
param::WarpPerspective::BorderMode::ISOLATED); | |||||
megdnn_assert(param().format == param::WarpPerspective::Format::NCHW); | |||||
megdnn_assert((src.dtype.enumv() == DTypeEnum::Quantized8Asymm || | |||||
src.dtype.enumv() == DTypeEnum::Uint8) && | |||||
dst.dtype.enumv() == DTypeEnum::Float32); | |||||
} | } | ||||
megdnn_assert(src.format == dst.format); | |||||
} | } | ||||
std::string WarpPerspectiveBase::param_msg() const | |||||
{ | |||||
std::string WarpPerspectiveBase::param_msg() const { | |||||
std::string res; | std::string res; | ||||
res.append(megdnn_mangle("imode=")); | res.append(megdnn_mangle("imode=")); | ||||
switch (param().imode) { | switch (param().imode) { | ||||
@@ -191,31 +237,25 @@ std::string WarpPerspectiveBase::param_msg() const | |||||
return res; | return res; | ||||
} | } | ||||
int WarpPerspectiveBase::get_real_coord(int p, int len) | |||||
{ | |||||
int WarpPerspectiveBase::get_real_coord(int p, int len) { | |||||
auto bmode = param().bmode; | auto bmode = param().bmode; | ||||
if( (unsigned)p < (unsigned)len ) | |||||
if ((unsigned)p < (unsigned)len) | |||||
; | ; | ||||
else if( bmode == BorderMode::REPLICATE ) | |||||
else if (bmode == BorderMode::REPLICATE) | |||||
p = p < 0 ? 0 : len - 1; | p = p < 0 ? 0 : len - 1; | ||||
else if( bmode == BorderMode::REFLECT || bmode == BorderMode::REFLECT_101 ) | |||||
{ | |||||
else if (bmode == BorderMode::REFLECT || bmode == BorderMode::REFLECT_101) { | |||||
int delta = (bmode == BorderMode::REFLECT_101); | int delta = (bmode == BorderMode::REFLECT_101); | ||||
if( len == 1 ) | |||||
if (len == 1) | |||||
return 0; | return 0; | ||||
do | |||||
{ | |||||
if( p < 0 ) | |||||
do { | |||||
if (p < 0) | |||||
p = -p - 1 + delta; | p = -p - 1 + delta; | ||||
else | else | ||||
p = len - 1 - (p - len) - delta; | p = len - 1 - (p - len) - delta; | ||||
} | |||||
while( (unsigned)p >= (unsigned)len ); | |||||
} | |||||
else if( bmode == BorderMode::WRAP ) | |||||
{ | |||||
if( p < 0 ) | |||||
p -= ((p-len+1)/len)*len; | |||||
} while ((unsigned)p >= (unsigned)len); | |||||
} else if (bmode == BorderMode::WRAP) { | |||||
if (p < 0) | |||||
p -= ((p - len + 1) / len) * len; | |||||
/* | /* | ||||
if( p >= len ) | if( p >= len ) | ||||
p %= len; | p %= len; | ||||
@@ -223,18 +263,16 @@ int WarpPerspectiveBase::get_real_coord(int p, int len) | |||||
while (p >= len) { | while (p >= len) { | ||||
p -= len; | p -= len; | ||||
} | } | ||||
} | |||||
else if( bmode == BorderMode::CONSTANT ) | |||||
} else if (bmode == BorderMode::CONSTANT) | |||||
p = -1; | p = -1; | ||||
return p; | return p; | ||||
} | } | ||||
void WarpPerspectiveForward::check_exec(const TensorLayout &src, | |||||
const TensorLayout &mat, | |||||
const TensorLayout &mat_idx, | |||||
const TensorLayout &dst, | |||||
size_t workspace_in_bytes) | |||||
{ | |||||
void WarpPerspectiveForward::check_exec(const TensorLayout& src, | |||||
const TensorLayout& mat, | |||||
const TensorLayout& mat_idx, | |||||
const TensorLayout& dst, | |||||
size_t workspace_in_bytes) { | |||||
check_exec_allow_nhwc_mat_idx(src, mat, mat_idx, dst, workspace_in_bytes); | check_exec_allow_nhwc_mat_idx(src, mat, mat_idx, dst, workspace_in_bytes); | ||||
} | } | ||||
@@ -248,7 +286,10 @@ void WarpPerspectiveForward::check_exec_allow_nhwc_mat_idx( | |||||
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | ||||
if (param().format != Param::Format::NHWC && | if (param().format != Param::Format::NHWC && | ||||
param().format != Param::Format::NCHW && | param().format != Param::Format::NCHW && | ||||
param().format != Param::Format::NCHW4) { | |||||
param().format != Param::Format::NCHW4 && | |||||
param().format != Param::Format::NHWC_NCHW && | |||||
param().format != Param::Format::NHWC_NCHW4_IC_SMALL && | |||||
param().format != Param::Format::NCHW_NCHW4_IC_SMALL) { | |||||
megdnn_assert(!mat_idx.ndim, | megdnn_assert(!mat_idx.ndim, | ||||
"mat_idx not supported for current format"); | "mat_idx not supported for current format"); | ||||
} | } | ||||
@@ -263,7 +304,8 @@ void WarpPerspectiveBackwardData::check_exec(const TensorLayout& mat, | |||||
megdnn_assert(grad.dtype == dtype::Float32() MEGDNN_INC_FLOAT16( | megdnn_assert(grad.dtype == dtype::Float32() MEGDNN_INC_FLOAT16( | ||||
|| grad.dtype == dtype::BFloat16()), | || grad.dtype == dtype::BFloat16()), | ||||
"Backward WarpPerspective only supports Float32/BFloat16."); | "Backward WarpPerspective only supports Float32/BFloat16."); | ||||
auto required_workspace_in_bytes = get_workspace_in_bytes(mat, mat_idx, diff, grad); | |||||
auto required_workspace_in_bytes = | |||||
get_workspace_in_bytes(mat, mat_idx, diff, grad); | |||||
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | ||||
} | } | ||||
@@ -283,6 +325,6 @@ void WarpPerspectiveBackwardMat::check_exec(const TensorLayout& src, | |||||
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | ||||
} | } | ||||
} // namespace megdnn | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -12,6 +12,7 @@ | |||||
#pragma once | #pragma once | ||||
#include <cuda_runtime_api.h> | #include <cuda_runtime_api.h> | ||||
#include "src/common/cv/enums.h" | #include "src/common/cv/enums.h" | ||||
#include "src/cuda/utils.cuh" | |||||
#include "megcore_cdefs.h" | #include "megcore_cdefs.h" | ||||
namespace megdnn { | namespace megdnn { | ||||
@@ -34,6 +35,22 @@ void forward_proxy_nchw4(const ctype* src, const float* mat, const int* mat_idx, | |||||
megcore::AsyncErrorInfo* error_info, | megcore::AsyncErrorInfo* error_info, | ||||
void* error_tracker, cudaStream_t stream); | void* error_tracker, cudaStream_t stream); | ||||
template <typename src_dtype, typename src_ctype, typename dst_ctype> | |||||
void forward_proxy_quint8_dimshuffle_typecvt_nchw4( | |||||
bool is_nhwc, const src_ctype* src, const float* mat, | |||||
const int* mat_idx, dst_ctype* dst, int N_SRC, int N_MAT, int C, int IH, | |||||
int IW, int OH, int OW, src_ctype bval, DTypeParamImpl<src_dtype> param, | |||||
BorderMode bmode, megcore::AsyncErrorInfo* error_info, | |||||
void* error_tracker, cudaStream_t stream); | |||||
template <typename src_dtype, typename src_ctype, typename dst_ctype> | |||||
void forward_proxy_quint8_dimshuffle_typecvt_nchw( | |||||
bool is_nhwc, const src_ctype* src, const float* mat, | |||||
const int* mat_idx, dst_ctype* dst, int N_SRC, int N_MAT, int C, int IH, | |||||
int IW, int OH, int OW, src_ctype bval, DTypeParamImpl<src_dtype> param, | |||||
BorderMode bmode, megcore::AsyncErrorInfo* error_info, | |||||
void* error_tracker, cudaStream_t stream); | |||||
void backward_data_proxy(const float* mat, const int* midx, const float* diff, | void backward_data_proxy(const float* mat, const int* midx, const float* diff, | ||||
float* grad, float* workspace, int N, int N_SRC, int C, | float* grad, float* workspace, int N, int N_SRC, int C, | ||||
int IH, int IW, int OH, int OW, float bval, | int IH, int IW, int OH, int OW, float bval, | ||||
@@ -6,7 +6,8 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#include "src/cuda/warp_perspective/opr_impl.h" | #include "src/cuda/warp_perspective/opr_impl.h" | ||||
#include "src/cuda/warp_perspective/warp_perspective_cv.cuh" | #include "src/cuda/warp_perspective/warp_perspective_cv.cuh" | ||||
@@ -166,6 +167,30 @@ void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in ssrc, | |||||
IW = src.layout.shape[3]; | IW = src.layout.shape[3]; | ||||
OH = dst.layout.shape[2]; | OH = dst.layout.shape[2]; | ||||
OW = dst.layout.shape[3]; | OW = dst.layout.shape[3]; | ||||
} else if (param().format == Param::Format::NHWC_NCHW) { | |||||
C = src.layout.shape[3]; | |||||
IH = src.layout.shape[1]; | |||||
IW = src.layout.shape[2]; | |||||
OH = dst.layout.shape[2]; | |||||
OW = dst.layout.shape[3]; | |||||
} else if (param().format == Param::Format::NHWC_NCHW4_IC_SMALL) { | |||||
C = src.layout.shape[3]; | |||||
IH = src.layout.shape[1]; | |||||
IW = src.layout.shape[2]; | |||||
OH = dst.layout.shape[2]; | |||||
OW = dst.layout.shape[3]; | |||||
megdnn_assert( | |||||
(C == 1) || (C == 3), | |||||
"NHWC_NCHW4_IC_SMALL only support C == 1 or C == 3"); | |||||
} else if (param().format == Param::Format::NCHW_NCHW4_IC_SMALL) { | |||||
C = src.layout.shape[1]; | |||||
IH = src.layout.shape[2]; | |||||
IW = src.layout.shape[3]; | |||||
OH = dst.layout.shape[2]; | |||||
OW = dst.layout.shape[3]; | |||||
megdnn_assert( | |||||
(C == 1) || (C == 3), | |||||
"NCHW_NCHW4_IC_SMALL only support C == 1 or C == 3"); | |||||
} else { | } else { | ||||
megdnn_assert( | megdnn_assert( | ||||
param().format == param::WarpPerspective::Format::NCHW, | param().format == param::WarpPerspective::Format::NCHW, | ||||
@@ -180,55 +205,123 @@ void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in ssrc, | |||||
"unsupported interpolation mode for NCHW format"); | "unsupported interpolation mode for NCHW format"); | ||||
auto bval = param().border_val; | auto bval = param().border_val; | ||||
auto bmode = warp_perspective::get_bmode(param().bmode); | auto bmode = warp_perspective::get_bmode(param().bmode); | ||||
if (src.layout.dtype == dtype::Float32{}) { | |||||
warp_perspective::forward_proxy( | |||||
is_nhwc, src.ptr<dt_float32>(), mat.ptr<dt_float32>(), | |||||
mat_idx.raw_ptr ? mat_idx.ptr<int>() : nullptr, | |||||
dst.ptr<dt_float32>(), src.layout[0], mat.layout[0], C, | |||||
IH, IW, OH, OW, bval, bmode, async_error_info(handle()), | |||||
m_error_tracker, stream); | |||||
} else if (MEGDNN_FLOAT16_SELECT( | |||||
src.layout.dtype == dtype::Float16(), false)) { | |||||
if (src.layout.dtype == dst.layout.dtype) { | |||||
if (src.layout.dtype == dtype::Float32{}) { | |||||
warp_perspective::forward_proxy( | |||||
is_nhwc, src.ptr<dt_float32>(), | |||||
mat.ptr<dt_float32>(), | |||||
mat_idx.raw_ptr ? mat_idx.ptr<int>() : nullptr, | |||||
dst.ptr<dt_float32>(), src.layout[0], mat.layout[0], | |||||
C, IH, IW, OH, OW, bval, bmode, | |||||
async_error_info(handle()), m_error_tracker, | |||||
stream); | |||||
} else if (MEGDNN_FLOAT16_SELECT( | |||||
src.layout.dtype == dtype::Float16(), | |||||
false)) { | |||||
#ifndef MEGDNN_DISABLE_FLOAT16 | #ifndef MEGDNN_DISABLE_FLOAT16 | ||||
warp_perspective::forward_proxy( | |||||
is_nhwc, src.ptr<dt_float16>(), mat.ptr<dt_float32>(), | |||||
mat_idx.raw_ptr ? mat_idx.ptr<int>() : nullptr, | |||||
dst.ptr<dt_float16>(), src.layout[0], mat.layout[0], C, | |||||
IH, IW, OH, OW, static_cast<dt_float16>(bval), bmode, | |||||
async_error_info(handle()), m_error_tracker, stream); | |||||
warp_perspective::forward_proxy( | |||||
is_nhwc, src.ptr<dt_float16>(), | |||||
mat.ptr<dt_float32>(), | |||||
mat_idx.raw_ptr ? mat_idx.ptr<int>() : nullptr, | |||||
dst.ptr<dt_float16>(), src.layout[0], mat.layout[0], | |||||
C, IH, IW, OH, OW, static_cast<dt_float16>(bval), | |||||
bmode, async_error_info(handle()), m_error_tracker, | |||||
stream); | |||||
#endif | #endif | ||||
} else if (src.layout.dtype == dtype::Uint8()) { | |||||
warp_perspective::forward_proxy<dt_uint8>( | |||||
is_nhwc, src.ptr<dt_uint8>(), mat.ptr<dt_float32>(), | |||||
mat_idx.raw_ptr ? mat_idx.ptr<int>() : nullptr, | |||||
dst.ptr<dt_uint8>(), src.layout[0], mat.layout[0], C, | |||||
IH, IW, OH, OW, bval, bmode, async_error_info(handle()), | |||||
m_error_tracker, stream); | |||||
} else if (src.layout.dtype == dtype::Int8()) { | |||||
megdnn_assert( | |||||
!is_nhwc, | |||||
"WarpPerspective on CUDA does not support NHWC + Int8"); | |||||
warp_perspective::forward_proxy<dt_int8>( | |||||
false, src.ptr<dt_int8>(), mat.ptr<dt_float32>(), | |||||
mat_idx.raw_ptr ? mat_idx.ptr<int>() : nullptr, | |||||
dst.ptr<dt_int8>(), src.layout[0], mat.layout[0], C, IH, | |||||
IW, OH, OW, | |||||
bval /* implicit float -> int8 conversion, should be | |||||
safe */ | |||||
, | |||||
bmode, async_error_info(handle()), m_error_tracker, | |||||
stream); | |||||
} else if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { | |||||
megdnn_assert(param().format == Param::Format::NCHW4, | |||||
"WarpPerspective on CUDA supports NCHW4 + " | |||||
"QuantizedS8 only"); | |||||
warp_perspective::forward_proxy_nchw4<dt_int8>( | |||||
src.compatible_ptr<dt_int8>(), mat.ptr<dt_float32>(), | |||||
mat_idx.raw_ptr ? mat_idx.ptr<int>() : nullptr, | |||||
dst.compatible_ptr<dt_int8>(), src.layout[0], | |||||
mat.layout[0], C, IH, IW, OH, OW, bval, bmode, | |||||
async_error_info(handle()), m_error_tracker, stream); | |||||
} else if (src.layout.dtype == dtype::Uint8()) { | |||||
warp_perspective::forward_proxy<dt_uint8>( | |||||
is_nhwc, src.ptr<dt_uint8>(), mat.ptr<dt_float32>(), | |||||
mat_idx.raw_ptr ? mat_idx.ptr<int>() : nullptr, | |||||
dst.ptr<dt_uint8>(), src.layout[0], mat.layout[0], | |||||
C, IH, IW, OH, OW, bval, bmode, | |||||
async_error_info(handle()), m_error_tracker, | |||||
stream); | |||||
} else if (src.layout.dtype == dtype::Int8()) { | |||||
megdnn_assert(!is_nhwc, | |||||
"WarpPerspective on CUDA does not support " | |||||
"NHWC + Int8"); | |||||
warp_perspective::forward_proxy<dt_int8>( | |||||
false, src.ptr<dt_int8>(), mat.ptr<dt_float32>(), | |||||
mat_idx.raw_ptr ? mat_idx.ptr<int>() : nullptr, | |||||
dst.ptr<dt_int8>(), src.layout[0], mat.layout[0], C, | |||||
IH, IW, OH, OW, | |||||
bval /* implicit float -> int8 conversion, | |||||
should be safe */ | |||||
, | |||||
bmode, async_error_info(handle()), m_error_tracker, | |||||
stream); | |||||
} else if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { | |||||
megdnn_assert(param().format == Param::Format::NCHW4, | |||||
"WarpPerspective on CUDA supports NCHW4 + " | |||||
"QuantizedS8 only"); | |||||
warp_perspective::forward_proxy_nchw4<dt_int8>( | |||||
src.compatible_ptr<dt_int8>(), | |||||
mat.ptr<dt_float32>(), | |||||
mat_idx.raw_ptr ? mat_idx.ptr<int>() : nullptr, | |||||
dst.compatible_ptr<dt_int8>(), src.layout[0], | |||||
mat.layout[0], C, IH, IW, OH, OW, bval, bmode, | |||||
async_error_info(handle()), m_error_tracker, | |||||
stream); | |||||
} | |||||
} else if ((src.layout.dtype.enumv() == | |||||
DTypeEnum::Quantized8Asymm || | |||||
src.layout.dtype.enumv() == DTypeEnum::Uint8)) { | |||||
uint8_t zero_point = 0; | |||||
float scale = 1.f; | |||||
if (src.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { | |||||
zero_point = | |||||
src.layout.dtype.param<dtype::Quantized8Asymm>() | |||||
.zero_point; | |||||
scale = src.layout.dtype.param<dtype::Quantized8Asymm>() | |||||
.scale; | |||||
} else if (src.layout.dtype.enumv() == DTypeEnum::Uint8 && | |||||
dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { | |||||
zero_point = 128; | |||||
scale = 1.f; | |||||
} | |||||
DTypeParamImpl<dt_quint8> src_dtype_param(scale, zero_point); | |||||
if ((dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8 && | |||||
dst.layout.dtype.param<dtype::QuantizedS8>().scale == | |||||
scale) && | |||||
((param().format == Param::Format::NCHW_NCHW4_IC_SMALL) || | |||||
(param().format == Param::Format::NHWC_NCHW4_IC_SMALL))) { | |||||
bool is_nhwc_ic_small = | |||||
(param().format == | |||||
Param::Format::NHWC_NCHW4_IC_SMALL); | |||||
warp_perspective:: | |||||
forward_proxy_quint8_dimshuffle_typecvt_nchw4< | |||||
dt_quint8, dt_uint8, dt_int8>( | |||||
is_nhwc_ic_small, | |||||
src.compatible_ptr<dt_uint8>(), | |||||
mat.ptr<dt_float32>(), | |||||
mat_idx.raw_ptr ? mat_idx.ptr<int>() | |||||
: nullptr, | |||||
dst.compatible_ptr<dt_int8>(), | |||||
src.layout[0], mat.layout[0], C, IH, IW, OH, | |||||
OW, bval, src_dtype_param, bmode, | |||||
async_error_info(handle()), m_error_tracker, | |||||
stream); | |||||
} else { | |||||
megdnn_assert( | |||||
((dst.layout.dtype.enumv() == DTypeEnum::Float32) && | |||||
((param().format == Param::Format::NCHW) || | |||||
(param().format == Param::Format::NHWC_NCHW))), | |||||
"invalid format for Quantized8Asymm input"); | |||||
bool is_nhwc = (param().format == Param::Format::NHWC_NCHW); | |||||
warp_perspective:: | |||||
forward_proxy_quint8_dimshuffle_typecvt_nchw< | |||||
dt_quint8, dt_uint8, dt_float32>( | |||||
is_nhwc, src.compatible_ptr<dt_uint8>(), | |||||
mat.ptr<dt_float32>(), | |||||
mat_idx.raw_ptr ? mat_idx.ptr<int>() | |||||
: nullptr, | |||||
dst.compatible_ptr<dt_float32>(), | |||||
src.layout[0], mat.layout[0], C, IH, IW, OH, | |||||
OW, bval, src_dtype_param, bmode, | |||||
async_error_info(handle()), m_error_tracker, | |||||
stream); | |||||
} | |||||
} else { | } else { | ||||
megdnn_throw(ssprintf("unsupported dtype: %s", | megdnn_throw(ssprintf("unsupported dtype: %s", | ||||
src.layout.dtype.name())); | src.layout.dtype.name())); | ||||
@@ -249,6 +249,162 @@ void WarpPerspectiveForwardImpl::kern_naive_nhwcd4( | |||||
MIDOUT_END(); | MIDOUT_END(); | ||||
} | } | ||||
template <typename ctype, typename dst_ctype, typename mtype> | |||||
void WarpPerspectiveForwardImpl::kern_naive_dimshuffle_typecvt( | |||||
const KernParam<ctype, mtype>& kern_param, size_t task_id) { | |||||
MEGDNN_MARK_USED_VAR(kern_param); | |||||
MIDOUT_BEGIN(megdnn_naive_warpperspective, ctype, mtype, midout_iv(2)) { | |||||
UNPACK_WARP_PERSPECTIVE_FWD_KERN_PARAM(kern_param); | |||||
MEGDNN_MARK_USED_VAR(N_MAT); | |||||
//! strides of C, H, W on src and dst | |||||
size_t sstrd[3], dstrd[3]; | |||||
auto set_sstrd = [&](size_t s0, size_t s1, size_t s2) { | |||||
sstrd[0] = s0; | |||||
sstrd[1] = s1; | |||||
sstrd[2] = s2; | |||||
}; | |||||
auto set_dstrd = [&](size_t s0, size_t s1, size_t s2) { | |||||
dstrd[0] = s0; | |||||
dstrd[1] = s1; | |||||
dstrd[2] = s2; | |||||
}; | |||||
switch (kern_param.format) { | |||||
case Format::NCHW: | |||||
case Format::NCHW_NCHW4_IC_SMALL: | |||||
set_sstrd(IH * IW, IW, 1); | |||||
set_dstrd(OH * OW, OW, 1); | |||||
break; | |||||
case Format::NHWC_NCHW: | |||||
case Format::NHWC_NCHW4_IC_SMALL: | |||||
set_sstrd(1, IW * C, C); | |||||
set_dstrd(OH * OW, OW, 1); | |||||
break; | |||||
default: | |||||
megdnn_throw("bad format"); | |||||
} | |||||
uint8_t zero_point = 0; | |||||
float scale = 1.f; | |||||
bool is_dst_float = kern_param.dst_dtype.enumv() == DTypeEnum::Float32; | |||||
if (kern_param.src_dtype.enumv() == | |||||
DTypeTrait<dtype::Quantized8Asymm>::enumv) { | |||||
auto dtype_param = | |||||
kern_param.src_dtype | |||||
.template param<dtype::Quantized8Asymm>(); | |||||
zero_point = dtype_param.zero_point; | |||||
scale = dtype_param.scale; | |||||
} else if (kern_param.src_dtype.enumv() == DTypeEnum::Uint8) { | |||||
zero_point = | |||||
(kern_param.dst_dtype.enumv() == DTypeEnum::QuantizedS8) | |||||
? 128 | |||||
: 0; | |||||
scale = 1.f; | |||||
} | |||||
dst_ctype* dst_ptr = reinterpret_cast<dst_ctype*>(dptr); | |||||
bool is_dst_nchw4 = | |||||
(kern_param.format == Format::NCHW_NCHW4_IC_SMALL) || | |||||
(kern_param.format == Format::NHWC_NCHW4_IC_SMALL); | |||||
auto visit_src = [&sptr, sstrd](size_t c, int h, int w) -> float { | |||||
return sptr[sstrd[0] * c + sstrd[1] * h + sstrd[2] * w]; | |||||
}; | |||||
auto visit_src_bd = [&sptr, sstrd, border_val](size_t c, int h, | |||||
int w) -> float { | |||||
if (h != -1 && w != -1) { | |||||
return sptr[sstrd[0] * c + sstrd[1] * h + sstrd[2] * w]; | |||||
} else | |||||
return border_val; | |||||
}; | |||||
auto visit_dst = [&dst_ptr, dstrd, is_dst_nchw4](size_t c, int h, | |||||
int w) -> dst_ctype& { | |||||
if (!is_dst_nchw4) | |||||
return dst_ptr[dstrd[0] * c + dstrd[1] * h + dstrd[2] * w]; | |||||
else | |||||
return dst_ptr[((dstrd[0] * (c >> 2) + dstrd[1] * h + | |||||
dstrd[2] * w) | |||||
<< 2) + | |||||
(c & 0b11)]; | |||||
}; | |||||
rounding::RoundingConverter<dst_ctype> output_converter; | |||||
auto orig_sptr = sptr; | |||||
size_t n = task_id / OH; | |||||
size_t oh = task_id % OH; | |||||
mptr = mptr + n * 3 * 3; | |||||
dst_ptr = is_dst_nchw4 ? (dst_ptr + n * OH * OW * 4) | |||||
: (dst_ptr + n * C * OH * OW); | |||||
if (midx_ptr) { | |||||
size_t idx = midx_ptr[n]; | |||||
megdnn_assert( | |||||
idx < N_SRC, | |||||
"mat_idx out of bound: mat_idx[%zu]=%zu src_batch=%zu", n, | |||||
idx, N_SRC); | |||||
sptr = orig_sptr + idx * (C * IH * IW); | |||||
} else if (n) { | |||||
sptr += n * C * IH * IW; | |||||
} | |||||
rep(ow, OW) { | |||||
float numeratorw = mptr[0] * ow + mptr[1] * oh + mptr[2]; | |||||
float numeratorh = mptr[3] * ow + mptr[4] * oh + mptr[5]; | |||||
float denominator = mptr[6] * ow + mptr[7] * oh + mptr[8]; | |||||
float alphaw = numeratorw / denominator; | |||||
float alphah = numeratorh / denominator; | |||||
int iw0 = get_real_coord(std::floor(alphaw) + 0, IW); | |||||
int iw1 = get_real_coord(std::floor(alphaw) + 1, IW); | |||||
int ih0 = get_real_coord(std::floor(alphah) + 0, IH); | |||||
int ih1 = get_real_coord(std::floor(alphah) + 1, IH); | |||||
alphaw -= floor(alphaw); | |||||
alphah -= floor(alphah); | |||||
if (bmode != BorderMode::CONSTANT) { | |||||
rep(c, C) { | |||||
auto val = | |||||
visit_src(c, ih0, iw0) * (1.0f - alphaw) * | |||||
(1.0f - alphah) + | |||||
visit_src(c, ih0, iw1) * alphaw * (1.0f - alphah) + | |||||
visit_src(c, ih1, iw0) * (1.0f - alphaw) * alphah + | |||||
visit_src(c, ih1, iw1) * alphaw * alphah; | |||||
val = is_dst_float ? (val - zero_point) * scale | |||||
: val - zero_point; | |||||
visit_dst(c, oh, ow) = output_converter(val); | |||||
} | |||||
} else { | |||||
rep(c, C) { | |||||
auto val = visit_src_bd(c, ih0, iw0) * (1.0f - alphaw) * | |||||
(1.0f - alphah) + | |||||
visit_src_bd(c, ih0, iw1) * alphaw * | |||||
(1.0f - alphah) + | |||||
visit_src_bd(c, ih1, iw0) * (1.0f - alphaw) * | |||||
alphah + | |||||
visit_src_bd(c, ih1, iw1) * alphaw * alphah; | |||||
val = std::isfinite(val) ? val : border_val; | |||||
val = is_dst_float ? (val - zero_point) * scale | |||||
: val - zero_point; | |||||
visit_dst(c, oh, ow) = output_converter(val); | |||||
} | |||||
} | |||||
if (is_dst_nchw4) { | |||||
for (auto c = C; c < 4; ++c) { | |||||
visit_dst(c, oh, ow) = 0; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
MIDOUT_END(); | |||||
} | |||||
#define INST(ctype, drc_ctype, mtype) \ | |||||
template void WarpPerspectiveForwardImpl::kern_naive_dimshuffle_typecvt< \ | |||||
ctype, drc_ctype, mtype>(const KernParam<ctype, mtype>&, size_t); | |||||
INST(uint8_t, int8_t, float); | |||||
INST(uint8_t, float, float); | |||||
#undef INST | |||||
void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in src, | void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in src, | ||||
_megdnn_tensor_in mat, | _megdnn_tensor_in mat, | ||||
_megdnn_tensor_in mat_idx, | _megdnn_tensor_in mat_idx, | ||||
@@ -320,6 +476,65 @@ void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in src, | |||||
src.layout.dtype.name()) | src.layout.dtype.name()) | ||||
.c_str()); | .c_str()); | ||||
} | } | ||||
bool is_fusion_dtype = src.layout.dtype.enumv() != dst.layout.dtype.enumv(); | |||||
bool is_u8_or_qu8_in = | |||||
src.layout.dtype.enumv() == DTypeTrait<dtype::Uint8>::enumv || | |||||
src.layout.dtype.enumv() == | |||||
DTypeTrait<dtype::Quantized8Asymm>::enumv; | |||||
if (is_fusion_dtype && is_u8_or_qu8_in && | |||||
((param().format == Format::NCHW_NCHW4_IC_SMALL) || | |||||
(param().format == Format::NHWC_NCHW4_IC_SMALL) || | |||||
(param().format == Format::NHWC_NCHW) || | |||||
(param().format == Format::NCHW))) { | |||||
if (src.layout.dtype.enumv() == | |||||
DTypeTrait<dtype::Quantized8Asymm>::enumv || | |||||
src.layout.dtype.enumv() == DTypeTrait<dtype::Uint8>::enumv) { | |||||
float scale = 1.f; | |||||
if (src.layout.dtype.enumv() == | |||||
DTypeTrait<dtype::Quantized8Asymm>::enumv) { | |||||
scale = src.layout.dtype.param<dtype::Quantized8Asymm>().scale; | |||||
} | |||||
auto kparam = KernParam<uint8_t, float>::from_tensors( | |||||
param().format, param().bmode, param().border_val, src, mat, | |||||
mat_idx, dst, workspace); | |||||
if (dst.layout.dtype.enumv() == DTypeTrait<dtype::Float32>::enumv) { | |||||
auto run = [kparam, this](size_t index, size_t) { | |||||
kern_naive_dimshuffle_typecvt<uint8_t, float, float>(kparam, | |||||
index); | |||||
}; | |||||
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN_OPR(run, | |||||
kparam.oh * batch); | |||||
return; | |||||
} else if ((dst.layout.dtype.enumv() == | |||||
DTypeTrait<dtype::QuantizedS8>::enumv) && | |||||
(dst.layout.dtype.param<dtype::QuantizedS8>().scale == | |||||
scale)) { | |||||
auto run = [kparam, this](size_t index, size_t) { | |||||
kern_naive_dimshuffle_typecvt<uint8_t, int8_t, float>( | |||||
kparam, index); | |||||
}; | |||||
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN_OPR(run, | |||||
kparam.oh * batch); | |||||
return; | |||||
} else { | |||||
megdnn_throw(ssprintf("Unsupported DType in " | |||||
"WarpPerspective Dimshuffle Typecvt: %s", | |||||
src.layout.dtype.name()) | |||||
.c_str()); | |||||
} | |||||
} | |||||
megdnn_throw(ssprintf("Unsupported input DType in " | |||||
"WarpPerspective: %s", | |||||
src.layout.dtype.name()) | |||||
.c_str()); | |||||
} | |||||
if (warp::is_cv_available(src.layout, mat.layout, dst.layout, param().imode, | if (warp::is_cv_available(src.layout, mat.layout, dst.layout, param().imode, | ||||
param().format)) { | param().format)) { | ||||
MIDOUT_BEGIN(megdnn_naive_warpperspective, void) { | MIDOUT_BEGIN(megdnn_naive_warpperspective, void) { | ||||
@@ -331,12 +546,12 @@ void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in src, | |||||
megdnn_assert(warp::is_dnn_available(src.layout, mat.layout, dst.layout, | megdnn_assert(warp::is_dnn_available(src.layout, mat.layout, dst.layout, | ||||
param().imode, param().format)); | param().imode, param().format)); | ||||
/*! | /*! | ||||
* We currently use floating point for all WarpPerspective computation, | |||||
* so even if the input ctype is one of the integer type, mtype should | |||||
* always be float32. | |||||
* We currently use floating point for all WarpPerspective | |||||
* computation, so even if the input ctype is one of the integer | |||||
* type, mtype should always be float32. | |||||
* | * | ||||
* \warning It's different with \c WarpAffine, with mtype be float16 if | |||||
* input type is float16. | |||||
* \warning It's different with \c WarpAffine, with mtype be float16 | |||||
* if input type is float16. | |||||
*/ | */ | ||||
DISPATCH_ST(dtype::Float32, float, float, KERN); | DISPATCH_ST(dtype::Float32, float, float, KERN); | ||||
@@ -26,6 +26,7 @@ protected: | |||||
float border_val; | float border_val; | ||||
size_t n_src, n_mat, c, ih, iw, oh, ow; | size_t n_src, n_mat, c, ih, iw, oh, ow; | ||||
ctype *sptr, *dptr; | ctype *sptr, *dptr; | ||||
DType src_dtype, dst_dtype; | |||||
mtype* mptr; | mtype* mptr; | ||||
int* midx_ptr; //!< can be null | int* midx_ptr; //!< can be null | ||||
Workspace workspace; | Workspace workspace; | ||||
@@ -41,6 +42,8 @@ protected: | |||||
ret.bmode = bmode; | ret.bmode = bmode; | ||||
ret.border_val = border_val; | ret.border_val = border_val; | ||||
ret.n_src = src.layout.shape[0]; | ret.n_src = src.layout.shape[0]; | ||||
ret.src_dtype = src.layout.dtype; | |||||
ret.dst_dtype = dst.layout.dtype; | |||||
if (mat_idx.raw_ptr) { | if (mat_idx.raw_ptr) { | ||||
megdnn_assert(mat_idx.layout.ndim == 1); | megdnn_assert(mat_idx.layout.ndim == 1); | ||||
ret.n_mat = mat_idx.layout.shape[0]; | ret.n_mat = mat_idx.layout.shape[0]; | ||||
@@ -50,7 +53,8 @@ protected: | |||||
ret.n_mat = ret.n_src; | ret.n_mat = ret.n_src; | ||||
ret.midx_ptr = nullptr; | ret.midx_ptr = nullptr; | ||||
} | } | ||||
if (format == Format::NCHW) { | |||||
if (format == Format::NCHW || | |||||
format == Format::NCHW_NCHW4_IC_SMALL) { | |||||
ret.c = src.layout.shape[1]; | ret.c = src.layout.shape[1]; | ||||
ret.ih = src.layout.shape[2]; | ret.ih = src.layout.shape[2]; | ||||
ret.iw = src.layout.shape[3]; | ret.iw = src.layout.shape[3]; | ||||
@@ -62,6 +66,13 @@ protected: | |||||
ret.iw = src.layout.shape[2]; | ret.iw = src.layout.shape[2]; | ||||
ret.oh = dst.layout.shape[1]; | ret.oh = dst.layout.shape[1]; | ||||
ret.ow = dst.layout.shape[2]; | ret.ow = dst.layout.shape[2]; | ||||
} else if (format == Format::NHWC_NCHW || | |||||
format == Format::NHWC_NCHW4_IC_SMALL) { | |||||
ret.c = src.layout.shape[3]; | |||||
ret.ih = src.layout.shape[1]; | |||||
ret.iw = src.layout.shape[2]; | |||||
ret.oh = dst.layout.shape[2]; | |||||
ret.ow = dst.layout.shape[3]; | |||||
} else if (format == Format::NCHW4) { | } else if (format == Format::NCHW4) { | ||||
ret.c = src.layout.shape[1] * 4; | ret.c = src.layout.shape[1] * 4; | ||||
ret.ih = src.layout.shape[2]; | ret.ih = src.layout.shape[2]; | ||||
@@ -76,15 +87,16 @@ protected: | |||||
ret.oh = dst.layout.shape[1]; | ret.oh = dst.layout.shape[1]; | ||||
ret.ow = dst.layout.shape[3]; | ret.ow = dst.layout.shape[3]; | ||||
} | } | ||||
if (src.layout.dtype.enumv() == DTypeEnum::Float32 || | |||||
MEGDNN_FLOAT16_SELECT( | |||||
(src.layout.dtype.enumv() == DTypeEnum::Float16 || | |||||
src.layout.dtype.enumv() == DTypeEnum::BFloat16), | |||||
false) || | |||||
src.layout.dtype.enumv() == DTypeEnum::Int8 || | |||||
src.layout.dtype.enumv() == DTypeEnum::Uint8 || | |||||
src.layout.dtype.enumv() == DTypeEnum::QuantizedS8 || | |||||
src.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { | |||||
if ((src.layout.dtype.enumv() == DTypeEnum::Float32 || | |||||
MEGDNN_FLOAT16_SELECT( | |||||
(src.layout.dtype.enumv() == DTypeEnum::Float16 || | |||||
src.layout.dtype.enumv() == DTypeEnum::BFloat16), | |||||
false) || | |||||
src.layout.dtype.enumv() == DTypeEnum::Int8 || | |||||
src.layout.dtype.enumv() == DTypeEnum::Uint8 || | |||||
src.layout.dtype.enumv() == DTypeEnum::QuantizedS8 || | |||||
src.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) && | |||||
(src.layout.dtype == dst.layout.dtype)) { | |||||
ret.sptr = src.compatible_ptr<ctype>(); | ret.sptr = src.compatible_ptr<ctype>(); | ||||
ret.mptr = mat.ptr<mtype>(); | ret.mptr = mat.ptr<mtype>(); | ||||
ret.dptr = dst.compatible_ptr<ctype>(); | ret.dptr = dst.compatible_ptr<ctype>(); | ||||
@@ -92,6 +104,13 @@ protected: | |||||
ret.sptr = src.compatible_ptr<ctype>(); | ret.sptr = src.compatible_ptr<ctype>(); | ||||
ret.mptr = mat.ptr<mtype>(); | ret.mptr = mat.ptr<mtype>(); | ||||
ret.dptr = dst.compatible_ptr<ctype>(); | ret.dptr = dst.compatible_ptr<ctype>(); | ||||
} else if ((src.layout.dtype.enumv() == DTypeEnum::Uint8 || | |||||
src.layout.dtype.enumv() == | |||||
DTypeEnum::Quantized8Asymm) && | |||||
src.layout.dtype.enumv() != dst.layout.dtype.enumv()) { | |||||
ret.sptr = src.compatible_ptr<ctype>(); | |||||
ret.mptr = mat.ptr<mtype>(); | |||||
ret.dptr = reinterpret_cast<ctype*>(dst.raw_ptr); | |||||
} else { | } else { | ||||
ret.sptr = nullptr; | ret.sptr = nullptr; | ||||
ret.mptr = nullptr; | ret.mptr = nullptr; | ||||
@@ -122,6 +141,9 @@ private: | |||||
template <typename ctype, typename mtype> | template <typename ctype, typename mtype> | ||||
void kern_naive_nhwcd4(const KernParam<ctype, mtype>& kern_param, | void kern_naive_nhwcd4(const KernParam<ctype, mtype>& kern_param, | ||||
size_t task_id); | size_t task_id); | ||||
template <typename ctype, typename dst_ctype, typename mtype> | |||||
void kern_naive_dimshuffle_typecvt( | |||||
const KernParam<ctype, mtype>& kern_param, size_t task_id); | |||||
}; | }; | ||||
class WarpPerspectiveBackwardDataImpl : public WarpPerspectiveBackwardData { | class WarpPerspectiveBackwardDataImpl : public WarpPerspectiveBackwardData { | ||||
@@ -23,8 +23,7 @@ using namespace megdnn; | |||||
using namespace test; | using namespace test; | ||||
class NanMatRNG : public RNG { | class NanMatRNG : public RNG { | ||||
void gen(const TensorND& tensor_) override | |||||
{ | |||||
void gen(const TensorND& tensor_) override { | |||||
auto& gen = RandomState::generator(); | auto& gen = RandomState::generator(); | ||||
std::uniform_real_distribution<dt_float32> pdist3(1.9f, 2.1f); | std::uniform_real_distribution<dt_float32> pdist3(1.9f, 2.1f); | ||||
std::uniform_real_distribution<dt_float32> pdist(0.9f, 1.1f); | std::uniform_real_distribution<dt_float32> pdist(0.9f, 1.1f); | ||||
@@ -335,6 +334,144 @@ TEST_F(CUDA, WARP_PERSPECTIVE_NCHW4) { | |||||
} | } | ||||
} | } | ||||
TEST_F(CUDA, WARP_PERSPECTIVE_NCHW_NCHW4_IC_SMALL) { | |||||
using Param = WarpPerspective::Param; | |||||
WarpPerspective::Param param; | |||||
Checker<WarpPerspectiveForward> checker(handle_cuda()); | |||||
WarpPerspectiveMatRNG rng; | |||||
param.format = Param::Format::NCHW_NCHW4_IC_SMALL; | |||||
checker.set_rng(1, &rng); | |||||
checker.set_dtype(0, dtype::Quantized8Asymm(0.1f, 128)); | |||||
checker.set_dtype(2, dtype::QuantizedS8(0.1f)); | |||||
for (auto bmode : {WarpPerspective::BorderMode::WRAP, | |||||
WarpPerspective::BorderMode::REFLECT, | |||||
WarpPerspective::BorderMode::REPLICATE, | |||||
WarpPerspective::BorderMode::CONSTANT}) { | |||||
param.border_val = 0.3f; | |||||
param.bmode = bmode; | |||||
param.imode = Param::InterpolationMode::LINEAR; | |||||
checker.set_param(param); | |||||
checker.set_epsilon(1 + 1e-3); | |||||
checker.execs({{2, 3, 10, 11}, {2, 3, 3}, {2, 1, 11, 12, 4}}); | |||||
checker.execs({{1, 3, 25, 510}, {1, 3, 3}, {1, 1, 25, 25, 4}}); | |||||
checker.execs({{1, 3, 25, 25}, {1, 3, 3}, {1, 1, 51, 51, 4}}); | |||||
checker.execs({{1, 3, 51, 51}, {1, 3, 3}, {1, 1, 25, 25, 4}}); | |||||
} | |||||
{ | |||||
Checker<WarpPerspective, WarpPerspectiveMatIdxProxy> checker( | |||||
handle_cuda()); | |||||
constexpr int N_SRC = 5; | |||||
UniformIntRNG mat_idx_rng{0, N_SRC - 1}; | |||||
checker.set_dtype(0, dtype::Quantized8Asymm(0.1f, 128)); | |||||
checker.set_rng(1, &rng); | |||||
checker.set_dtype(2, dtype::Int32()); | |||||
checker.set_rng(2, &mat_idx_rng); | |||||
checker.set_dtype(3, dtype::QuantizedS8(0.1f)); | |||||
param.bmode = WarpPerspective::Param::BorderMode::REFLECT; | |||||
param.imode = param::WarpPerspective::InterpolationMode::LINEAR; | |||||
checker.set_param(param); | |||||
checker.set_epsilon(1 + 1e-3); | |||||
checker.execs({{N_SRC, 3, 10, 11}, {2, 3, 3}, {2}, {2, 1, 11, 12, 4}}); | |||||
checker.execs( | |||||
{{N_SRC, 3, 17, 13}, {123, 3, 3}, {123}, {123, 1, 16, 15, 4}}); | |||||
} | |||||
} | |||||
TEST_F(CUDA, WARP_PERSPECTIVE_NHWC_NCHW4_IC_SMALL) { | |||||
using Param = WarpPerspective::Param; | |||||
WarpPerspective::Param param; | |||||
Checker<WarpPerspectiveForward> checker(handle_cuda()); | |||||
WarpPerspectiveMatRNG rng; | |||||
param.format = Param::Format::NHWC_NCHW4_IC_SMALL; | |||||
checker.set_rng(1, &rng); | |||||
checker.set_dtype(0, dtype::Uint8()); | |||||
checker.set_dtype(2, dtype::QuantizedS8(1.f)); | |||||
for (auto bmode : {WarpPerspective::BorderMode::WRAP, | |||||
WarpPerspective::BorderMode::REFLECT, | |||||
WarpPerspective::BorderMode::REPLICATE, | |||||
WarpPerspective::BorderMode::CONSTANT}) { | |||||
param.border_val = 0.3f; | |||||
param.bmode = bmode; | |||||
param.imode = Param::InterpolationMode::LINEAR; | |||||
checker.set_param(param); | |||||
checker.set_epsilon(1 + 1e-3); | |||||
checker.execs({{2, 10, 11, 3}, {2, 3, 3}, {2, 1, 11, 12, 4}}); | |||||
checker.execs({{1, 25, 510, 3}, {1, 3, 3}, {1, 1, 25, 25, 4}}); | |||||
checker.execs({{1, 25, 25, 3}, {1, 3, 3}, {1, 1, 51, 51, 4}}); | |||||
checker.execs({{1, 51, 51, 3}, {1, 3, 3}, {1, 1, 25, 25, 4}}); | |||||
} | |||||
{ | |||||
Checker<WarpPerspective, WarpPerspectiveMatIdxProxy> checker( | |||||
handle_cuda()); | |||||
constexpr int N_SRC = 5; | |||||
UniformIntRNG mat_idx_rng{0, N_SRC - 1}; | |||||
checker.set_dtype(0, dtype::Uint8()); | |||||
checker.set_rng(1, &rng); | |||||
checker.set_dtype(2, dtype::Int32()); | |||||
checker.set_rng(2, &mat_idx_rng); | |||||
checker.set_dtype(3, dtype::QuantizedS8(1.f)); | |||||
param.bmode = WarpPerspective::Param::BorderMode::REFLECT; | |||||
param.imode = param::WarpPerspective::InterpolationMode::LINEAR; | |||||
checker.set_param(param); | |||||
checker.set_epsilon(1 + 1e-3); | |||||
checker.execs({{N_SRC, 10, 11, 3}, {2, 3, 3}, {2}, {2, 1, 11, 12, 4}}); | |||||
checker.execs( | |||||
{{N_SRC, 17, 13, 3}, {123, 3, 3}, {123}, {123, 1, 16, 15, 4}}); | |||||
} | |||||
} | |||||
TEST_F(CUDA, WARP_PERSPECTIVE_NHWC_NCHW) { | |||||
using Param = WarpPerspective::Param; | |||||
WarpPerspective::Param param; | |||||
Checker<WarpPerspectiveForward> checker(handle_cuda()); | |||||
WarpPerspectiveMatRNG rng; | |||||
param.format = Param::Format::NHWC_NCHW; | |||||
checker.set_rng(1, &rng); | |||||
checker.set_dtype(0, dtype::Uint8()); | |||||
checker.set_dtype(2, dtype::Float32()); | |||||
for (auto bmode : {WarpPerspective::BorderMode::WRAP, | |||||
WarpPerspective::BorderMode::REFLECT, | |||||
WarpPerspective::BorderMode::REPLICATE, | |||||
WarpPerspective::BorderMode::CONSTANT}) { | |||||
param.border_val = 0.3f; | |||||
param.bmode = bmode; | |||||
param.imode = Param::InterpolationMode::LINEAR; | |||||
checker.set_param(param); | |||||
checker.set_epsilon(1 + 1e-3); | |||||
checker.execs({{2, 10, 11, 3}, {2, 3, 3}, {2, 3, 11, 12}}); | |||||
checker.execs({{1, 25, 510, 3}, {1, 3, 3}, {1, 3, 25, 25}}); | |||||
checker.execs({{1, 25, 25, 3}, {1, 3, 3}, {1, 3, 51, 51}}); | |||||
checker.execs({{1, 51, 51, 3}, {1, 3, 3}, {1, 3, 25, 25}}); | |||||
} | |||||
{ | |||||
Checker<WarpPerspective, WarpPerspectiveMatIdxProxy> checker( | |||||
handle_cuda()); | |||||
constexpr int N_SRC = 5; | |||||
UniformIntRNG mat_idx_rng{0, N_SRC - 1}; | |||||
checker.set_dtype(0, dtype::Uint8()); | |||||
checker.set_rng(1, &rng); | |||||
checker.set_dtype(2, dtype::Int32()); | |||||
checker.set_rng(2, &mat_idx_rng); | |||||
checker.set_dtype(3, dtype::Float32()); | |||||
param.bmode = WarpPerspective::Param::BorderMode::REFLECT; | |||||
param.imode = param::WarpPerspective::InterpolationMode::LINEAR; | |||||
checker.set_param(param); | |||||
checker.set_epsilon(1 + 1e-3); | |||||
checker.execs({{N_SRC, 10, 11, 3}, {2, 3, 3}, {2}, {2, 3, 11, 12}}); | |||||
checker.execs( | |||||
{{N_SRC, 17, 13, 3}, {123, 3, 3}, {123}, {123, 3, 16, 15}}); | |||||
} | |||||
} | |||||
TEST_F(CUDA, WARP_PERSPECTIVE_FORWARD_NCHW_INT8) { | TEST_F(CUDA, WARP_PERSPECTIVE_FORWARD_NCHW_INT8) { | ||||
warp_perspective::run_int8_test(handle_cuda()); | warp_perspective::run_int8_test(handle_cuda()); | ||||
} | } | ||||
@@ -6,7 +6,8 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#include "megbrain/gopt/framework.h" | #include "megbrain/gopt/framework.h" | ||||
@@ -35,13 +36,13 @@ using namespace gopt; | |||||
/* ================ SubGraph ================ */ | /* ================ SubGraph ================ */ | ||||
OperatorNodeBase* SubGraph::Rewriter::auto_replace_outputs( | OperatorNodeBase* SubGraph::Rewriter::auto_replace_outputs( | ||||
OperatorNodeBase *opr) { | |||||
auto &&new_inp = m_opr_new_inp_cache; | |||||
OperatorNodeBase* opr) { | |||||
auto&& new_inp = m_opr_new_inp_cache; | |||||
new_inp.clear(); | new_inp.clear(); | ||||
new_inp.reserve(opr->input().size()); | new_inp.reserve(opr->input().size()); | ||||
bool has_replaced_inp = false; | bool has_replaced_inp = false; | ||||
for (auto i: opr->input()) { | |||||
for (auto i : opr->input()) { | |||||
auto new_var = get_var(i); | auto new_var = get_var(i); | ||||
if (new_var != i) { | if (new_var != i) { | ||||
has_replaced_inp = true; | has_replaced_inp = true; | ||||
@@ -52,14 +53,14 @@ OperatorNodeBase* SubGraph::Rewriter::auto_replace_outputs( | |||||
} | } | ||||
if (has_replaced_inp) { | if (has_replaced_inp) { | ||||
auto new_opr = serialization::copy_opr_shallow( | |||||
*opr, new_inp, opr->config()); | |||||
auto new_opr = | |||||
serialization::copy_opr_shallow(*opr, new_inp, opr->config()); | |||||
auto &&out0 = opr->output(), &&out1 = new_opr->output(); | auto &&out0 = opr->output(), &&out1 = new_opr->output(); | ||||
size_t i = 0; | size_t i = 0; | ||||
auto err_msg = [opr, new_opr] { | auto err_msg = [opr, new_opr] { | ||||
return ssprintf("bad opr copy: src=%s{%s} dst=%s{%s}", | |||||
opr->cname(), opr->dyn_typeinfo()->name, | |||||
new_opr->cname(), new_opr->dyn_typeinfo()->name); | |||||
return ssprintf("bad opr copy: src=%s{%s} dst=%s{%s}", opr->cname(), | |||||
opr->dyn_typeinfo()->name, new_opr->cname(), | |||||
new_opr->dyn_typeinfo()->name); | |||||
}; | }; | ||||
MGB_MARK_USED_VAR(err_msg); | MGB_MARK_USED_VAR(err_msg); | ||||
// opr output size mismatch may be caused by: | // opr output size mismatch may be caused by: | ||||
@@ -67,33 +68,33 @@ OperatorNodeBase* SubGraph::Rewriter::auto_replace_outputs( | |||||
// 1) other post-insert optimization (e.g. const folding) | // 1) other post-insert optimization (e.g. const folding) | ||||
// we can't handle only usable_output here, since some output var with | // we can't handle only usable_output here, since some output var with | ||||
// volatile flag could be the graph's endpoint (e.g. RemoteSend) | // volatile flag could be the graph's endpoint (e.g. RemoteSend) | ||||
for (; i < std::min(out0.size(), out1.size()); ++ i) { | |||||
for (; i < std::min(out0.size(), out1.size()); ++i) { | |||||
bool v0 = out0[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT), | bool v0 = out0[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT), | ||||
v1 = out1[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT); | v1 = out1[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT); | ||||
mgb_assert(v0 == v1, "%s", err_msg().c_str()); | mgb_assert(v0 == v1, "%s", err_msg().c_str()); | ||||
auto &&ins = m_varmap.insert({out0[i], {true, nullptr}}); | |||||
auto&& ins = m_varmap.insert({out0[i], {true, nullptr}}); | |||||
mgb_assert(ins.second || ins.first->second.first, | mgb_assert(ins.second || ins.first->second.first, | ||||
"opr output already replaced"); | "opr output already replaced"); | ||||
// handle repeated call on the same opr | // handle repeated call on the same opr | ||||
ins.first->second.second = out1[i]; | ins.first->second.second = out1[i]; | ||||
on_var_replaced(out0[i], out1[i], nullptr); | on_var_replaced(out0[i], out1[i], nullptr); | ||||
} | } | ||||
for (; i < out0.size(); ++ i) { | |||||
for (; i < out0.size(); ++i) { | |||||
mgb_assert(out0[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT), | mgb_assert(out0[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT), | ||||
"%s", err_msg().c_str()); | |||||
"%s", err_msg().c_str()); | |||||
} | } | ||||
for (; i < out1.size(); ++ i) { | |||||
for (; i < out1.size(); ++i) { | |||||
mgb_assert(out1[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT), | mgb_assert(out1[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT), | ||||
"%s", err_msg().c_str()); | |||||
"%s", err_msg().c_str()); | |||||
} | } | ||||
return new_opr; | return new_opr; | ||||
} | } | ||||
return opr; | return opr; | ||||
} | } | ||||
void SubGraph::Rewriter::replace_var( | |||||
VarNode *src, VarNode *dst, const char *msg) { | |||||
void SubGraph::Rewriter::replace_var(VarNode* src, VarNode* dst, | |||||
const char* msg) { | |||||
if (src == dst) | if (src == dst) | ||||
return; | return; | ||||
@@ -103,19 +104,19 @@ void SubGraph::Rewriter::replace_var( | |||||
"dst %s maps back to src %s in SubGraph::Rewriter::replace_var", | "dst %s maps back to src %s in SubGraph::Rewriter::replace_var", | ||||
dst->cname(), src->cname()); | dst->cname(), src->cname()); | ||||
auto &&ins = m_varmap.insert({src, {false, dst}}); | |||||
auto&& ins = m_varmap.insert({src, {false, dst}}); | |||||
if (!ins.second) { | if (!ins.second) { | ||||
auto &&old_rep = ins.first->second; | |||||
auto&& old_rep = ins.first->second; | |||||
mgb_assert(old_rep.first || old_rep.second == dst, | mgb_assert(old_rep.first || old_rep.second == dst, | ||||
"can not replace a var twice"); | |||||
"can not replace a var twice"); | |||||
old_rep.first = false; | old_rep.first = false; | ||||
old_rep.second = dst; | old_rep.second = dst; | ||||
} | } | ||||
on_var_replaced(src, dst, msg); | on_var_replaced(src, dst, msg); | ||||
} | } | ||||
void SubGraph::Rewriter::on_var_replaced( | |||||
VarNode* src, VarNode* dst, const char* msg) { | |||||
void SubGraph::Rewriter::on_var_replaced(VarNode* src, VarNode* dst, | |||||
const char* msg) { | |||||
if (auto state = m_owner_graph->owner_opt_state()) { | if (auto state = m_owner_graph->owner_opt_state()) { | ||||
state->on_var_replaced(src, dst, msg); | state->on_var_replaced(src, dst, msg); | ||||
} | } | ||||
@@ -124,7 +125,7 @@ void SubGraph::Rewriter::on_var_replaced( | |||||
void SubGraph::Rewriter::apply_inplace() const { | void SubGraph::Rewriter::apply_inplace() const { | ||||
m_owner_graph->m_endpoint_oprs.clear(); | m_owner_graph->m_endpoint_oprs.clear(); | ||||
m_owner_graph->m_endpoint_vars_set.clear(); | m_owner_graph->m_endpoint_vars_set.clear(); | ||||
for (auto &&var: m_owner_graph->m_endpoint_vars) { | |||||
for (auto&& var : m_owner_graph->m_endpoint_vars) { | |||||
var = get_var(var.node()); | var = get_var(var.node()); | ||||
m_owner_graph->m_endpoint_oprs.insert(var.node()->owner_opr()); | m_owner_graph->m_endpoint_oprs.insert(var.node()->owner_opr()); | ||||
m_owner_graph->m_endpoint_vars_set.insert(var.node()); | m_owner_graph->m_endpoint_vars_set.insert(var.node()); | ||||
@@ -150,33 +151,30 @@ std::pair<bool, VarNode*> SubGraph::Rewriter::get_var_internal(VarNode* var) { | |||||
return it->second = {it_next->second.first & it->second.first, next.second}; | return it->second = {it_next->second.first & it->second.first, next.second}; | ||||
} | } | ||||
SubGraph::SubGraph(const SymbolVarArray &endpoint_vars): | |||||
m_endpoint_vars(endpoint_vars) | |||||
{ | |||||
SubGraph::SubGraph(const SymbolVarArray& endpoint_vars) | |||||
: m_endpoint_vars(endpoint_vars) { | |||||
mgb_assert(!endpoint_vars.empty(), "endpoints can not be empty"); | mgb_assert(!endpoint_vars.empty(), "endpoints can not be empty"); | ||||
m_comp_graph = endpoint_vars[0].node()->owner_graph(); | m_comp_graph = endpoint_vars[0].node()->owner_graph(); | ||||
for (auto i: endpoint_vars) { | |||||
for (auto i : endpoint_vars) { | |||||
m_endpoint_oprs.insert(i.node()->owner_opr()); | m_endpoint_oprs.insert(i.node()->owner_opr()); | ||||
m_endpoint_vars_set.insert(i.node()); | m_endpoint_vars_set.insert(i.node()); | ||||
mgb_assert(m_comp_graph == i.node()->owner_graph(), | mgb_assert(m_comp_graph == i.node()->owner_graph(), | ||||
"endpoints belong to different computing graphs"); | |||||
"endpoints belong to different computing graphs"); | |||||
} | } | ||||
} | } | ||||
void SubGraph::iter( | |||||
const Callback& cb, | |||||
std::shared_ptr<ExtraDep> extra_dep) const { | |||||
void SubGraph::iter(const Callback& cb, | |||||
std::shared_ptr<ExtraDep> extra_dep) const { | |||||
Callback on_opr; | Callback on_opr; | ||||
if (m_owner_opt_state) { | if (m_owner_opt_state) { | ||||
on_opr = [state=m_owner_opt_state, &cb](OperatorNodeBase *opr) { | |||||
on_opr = [state = m_owner_opt_state, &cb](OperatorNodeBase* opr) { | |||||
state->m_opr_property_flag = OprPropertyFlag::ALL; | state->m_opr_property_flag = OprPropertyFlag::ALL; | ||||
state->m_cur_iter_src_opr = cg::get_opr_root_source_opr(opr); | state->m_cur_iter_src_opr = cg::get_opr_root_source_opr(opr); | ||||
state->m_cur_iter_opr_priority = | state->m_cur_iter_opr_priority = | ||||
opr->node_prop().attribute().priority; | |||||
opr->node_prop().attribute().priority; | |||||
state->m_cur_iter_opr_stream_prop_type = | state->m_cur_iter_opr_stream_prop_type = | ||||
state->m_comp_node_opt.stream_prop_type( | |||||
opr->output(0)); | |||||
state->m_comp_node_opt.stream_prop_type(opr->output(0)); | |||||
mgb_assert(state->m_oprs_inserted.empty()); | mgb_assert(state->m_oprs_inserted.empty()); | ||||
cb(opr); | cb(opr); | ||||
state->m_opr_property_flag = OprPropertyFlag::NONE; | state->m_opr_property_flag = OprPropertyFlag::NONE; | ||||
@@ -188,19 +186,19 @@ void SubGraph::iter( | |||||
} | } | ||||
cg::DepOprIter dep_iter{on_opr, std::move(extra_dep)}; | cg::DepOprIter dep_iter{on_opr, std::move(extra_dep)}; | ||||
for (auto i: m_endpoint_oprs) | |||||
for (auto i : m_endpoint_oprs) | |||||
dep_iter.add(i); | dep_iter.add(i); | ||||
} | } | ||||
ThinHashMap<VarNode*, size_t> SubGraph::get_var2nr_val_dep_oprs() const { | ThinHashMap<VarNode*, size_t> SubGraph::get_var2nr_val_dep_oprs() const { | ||||
ThinHashMap<VarNode*, size_t> ret; | ThinHashMap<VarNode*, size_t> ret; | ||||
auto cb = [&](OperatorNodeBase *opr) { | |||||
for (auto &&i: opr->node_prop().dep_map()) { | |||||
auto cb = [&](OperatorNodeBase* opr) { | |||||
for (auto&& i : opr->node_prop().dep_map()) { | |||||
if (OperatorNodeBase::NodeProp::is_device_value_dep(i.second)) { | if (OperatorNodeBase::NodeProp::is_device_value_dep(i.second)) { | ||||
++ ret.at(i.first); | |||||
++ret.at(i.first); | |||||
} | } | ||||
} | } | ||||
for (auto i: opr->output()) { | |||||
for (auto i : opr->output()) { | |||||
if (!i->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { | if (!i->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { | ||||
auto ins = ret.insert({i, 0}); | auto ins = ret.insert({i, 0}); | ||||
mgb_assert(ins.second); | mgb_assert(ins.second); | ||||
@@ -208,13 +206,13 @@ ThinHashMap<VarNode*, size_t> SubGraph::get_var2nr_val_dep_oprs() const { | |||||
} | } | ||||
}; | }; | ||||
iter(cb); | iter(cb); | ||||
for (auto i: m_endpoint_vars_set) { | |||||
for (auto i : m_endpoint_vars_set) { | |||||
auto iter = ret.find(i); | auto iter = ret.find(i); | ||||
if (iter == ret.end()) { | if (iter == ret.end()) { | ||||
mgb_assert(i->contain_flag(VarNode::Flag::VOLATILE_CONTENT)); | mgb_assert(i->contain_flag(VarNode::Flag::VOLATILE_CONTENT)); | ||||
ret[i] = 1; | ret[i] = 1; | ||||
} else { | } else { | ||||
++ ret.at(i); | |||||
++ret.at(i); | |||||
} | } | ||||
} | } | ||||
return ret; | return ret; | ||||
@@ -222,10 +220,8 @@ ThinHashMap<VarNode*, size_t> SubGraph::get_var2nr_val_dep_oprs() const { | |||||
/* ================ UniqReaderCheck ================ */ | /* ================ UniqReaderCheck ================ */ | ||||
UniqReaderCheck::UniqReaderCheck(const SubGraph &graph): | |||||
m_var2nr_val_dep{graph.get_var2nr_val_dep_oprs()} | |||||
{ | |||||
} | |||||
UniqReaderCheck::UniqReaderCheck(const SubGraph& graph) | |||||
: m_var2nr_val_dep{graph.get_var2nr_val_dep_oprs()} {} | |||||
void UniqReaderCheck::update_on_opr_auto_replace(OperatorNodeBase* opr, | void UniqReaderCheck::update_on_opr_auto_replace(OperatorNodeBase* opr, | ||||
OperatorNodeBase* repl_opr) { | OperatorNodeBase* repl_opr) { | ||||
@@ -253,32 +249,30 @@ void UniqReaderCheck::update_on_opr_auto_replace(OperatorNodeBase* opr, | |||||
/* ================ OptState ================ */ | /* ================ OptState ================ */ | ||||
OptState::OptState( | |||||
const GraphOptimizer *owner_optimizer, const SubGraph& graph): | |||||
m_owner_optimizer{owner_optimizer}, | |||||
m_var_replace_map{ | |||||
const_cast<ThinHashMap<VarNode*, VarNode*>*>( | |||||
&GraphOptimizer::var_replace_map(*graph.comp_graph()))}, | |||||
m_comp_node_opt{graph.comp_graph()->seq_comp_node_optimizer()}, | |||||
m_graph{graph} | |||||
{ | |||||
OptState::OptState(const GraphOptimizer* owner_optimizer, const SubGraph& graph) | |||||
: m_owner_optimizer{owner_optimizer}, | |||||
m_var_replace_map{const_cast<ThinHashMap<VarNode*, VarNode*>*>( | |||||
&GraphOptimizer::var_replace_map(*graph.comp_graph()))}, | |||||
m_comp_node_opt{graph.comp_graph()->seq_comp_node_optimizer()}, | |||||
m_graph{graph} { | |||||
mgb_assert(!m_graph.m_owner_opt_state); | mgb_assert(!m_graph.m_owner_opt_state); | ||||
m_var_replace_map->clear(); | m_var_replace_map->clear(); | ||||
m_graph.m_owner_opt_state = this; | m_graph.m_owner_opt_state = this; | ||||
m_oprs_inserted.clear(); | m_oprs_inserted.clear(); | ||||
auto on_opr_insert = [this](const cg::event::OprInserted &ev) { | |||||
auto on_opr_insert = [this](const cg::event::OprInserted& ev) { | |||||
auto need_src_opr = m_opr_property_flag & OprPropertyFlag::SOURCE_OPR, | auto need_src_opr = m_opr_property_flag & OprPropertyFlag::SOURCE_OPR, | ||||
need_priority = m_opr_property_flag & OprPropertyFlag::PRIORITY; | need_priority = m_opr_property_flag & OprPropertyFlag::PRIORITY; | ||||
if (need_src_opr) | if (need_src_opr) | ||||
mgb_assert(m_cur_iter_src_opr, "opr %s{%s} created outside from " | |||||
"SubGraph::iter", | |||||
ev.opr->cname(), ev.opr->dyn_typeinfo()->name); | |||||
mgb_assert(m_cur_iter_src_opr, | |||||
"opr %s{%s} created outside from " | |||||
"SubGraph::iter", | |||||
ev.opr->cname(), ev.opr->dyn_typeinfo()->name); | |||||
if (ev.exc || ev.is_dedup) | if (ev.exc || ev.is_dedup) | ||||
return; | return; | ||||
auto &&new_attr = ev.opr->node_prop().attribute(); | |||||
auto &&ins = m_oprs_inserted.insert({ev.opr, OprPropertyFlag::NONE}); | |||||
auto&& new_attr = ev.opr->node_prop().attribute(); | |||||
auto&& ins = m_oprs_inserted.insert({ev.opr, OprPropertyFlag::NONE}); | |||||
mgb_assert(ins.second); | mgb_assert(ins.second); | ||||
if (need_src_opr && !new_attr.src_opr) { | if (need_src_opr && !new_attr.src_opr) { | ||||
@@ -296,20 +290,22 @@ OptState::OptState( | |||||
auto csp = m_cur_iter_opr_stream_prop_type; | auto csp = m_cur_iter_opr_stream_prop_type; | ||||
if (csp.prop_type != cg::SeqCompNodeOptimizer::StreamPropType::NONE) { | if (csp.prop_type != cg::SeqCompNodeOptimizer::StreamPropType::NONE) { | ||||
for (auto i: ev.opr->output()) | |||||
for (auto i : ev.opr->output()) | |||||
m_comp_node_opt.register_stream_var(i, csp); | m_comp_node_opt.register_stream_var(i, csp); | ||||
} | } | ||||
}; | }; | ||||
m_on_opr_insert_handler = graph.comp_graph()->event().register_receiver< | |||||
cg::event::OprInserted>(on_opr_insert); | |||||
m_on_opr_insert_handler = | |||||
graph.comp_graph() | |||||
->event() | |||||
.register_receiver<cg::event::OprInserted>(on_opr_insert); | |||||
} | } | ||||
void OptState::on_var_replaced(VarNode *src, VarNode *dst, const char *msg) { | |||||
void OptState::on_var_replaced(VarNode* src, VarNode* dst, const char* msg) { | |||||
if (src->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { | if (src->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { | ||||
// this can only happen in auto_replace_outputs() | // this can only happen in auto_replace_outputs() | ||||
mgb_assert(dst->contain_flag(VarNode::Flag::VOLATILE_CONTENT) && | mgb_assert(dst->contain_flag(VarNode::Flag::VOLATILE_CONTENT) && | ||||
src->owner_opr()->dyn_typeinfo() == | |||||
dst->owner_opr()->dyn_typeinfo()); | |||||
src->owner_opr()->dyn_typeinfo() == | |||||
dst->owner_opr()->dyn_typeinfo()); | |||||
mgb_assert(!msg); | mgb_assert(!msg); | ||||
return; | return; | ||||
} | } | ||||
@@ -362,7 +358,7 @@ void OptState::on_var_replaced(VarNode *src, VarNode *dst, const char *msg) { | |||||
return f & (InferType::RT_STATIC | InferType::CONST); | return f & (InferType::RT_STATIC | InferType::CONST); | ||||
}; | }; | ||||
if (!(norm(it0.shape) == norm(it1.shape) && | if (!(norm(it0.shape) == norm(it1.shape) && | ||||
norm(it0.value) <= norm(it1.value))) { | |||||
norm(it0.value) <= norm(it1.value))) { | |||||
suc = false; | suc = false; | ||||
fail_chks.push_back("infer-type"); | fail_chks.push_back("infer-type"); | ||||
} | } | ||||
@@ -407,22 +403,21 @@ void OptState::on_var_replaced(VarNode *src, VarNode *dst, const char *msg) { | |||||
#if MGB_ENABLE_LOGGING | #if MGB_ENABLE_LOGGING | ||||
if (msg && m_owner_optimizer->verbosity()) { | if (msg && m_owner_optimizer->verbosity()) { | ||||
m_log_msg. | |||||
append("\n "). | |||||
append(std::to_string(m_log_nr_item)). | |||||
append(": "). | |||||
append(src->owner_opr()->cname()). | |||||
append(" => "). | |||||
append(dst->owner_opr()->cname()). | |||||
append(" ("). | |||||
append(msg). | |||||
append(")"); | |||||
} | |||||
++ m_log_nr_item; | |||||
m_log_msg.append("\n ") | |||||
.append(std::to_string(m_log_nr_item)) | |||||
.append(": ") | |||||
.append(src->owner_opr()->cname()) | |||||
.append(" => ") | |||||
.append(dst->owner_opr()->cname()) | |||||
.append(" (") | |||||
.append(msg) | |||||
.append(")"); | |||||
} | |||||
++m_log_nr_item; | |||||
#endif | #endif | ||||
} | } | ||||
size_t OptState::flush_log(const char *title) { | |||||
size_t OptState::flush_log(const char* title) { | |||||
if (m_owner_optimizer->verbosity() >= 2) { | if (m_owner_optimizer->verbosity() >= 2) { | ||||
if (m_log_msg.empty()) { | if (m_log_msg.empty()) { | ||||
m_log_msg = mgb_cstr_log(" no var replacement logged"); | m_log_msg = mgb_cstr_log(" no var replacement logged"); | ||||
@@ -435,42 +430,40 @@ size_t OptState::flush_log(const char *title) { | |||||
return ret; | return ret; | ||||
} | } | ||||
void OptState::call_with_opr(OperatorNodeBase *opr, thin_function<void(void)> func, | |||||
void OptState::call_with_opr(OperatorNodeBase* opr, | |||||
thin_function<void(void)> func, | |||||
OprPropertyFlag opr_property_flag) { | OprPropertyFlag opr_property_flag) { | ||||
auto src_opr = cg::get_opr_root_source_opr(opr); | auto src_opr = cg::get_opr_root_source_opr(opr); | ||||
auto opr_priority = opr->node_prop().attribute().priority; | auto opr_priority = opr->node_prop().attribute().priority; | ||||
auto stream_prop_type = m_comp_node_opt.stream_prop_type(opr->output(0)); | auto stream_prop_type = m_comp_node_opt.stream_prop_type(opr->output(0)); | ||||
ThinHashMap<OperatorNodeBase*, OprPropertyFlag> oprs_inserted; | ThinHashMap<OperatorNodeBase*, OprPropertyFlag> oprs_inserted; | ||||
auto swap_properties = [&, | |||||
need_src_opr = opr_property_flag & OprPropertyFlag::SOURCE_OPR, | |||||
need_priority = opr_property_flag & OprPropertyFlag::PRIORITY] { | |||||
if (need_src_opr) { | |||||
std::swap(m_cur_iter_src_opr, src_opr); | |||||
} | |||||
if (need_priority) { | |||||
std::swap(m_cur_iter_opr_priority, opr_priority); | |||||
} | |||||
std::swap(m_cur_iter_opr_stream_prop_type, stream_prop_type); | |||||
std::swap(m_opr_property_flag, opr_property_flag); | |||||
std::swap(m_oprs_inserted, oprs_inserted); | |||||
}; | |||||
auto swap_properties = | |||||
[&, need_src_opr = opr_property_flag & OprPropertyFlag::SOURCE_OPR, | |||||
need_priority = opr_property_flag & OprPropertyFlag::PRIORITY] { | |||||
if (need_src_opr) { | |||||
std::swap(m_cur_iter_src_opr, src_opr); | |||||
} | |||||
if (need_priority) { | |||||
std::swap(m_cur_iter_opr_priority, opr_priority); | |||||
} | |||||
std::swap(m_cur_iter_opr_stream_prop_type, stream_prop_type); | |||||
std::swap(m_opr_property_flag, opr_property_flag); | |||||
std::swap(m_oprs_inserted, oprs_inserted); | |||||
}; | |||||
MGB_TRY { | MGB_TRY { | ||||
swap_properties(); | swap_properties(); | ||||
func(); | func(); | ||||
} MGB_FINALLY({ | |||||
swap_properties(); | |||||
}); | |||||
} | |||||
MGB_FINALLY({ swap_properties(); }); | |||||
} | } | ||||
/* ================ RecursiveSubGraphRewriteHelper ================ */ | /* ================ RecursiveSubGraphRewriteHelper ================ */ | ||||
RecursiveSubGraphRewriteHelper:: | |||||
~RecursiveSubGraphRewriteHelper() noexcept = default; | |||||
RecursiveSubGraphRewriteHelper::~RecursiveSubGraphRewriteHelper() noexcept = | |||||
default; | |||||
RecursiveSubGraphRewriteHelper::RecursiveSubGraphRewriteHelper(OptState &state): | |||||
m_opt_state{state}, m_rewriter{state.graph().make_rewriter()} | |||||
{ | |||||
} | |||||
RecursiveSubGraphRewriteHelper::RecursiveSubGraphRewriteHelper(OptState& state) | |||||
: m_opt_state{state}, m_rewriter{state.graph().make_rewriter()} {} | |||||
void RecursiveSubGraphRewriteHelper::apply() { | void RecursiveSubGraphRewriteHelper::apply() { | ||||
using namespace std::placeholders; | using namespace std::placeholders; | ||||
@@ -479,8 +472,8 @@ void RecursiveSubGraphRewriteHelper::apply() { | |||||
m_rewriter.apply_inplace(); | m_rewriter.apply_inplace(); | ||||
} | } | ||||
void RecursiveSubGraphRewriteHelper::on_opr(OperatorNodeBase *opr) { | |||||
auto on_new_opr = [this](OperatorNodeBase *opr) { | |||||
void RecursiveSubGraphRewriteHelper::on_opr(OperatorNodeBase* opr) { | |||||
auto on_new_opr = [this](OperatorNodeBase* opr) { | |||||
auto repl_opr = m_rewriter.auto_replace_outputs(opr); | auto repl_opr = m_rewriter.auto_replace_outputs(opr); | ||||
return on_new_opr_check_should_process(opr, repl_opr); | return on_new_opr_check_should_process(opr, repl_opr); | ||||
}; | }; | ||||
@@ -493,8 +486,8 @@ void RecursiveSubGraphRewriteHelper::on_opr(OperatorNodeBase *opr) { | |||||
return; | return; | ||||
mgb_assert(m_opr_stack.empty()); | mgb_assert(m_opr_stack.empty()); | ||||
m_opr_stack.push_back({ | |||||
orig_out, m_rewriter.get_var(orig_out)->owner_opr()}); | |||||
m_opr_stack.push_back( | |||||
{orig_out, m_rewriter.get_var(orig_out)->owner_opr()}); | |||||
bool first = true; | bool first = true; | ||||
while (!m_opr_stack.empty()) { | while (!m_opr_stack.empty()) { | ||||
@@ -515,9 +508,9 @@ void RecursiveSubGraphRewriteHelper::on_opr(OperatorNodeBase *opr) { | |||||
if (should_process) { | if (should_process) { | ||||
auto trans = process_opr(cur_out); | auto trans = process_opr(cur_out); | ||||
if (trans.valid()) { | if (trans.valid()) { | ||||
m_opr_stack.push_back({ | |||||
cur_frame.orig_var, trans->result->owner_opr()}); | |||||
for (auto i: reverse_adaptor(trans->internal)) { | |||||
m_opr_stack.push_back( | |||||
{cur_frame.orig_var, trans->result->owner_opr()}); | |||||
for (auto i : reverse_adaptor(trans->internal)) { | |||||
if (i) | if (i) | ||||
m_opr_stack.push_back({i, i->owner_opr()}); | m_opr_stack.push_back({i, i->owner_opr()}); | ||||
} | } | ||||
@@ -532,7 +525,7 @@ void RecursiveSubGraphRewriteHelper::on_opr(OperatorNodeBase *opr) { | |||||
auto src = cur_frame.orig_var; | auto src = cur_frame.orig_var; | ||||
if (m_rewriter.get_var(src) != cur_out) { | if (m_rewriter.get_var(src) != cur_out) { | ||||
const char *msg = nullptr; | |||||
const char* msg = nullptr; | |||||
if (m_opr_stack.empty()) { | if (m_opr_stack.empty()) { | ||||
msg = m_log_msg.c_str(); | msg = m_log_msg.c_str(); | ||||
} | } | ||||
@@ -550,11 +543,12 @@ void RecursiveSubGraphRewriteHelper::on_opr(OperatorNodeBase *opr) { | |||||
GraphOptimizer::~GraphOptimizer() noexcept = default; | GraphOptimizer::~GraphOptimizer() noexcept = default; | ||||
class GraphOptimizer::VarReplaceMapStorage :public UserDataContainer::UserData { | |||||
class GraphOptimizer::VarReplaceMapStorage | |||||
: public UserDataContainer::UserData { | |||||
MGB_TYPEINFO_OBJ_DECL; | MGB_TYPEINFO_OBJ_DECL; | ||||
public: | |||||
ThinHashMap<VarNode*, VarNode*> map; | |||||
public: | |||||
ThinHashMap<VarNode*, VarNode*> map; | |||||
}; | }; | ||||
MGB_TYPEINFO_OBJ_IMPL(GraphOptimizer::VarReplaceMapStorage); | MGB_TYPEINFO_OBJ_IMPL(GraphOptimizer::VarReplaceMapStorage); | ||||
@@ -565,7 +559,7 @@ GraphOptimizer& GraphOptimizer::add_pass(std::unique_ptr<Pass> pass) { | |||||
return *this; | return *this; | ||||
} | } | ||||
SubGraph GraphOptimizer::apply(const SubGraph &graph) const { | |||||
SubGraph GraphOptimizer::apply(const SubGraph& graph) const { | |||||
RealTimer timer; | RealTimer timer; | ||||
OptState state{this, graph}; | OptState state{this, graph}; | ||||
@@ -574,38 +568,38 @@ SubGraph GraphOptimizer::apply(const SubGraph &graph) const { | |||||
// first update output var shapes of all oprs | // first update output var shapes of all oprs | ||||
state.graph().iter(cg::update_output_var_shapes); | state.graph().iter(cg::update_output_var_shapes); | ||||
auto &&opt = graph.comp_graph()->options(); | |||||
auto&& opt = graph.comp_graph()->options(); | |||||
auto orig_setting = opt.graph_opt_level; | auto orig_setting = opt.graph_opt_level; | ||||
Pass *cur_pass = nullptr; | |||||
Pass* cur_pass = nullptr; | |||||
MGB_MARK_USED_VAR(cur_pass); | MGB_MARK_USED_VAR(cur_pass); | ||||
MGB_TRY { | MGB_TRY { | ||||
for (auto &&i: m_passes) { | |||||
for (auto&& i : m_passes) { | |||||
state.set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_ALL); | state.set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_ALL); | ||||
cur_pass = i.get(); | cur_pass = i.get(); | ||||
opt.graph_opt_level = 1; | opt.graph_opt_level = 1; | ||||
i->apply(state); | i->apply(state); | ||||
tot_nr_replace += state.flush_log( | tot_nr_replace += state.flush_log( | ||||
mgb_ssprintf_log( | |||||
"apply optimization pass %s:", i->name()).c_str()); | |||||
mgb_ssprintf_log("apply optimization pass %s:", i->name()) | |||||
.c_str()); | |||||
} | } | ||||
} MGB_CATCH(std::exception &exc, { | |||||
} | |||||
MGB_CATCH(std::exception & exc, { | |||||
mgb_log_error("error while applying optimization pass %s: %s", | mgb_log_error("error while applying optimization pass %s: %s", | ||||
cur_pass->name(), exc.what()); | |||||
cur_pass->name(), exc.what()); | |||||
opt.graph_opt_level = orig_setting; | opt.graph_opt_level = orig_setting; | ||||
throw; | throw; | ||||
}) | }) | ||||
MGB_FINALLY( | |||||
opt.graph_opt_level = orig_setting | |||||
); | |||||
MGB_FINALLY(opt.graph_opt_level = orig_setting); | |||||
if (verbosity() >= 1) { | if (verbosity() >= 1) { | ||||
mgb_log_debug("graph optimization: applied %zu passes, " | |||||
mgb_log_debug( | |||||
"graph optimization: applied %zu passes, " | |||||
"total %zu var(s) replaced; time=%.2fms", | "total %zu var(s) replaced; time=%.2fms", | ||||
m_passes.size(), tot_nr_replace, timer.get_msecs()); | m_passes.size(), tot_nr_replace, timer.get_msecs()); | ||||
} | } | ||||
return state.graph(); | return state.graph(); | ||||
} | } | ||||
const GraphOptimizer& GraphOptimizer::apply_inplace(VarNodeArray &vars) const { | |||||
const GraphOptimizer& GraphOptimizer::apply_inplace(VarNodeArray& vars) const { | |||||
if (m_passes.empty()) { | if (m_passes.empty()) { | ||||
// this check is necessary, since OptState would clear | // this check is necessary, since OptState would clear | ||||
// var_replace_map() | // var_replace_map() | ||||
@@ -613,7 +607,7 @@ const GraphOptimizer& GraphOptimizer::apply_inplace(VarNodeArray &vars) const { | |||||
} | } | ||||
auto g = apply({{vars.begin(), vars.end()}}); | auto g = apply({{vars.begin(), vars.end()}}); | ||||
for (size_t i = 0; i < vars.size(); ++ i) { | |||||
for (size_t i = 0; i < vars.size(); ++i) { | |||||
vars[i] = g.endpoint_vars()[i].node(); | vars[i] = g.endpoint_vars()[i].node(); | ||||
} | } | ||||
return *this; | return *this; | ||||
@@ -653,7 +647,7 @@ GraphOptimizer& GraphOptimizer::add_preset_passes( | |||||
#if MGB_JIT | #if MGB_JIT | ||||
bool need_jit = false; | bool need_jit = false; | ||||
if (comp_graph_opt && (std::abs(comp_graph_opt->graph_opt_level) >= 3 || | if (comp_graph_opt && (std::abs(comp_graph_opt->graph_opt_level) >= 3 || | ||||
comp_graph_opt->graph_opt.jit)) { | |||||
comp_graph_opt->graph_opt.jit)) { | |||||
need_jit = true; | need_jit = true; | ||||
} | } | ||||
if (need_jit && after_grad) { | if (need_jit && after_grad) { | ||||
@@ -679,7 +673,6 @@ GraphOptimizer& GraphOptimizer::add_preset_passes( | |||||
add_passes_for_optimize_options(*inference_opt); | add_passes_for_optimize_options(*inference_opt); | ||||
} | } | ||||
if (inference_opt) { | if (inference_opt) { | ||||
// merge params to reduce loading time and graph overhead | // merge params to reduce loading time and graph overhead | ||||
add_pass<ParamMergePass>(); | add_pass<ParamMergePass>(); | ||||
@@ -689,15 +682,16 @@ GraphOptimizer& GraphOptimizer::add_preset_passes( | |||||
} | } | ||||
const ThinHashMap<VarNode*, VarNode*>& GraphOptimizer::var_replace_map( | const ThinHashMap<VarNode*, VarNode*>& GraphOptimizer::var_replace_map( | ||||
ComputingGraph &graph) { | |||||
auto storage = graph.options().user_data.get_user_data_or_create< | |||||
VarReplaceMapStorage>(); | |||||
ComputingGraph& graph) { | |||||
auto storage = | |||||
graph.options() | |||||
.user_data.get_user_data_or_create<VarReplaceMapStorage>(); | |||||
return storage->map; | return storage->map; | ||||
} | } | ||||
VarNode* GraphOptimizer::var_replace_lookup(VarNode *var) { | |||||
auto &&map = var_replace_map(*(var->owner_graph())); | |||||
for (; ; ) { | |||||
VarNode* GraphOptimizer::var_replace_lookup(VarNode* var) { | |||||
auto&& map = var_replace_map(*(var->owner_graph())); | |||||
for (;;) { | |||||
auto iter = map.find(var); | auto iter = map.find(var); | ||||
if (iter == map.end()) | if (iter == map.end()) | ||||
return var; | return var; | ||||
@@ -705,7 +699,6 @@ VarNode* GraphOptimizer::var_replace_lookup(VarNode *var) { | |||||
} | } | ||||
} | } | ||||
const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options( | const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options( | ||||
const cg::GraphCommonOptimizeOptions& options) { | const cg::GraphCommonOptimizeOptions& options) { | ||||
return add_passes_for_optimize_options( | return add_passes_for_optimize_options( | ||||
@@ -723,12 +716,14 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options( | |||||
options.disable_##_option(); \ | options.disable_##_option(); \ | ||||
} \ | } \ | ||||
} | } | ||||
cb(fuse_preprocess, {add_pass(FuseNCHW4Int8Preprocess::make());}); | |||||
cb(fuse_preprocess, { | |||||
add_pass(FuseNCHW4Int8Preprocess::make()); | |||||
add_pass<FuseWarpPerspectiveDimshufflePass>(); | |||||
}); | |||||
cb(f16_io_comp, { add_pass(ConvertF32ToF16Pass::make(false)); }); | cb(f16_io_comp, { add_pass(ConvertF32ToF16Pass::make(false)); }); | ||||
cb(f16_io_f32_comp, { add_pass(ConvertF32ToF16Pass::make(true)); }); | cb(f16_io_f32_comp, { add_pass(ConvertF32ToF16Pass::make(true)); }); | ||||
cb(nchw4, { | cb(nchw4, { | ||||
add_pass<FuseConvBiasNonlinPass>(); | add_pass<FuseConvBiasNonlinPass>(); | ||||
add_pass<FuseConvBiasZPass>(); | add_pass<FuseConvBiasZPass>(); | ||||
@@ -763,6 +758,7 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options( | |||||
add_pass<ShuffleShuffleRemovePass>(); | add_pass<ShuffleShuffleRemovePass>(); | ||||
add_pass<RemoveRedundantTypeCvtPass>(); | add_pass<RemoveRedundantTypeCvtPass>(); | ||||
add_pass(FuseNCHW4Int8Preprocess::make()); | add_pass(FuseNCHW4Int8Preprocess::make()); | ||||
add_pass<FuseWarpPerspectiveDimshufflePass>(); | |||||
}); | }); | ||||
cb(chwn4, { | cb(chwn4, { | ||||
add_pass<FuseConvBiasNonlinPass>(); | add_pass<FuseConvBiasNonlinPass>(); | ||||
@@ -790,9 +786,9 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options( | |||||
/* ================ ConstVarPropogateBase ================ */ | /* ================ ConstVarPropogateBase ================ */ | ||||
ConstVarPropogate::AddOprResult ConstVarPropogate::add_opr( | ConstVarPropogate::AddOprResult ConstVarPropogate::add_opr( | ||||
OperatorNodeBase *opr) { | |||||
OperatorNodeBase* opr) { | |||||
using ProfFlag = OperatorNodeBase::NodeProp::Flag; | using ProfFlag = OperatorNodeBase::NodeProp::Flag; | ||||
auto &&info = m_oprinfo[opr]; | |||||
auto&& info = m_oprinfo[opr]; | |||||
if (info.processed) | if (info.processed) | ||||
return info.result; | return info.result; | ||||
info.processed = true; | info.processed = true; | ||||
@@ -819,15 +815,14 @@ ConstVarPropogate::AddOprResult ConstVarPropogate::add_opr( | |||||
if (opr->input().empty()) | if (opr->input().empty()) | ||||
return make_ret(); | return make_ret(); | ||||
if (opr->node_prop().contain( | |||||
ProfFlag::FORCE_UPDATE_INPUT_VAR | | |||||
ProfFlag::IMPURE_FUNC)) { | |||||
if (opr->node_prop().contain(ProfFlag::FORCE_UPDATE_INPUT_VAR | | |||||
ProfFlag::IMPURE_FUNC)) { | |||||
return make_ret(); | return make_ret(); | ||||
} | } | ||||
size_t max_input_size = 0; | size_t max_input_size = 0; | ||||
ret.all_const_inp = true; | ret.all_const_inp = true; | ||||
for (auto i: opr->input()) { | |||||
for (auto i : opr->input()) { | |||||
auto io = i->owner_opr(); | auto io = i->owner_opr(); | ||||
auto iter = m_oprinfo.find(io); | auto iter = m_oprinfo.find(io); | ||||
if (iter == m_oprinfo.end()) { | if (iter == m_oprinfo.end()) { | ||||
@@ -835,7 +830,7 @@ ConstVarPropogate::AddOprResult ConstVarPropogate::add_opr( | |||||
iter = m_oprinfo.find(io); | iter = m_oprinfo.find(io); | ||||
mgb_assert(iter != m_oprinfo.end()); | mgb_assert(iter != m_oprinfo.end()); | ||||
} | } | ||||
auto &&src = iter->second; | |||||
auto&& src = iter->second; | |||||
if (src.is_const) { | if (src.is_const) { | ||||
update_max(max_input_size, src.max_size); | update_max(max_input_size, src.max_size); | ||||
ret.has_const_inp = true; | ret.has_const_inp = true; | ||||
@@ -19,6 +19,7 @@ | |||||
#include "megbrain/opr/utility.h" | #include "megbrain/opr/utility.h" | ||||
#include "megbrain/serialization/opr_shallow_copy.h" | #include "megbrain/serialization/opr_shallow_copy.h" | ||||
#include "megbrain/serialization/serializer.h" | #include "megbrain/serialization/serializer.h" | ||||
#include "megbrain/opr/imgproc.h" | |||||
using namespace mgb; | using namespace mgb; | ||||
using namespace gopt; | using namespace gopt; | ||||
@@ -443,4 +444,244 @@ void FuseNCHW4Int8Preprocess::apply(OptState& state) const { | |||||
}; | }; | ||||
state.graph().iter(on_opr); | state.graph().iter(on_opr); | ||||
rewriter.apply_inplace(); | rewriter.apply_inplace(); | ||||
} | |||||
/* ==================== FuseWarpPerspectiveDimshufflePass ================= */ | |||||
const char* FuseWarpPerspectiveDimshufflePass::name() const { | |||||
return mgb_cstr_log("Fuse warp perspective dimshuffle pass"); | |||||
} | |||||
void FuseWarpPerspectiveDimshufflePass::apply(OptState& opt) const { | |||||
auto rewriter = opt.graph().make_rewriter(); | |||||
auto uniq_reader_check = UniqReaderCheck{opt.graph()}; | |||||
auto make_new_warp = [&rewriter](opr::WarpPerspective* warp, | |||||
opr::WarpPerspective::Param new_param, | |||||
megdnn::DType dst_dtype, | |||||
SymbolVar& new_warp) { | |||||
OperatorNodeConfig new_config(dst_dtype); | |||||
if (warp->input().size() == 3) { | |||||
auto src = rewriter.get_var(warp->input(0)), | |||||
mat = rewriter.get_var(warp->input(1)), | |||||
out_shape = rewriter.get_var(warp->input(2)); | |||||
new_warp = opr::WarpPerspective::make(src, mat, out_shape, | |||||
new_param, new_config); | |||||
} else { | |||||
mgb_assert(warp->input().size() == 4); | |||||
auto src = rewriter.get_var(warp->input(0)), | |||||
mat = rewriter.get_var(warp->input(1)), | |||||
mat_idx = rewriter.get_var(warp->input(2)), | |||||
out_shape = rewriter.get_var(warp->input(3)); | |||||
new_warp = opr::WarpPerspective::make(src, mat, mat_idx, out_shape, | |||||
new_param, new_config); | |||||
} | |||||
}; | |||||
auto is_warp_nchw = [&uniq_reader_check](OperatorNodeBase* bottom_opr, | |||||
OperatorNodeBase*& top_opr) { | |||||
// check warp | |||||
auto warp = try_cast_as_op<opr::WarpPerspective>(bottom_opr); | |||||
if (warp == nullptr) | |||||
return false; | |||||
auto inp_dtype = warp->input(0)->dtype(); | |||||
bool is_u8_or_qu8 = inp_dtype.enumv() == DTypeEnum::Quantized8Asymm || | |||||
inp_dtype.enumv() == DTypeEnum::Uint8; | |||||
bool is_nchw = warp->param().format == | |||||
megdnn::param::WarpPerspective::Format::NCHW; | |||||
if (!(is_u8_or_qu8 && is_nchw)) | |||||
return false; | |||||
if (!uniq_reader_check(warp->input(0))) | |||||
return false; | |||||
top_opr = warp; | |||||
return true; | |||||
}; | |||||
auto is_warp_nhwc2nchw = [&uniq_reader_check](OperatorNodeBase* bottom_opr, | |||||
OperatorNodeBase*& top_opr) { | |||||
// check shuffle | |||||
auto shuffle = try_cast_as_op<opr::Dimshuffle>(bottom_opr); | |||||
if (shuffle == nullptr) | |||||
return false; | |||||
auto&& shuffle_param = shuffle->param(); | |||||
if (shuffle_param.pattern_len != 4) | |||||
return false; | |||||
bool is_nhwc2nchw = shuffle_param.pattern[0] == 0 && | |||||
shuffle_param.pattern[1] == 3 && | |||||
shuffle_param.pattern[2] == 1 && | |||||
shuffle_param.pattern[3] == 2; | |||||
if (!is_nhwc2nchw) | |||||
return false; | |||||
if (!uniq_reader_check(shuffle->input(0))) | |||||
return false; | |||||
// check warp | |||||
auto warp = try_cast_as_op<opr::WarpPerspective>( | |||||
shuffle->input(0)->owner_opr()); | |||||
if (warp == nullptr) | |||||
return false; | |||||
auto inp_dtype = warp->input(0)->dtype(); | |||||
bool is_u8_or_qu8 = inp_dtype.enumv() == DTypeEnum::Quantized8Asymm || | |||||
inp_dtype.enumv() == DTypeEnum::Uint8; | |||||
bool is_nhwc = warp->param().format == | |||||
megdnn::param::WarpPerspective::Format::NHWC; | |||||
if (!(is_u8_or_qu8 && is_nhwc)) | |||||
return false; | |||||
top_opr = warp; | |||||
return true; | |||||
}; | |||||
auto try_warp_nchw_typecvt = [&rewriter, &uniq_reader_check, &is_warp_nchw, | |||||
&make_new_warp](OperatorNodeBase* opr) { | |||||
// check typecvt | |||||
auto typecvt = try_cast_as_op<opr::TypeCvt>(opr); | |||||
if (typecvt == nullptr) | |||||
return false; | |||||
bool is_to_f32 = | |||||
typecvt->output(0)->dtype().enumv() == DTypeEnum::Float32; | |||||
if (!is_to_f32) | |||||
return false; | |||||
if (!uniq_reader_check(typecvt->input(0))) | |||||
return false; | |||||
OperatorNodeBase* top_opr = nullptr; | |||||
if (!is_warp_nchw(typecvt->input(0)->owner_opr(), top_opr)) | |||||
return false; | |||||
auto warp = try_cast_as_op<opr::WarpPerspective>(top_opr); | |||||
SymbolVar new_warp; | |||||
make_new_warp(warp, warp->param(), opr->output()[0]->dtype(), new_warp); | |||||
rewriter.replace_var(opr->output(0), new_warp.node(), | |||||
mgb_cstr_log("replace warp + typecvt" | |||||
"fuse warp_dimshuffle(NCHW)")); | |||||
return true; | |||||
}; | |||||
auto try_warp_nhwc2nchw_typecvt = [&rewriter, &uniq_reader_check, | |||||
&is_warp_nhwc2nchw, | |||||
&make_new_warp](OperatorNodeBase* opr) { | |||||
// check typecvt | |||||
auto typecvt = try_cast_as_op<opr::TypeCvt>(opr); | |||||
if (typecvt == nullptr) | |||||
return false; | |||||
bool is_to_f32 = | |||||
typecvt->output(0)->dtype().enumv() == DTypeEnum::Float32; | |||||
if (!is_to_f32) | |||||
return false; | |||||
if (!uniq_reader_check(typecvt->input(0))) | |||||
return false; | |||||
OperatorNodeBase* top_opr = nullptr; | |||||
if (!is_warp_nhwc2nchw(typecvt->input(0)->owner_opr(), top_opr)) | |||||
return false; | |||||
auto warp = try_cast_as_op<opr::WarpPerspective>(top_opr); | |||||
opr::WarpPerspective::Param new_param = warp->param(); | |||||
new_param.format = megdnn::param::WarpPerspective::Format::NHWC_NCHW; | |||||
SymbolVar new_warp; | |||||
make_new_warp(warp, new_param, opr->output()[0]->dtype(), new_warp); | |||||
rewriter.replace_var( | |||||
opr->output(0), new_warp.node(), | |||||
mgb_cstr_log("replace conv_bias + dimshuffle + " | |||||
"typecvt to warp_dimshuffle(NHWC_NCHW)")); | |||||
return true; | |||||
}; | |||||
auto try_warp_nhwc2nchw4_typecvt = [&rewriter, &uniq_reader_check, | |||||
&is_warp_nhwc2nchw, | |||||
&make_new_warp](OperatorNodeBase* opr) { | |||||
// check relayout | |||||
auto relayout = try_cast_as_op<opr::RelayoutFormat>(opr); | |||||
if (relayout == nullptr) | |||||
return false; | |||||
bool is_to_q8 = | |||||
relayout->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8; | |||||
bool is_to_nchw2nchw4 = relayout->param().mode == | |||||
opr::RelayoutFormat::Param::Mode::NCHW_NCHW4; | |||||
if (!(is_to_q8 && is_to_nchw2nchw4)) | |||||
return false; | |||||
if (!uniq_reader_check(relayout->input(0))) | |||||
return false; | |||||
OperatorNodeBase* top_opr = nullptr; | |||||
if (!is_warp_nhwc2nchw(relayout->input(0)->owner_opr(), top_opr)) | |||||
return false; | |||||
auto warp = try_cast_as_op<opr::WarpPerspective>(top_opr); | |||||
bool is_small_chn = warp->input(0)->shape()[3] < 4; | |||||
if (!is_small_chn) | |||||
return false; | |||||
opr::WarpPerspective::Param new_param = warp->param(); | |||||
new_param.format = | |||||
megdnn::param::WarpPerspective::Format::NHWC_NCHW4_IC_SMALL; | |||||
SymbolVar new_warp; | |||||
make_new_warp(warp, new_param, opr->output()[0]->dtype(), new_warp); | |||||
rewriter.replace_var( | |||||
opr->output(0), new_warp.node(), | |||||
mgb_cstr_log("replace warp + dimshuffle + relayout(NCHW_NCHW4)" | |||||
"to warp_dimshuffle(NHWC_NCHW4_IC_SMALL)")); | |||||
return true; | |||||
}; | |||||
auto try_warp_nchw2nchw4_typecvt = [&rewriter, &uniq_reader_check, | |||||
&is_warp_nchw, | |||||
&make_new_warp](OperatorNodeBase* opr) { | |||||
// check relayout | |||||
auto relayout = try_cast_as_op<opr::RelayoutFormat>(opr); | |||||
if (relayout == nullptr) | |||||
return false; | |||||
bool is_to_q8 = | |||||
relayout->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8; | |||||
bool is_to_nchw2nchw4 = relayout->param().mode == | |||||
opr::RelayoutFormat::Param::Mode::NCHW_NCHW4; | |||||
if (!(is_to_q8 && is_to_nchw2nchw4)) | |||||
return false; | |||||
if (!uniq_reader_check(relayout->input(0))) | |||||
return false; | |||||
OperatorNodeBase* top_opr = nullptr; | |||||
if (!is_warp_nchw(relayout->input(0)->owner_opr(), top_opr)) | |||||
return false; | |||||
auto warp = try_cast_as_op<opr::WarpPerspective>(top_opr); | |||||
bool is_small_chn = warp->input(0)->shape()[1] < 4; | |||||
if (!is_small_chn) | |||||
return false; | |||||
opr::WarpPerspective::Param new_param = warp->param(); | |||||
new_param.format = | |||||
megdnn::param::WarpPerspective::Format::NCHW_NCHW4_IC_SMALL; | |||||
SymbolVar new_warp; | |||||
make_new_warp(warp, new_param, opr->output()[0]->dtype(), new_warp); | |||||
rewriter.replace_var( | |||||
opr->output(0), new_warp.node(), | |||||
mgb_cstr_log("replace warp + relayout(NCHW_NCHW4)" | |||||
"to warp_dimshuffle(NCHW_NCHW4_IC_SMALL)")); | |||||
return true; | |||||
}; | |||||
auto on_opr = [&try_warp_nchw_typecvt, &try_warp_nhwc2nchw_typecvt, | |||||
&try_warp_nhwc2nchw4_typecvt, &try_warp_nchw2nchw4_typecvt, | |||||
&rewriter](OperatorNodeBase* opr) { | |||||
if (!try_warp_nchw_typecvt(opr) && !try_warp_nhwc2nchw_typecvt(opr) && | |||||
!try_warp_nhwc2nchw4_typecvt(opr) && | |||||
!try_warp_nchw2nchw4_typecvt(opr)) { | |||||
rewriter.auto_replace_outputs(opr); | |||||
} | |||||
}; | |||||
opt.graph().iter(on_opr); | |||||
rewriter.apply_inplace(); | |||||
} | } |
@@ -173,6 +173,16 @@ namespace gopt { | |||||
}; | }; | ||||
/*! | /*! | ||||
* \brief fuse warp perspective and dimshuffle, quint8/uint8 to qint8/float | |||||
*/ | |||||
class FuseWarpPerspectiveDimshufflePass : public Pass { | |||||
public: | |||||
const char* name() const override; | |||||
void apply(OptState& opt) const override; | |||||
}; | |||||
/*! | |||||
* \brief fuse deconv and typecvt to a deconv opr | * \brief fuse deconv and typecvt to a deconv opr | ||||
*/ | */ | ||||
class FuseDeconvCvtPass : public Pass { | class FuseDeconvCvtPass : public Pass { | ||||
@@ -1172,7 +1172,8 @@ TEST(TestGoptInference, ConvertFormatNHWCD4) { | |||||
param.pad_h = param.pad_w = 1; | param.pad_h = param.pad_w = 1; | ||||
auto w2 = mkcvar("w2", {4, 4, 3, 3}), | auto w2 = mkcvar("w2", {4, 4, 3, 3}), | ||||
y = opr::Convolution::make(elem, w2, param), | y = opr::Convolution::make(elem, w2, param), | ||||
z = opr::AxisAddRemove::make(y, {opr::AxisAddRemove::AxisDesc::make_add(0)}); | |||||
z = opr::AxisAddRemove::make( | |||||
y, {opr::AxisAddRemove::AxisDesc::make_add(0)}); | |||||
SymbolVar y_opt, z_opt; | SymbolVar y_opt, z_opt; | ||||
auto options = gopt::OptimizeForInferenceOptions{}; | auto options = gopt::OptimizeForInferenceOptions{}; | ||||
@@ -3722,5 +3723,65 @@ TEST(TestGoptInference, PreProcessCase1) { | |||||
ASSERT_TRUE(y_opt.node()->owner_opr()->same_type<opr::RelayoutFormat>()); | ASSERT_TRUE(y_opt.node()->owner_opr()->same_type<opr::RelayoutFormat>()); | ||||
} | } | ||||
TEST(TestGoptInference, WarpAndPreProcessCase) { | |||||
REQUIRE_GPU(1); | |||||
HostTensorGenerator<dtype::Uint8, RandomDistribution::UNIFORM> gen(0, 255); | |||||
auto cn = CompNode::load("gpu0"); | |||||
auto graph = ComputingGraph::make(); | |||||
graph->options().graph_opt_level = 0; | |||||
size_t n = 1; | |||||
size_t c = 3; | |||||
size_t h = 16; | |||||
size_t w = 16; | |||||
auto host_x1 = gen({n, h, w, c}, cn); | |||||
auto x = opr::Host2DeviceCopy::make(*graph, host_x1); | |||||
auto mat_host = std::make_shared<HostTensorND>(cn, TensorShape{n, 3, 3}, | |||||
dtype::Float32()); | |||||
warp_perspective_mat_gen(*mat_host, n, h, w); | |||||
auto mat = opr::Host2DeviceCopy::make(*graph, mat_host).rename("mat"); | |||||
opr::WarpPerspective::Param warp_param; | |||||
warp_param.format = opr::WarpPerspective::Param::Format::NHWC; | |||||
auto x_warp = | |||||
opr::WarpPerspective::make(x, mat, TensorShape{h, w}, warp_param); | |||||
auto x_nchw = opr::Dimshuffle::make(x_warp, {0, 3, 1, 2}, 4, cn); | |||||
auto x_u8 = opr::TypeCvt::make(x_nchw, dtype::Float32(), cn); | |||||
auto x_s8 = x_u8 - 128; | |||||
auto zero = DTypeScalar(dtype::Float32()); | |||||
auto zero_tensor = opr::ImmutableTensor::make(*graph, zero, cn); | |||||
auto pad_channel_tensor = | |||||
opr::Broadcast::make(zero_tensor, {n, 1, h, w}, cn); | |||||
auto paded_x = opr::Concat::make({x_s8, pad_channel_tensor}, 1, cn) | |||||
.reshape({n, 1, 4, h, w}); | |||||
auto nchw4_out = opr::Dimshuffle::make(paded_x, {0, 1, 3, 4, 2}, 5, cn); | |||||
auto result = opr::TypeCvt::make(nchw4_out, dtype::QuantizedS8(1.f)); | |||||
auto y = result; | |||||
SymbolVar y_opt; | |||||
auto options = gopt::OptimizeForInferenceOptions{}; | |||||
options.enable_fuse_preprocess(); | |||||
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); | |||||
ASSERT_TRUE(y_opt.node()->owner_opr()->same_type<opr::WarpPerspective>()); | |||||
ASSERT_EQ(opr::WarpPerspective::Param::Format::NHWC_NCHW4_IC_SMALL, | |||||
find_opr<opr::WarpPerspective>(y_opt).param().format); | |||||
graph->compile({{y_opt, {}}}) | |||||
->to_json() | |||||
->writeto_fpath(output_file( | |||||
"TestGoptInference.WarpAndPreProcessCase.json")); | |||||
HostTensorND host_y_opt, host_y; | |||||
auto func = graph->compile({make_callback_copy(y, host_y), | |||||
make_callback_copy(y_opt, host_y_opt)}); | |||||
func->execute(); | |||||
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-5); | |||||
} | |||||
#endif | #endif | ||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -47,7 +47,11 @@ SymbolVar WarpPerspectiveForward::make(SymbolVar i0, SymbolVar i1, SymbolVar i2, | |||||
} | } | ||||
void WarpPerspectiveForward::init_output_dtype() { | void WarpPerspectiveForward::init_output_dtype() { | ||||
output(0)->dtype(input(0)->dtype()); | |||||
if (config().output_dtype().valid()) { | |||||
output(0)->dtype(config().output_dtype()); | |||||
} else { | |||||
output(0)->dtype(input(0)->dtype()); | |||||
} | |||||
} | } | ||||
void WarpPerspectiveForward::add_input_layout_constraint() { | void WarpPerspectiveForward::add_input_layout_constraint() { | ||||
@@ -78,23 +82,40 @@ void WarpPerspectiveForward::outshape_by_symvar_do_get_output_shape( | |||||
mat_idx_shp.to_string().c_str()); | mat_idx_shp.to_string().c_str()); | ||||
} | } | ||||
//! The index of height, e.g.,[b, h, w, c], the height_idx = 1 | |||||
size_t height_idx = 0; | |||||
if (param().format == Param::Format::NCHW || | |||||
param().format == Param::Format::NCHW4) { | |||||
height_idx = 2; | |||||
} else { | |||||
height_idx = 1; | |||||
} | |||||
dest = imgshp; | |||||
dest[0] = matshp[0]; | |||||
if (param().format == Param::Format::NHWCD4) { | |||||
dest.shape[height_idx] = oshp2d.shape[0]; | |||||
dest.shape[height_idx + 2] = oshp2d.shape[1]; | |||||
} else { | |||||
for (int i = 0; i < 2; ++i) | |||||
dest.shape[height_idx + i] = oshp2d.shape[i]; | |||||
switch (param().format) { | |||||
case Param::Format::NCHW_NCHW4_IC_SMALL: | |||||
case Param::Format::NHWC_NCHW4_IC_SMALL: | |||||
dest.ndim = 5; | |||||
dest[0] = matshp[0]; | |||||
dest.shape[1] = 1; | |||||
dest.shape[2] = oshp2d.shape[0]; | |||||
dest.shape[3] = oshp2d.shape[1]; | |||||
dest.shape[4] = 4; | |||||
break; | |||||
case Param::Format::NHWC_NCHW: | |||||
dest[0] = matshp[0]; | |||||
dest.shape[1] = imgshp.shape[3]; | |||||
dest.shape[2] = oshp2d.shape[0]; | |||||
dest.shape[3] = oshp2d.shape[1]; | |||||
break; | |||||
default: | |||||
size_t height_idx = 0; | |||||
if (param().format == Param::Format::NCHW || | |||||
param().format == Param::Format::NCHW4) { | |||||
height_idx = 2; | |||||
} else { | |||||
height_idx = 1; | |||||
} | |||||
dest = imgshp; | |||||
dest[0] = matshp[0]; | |||||
if (param().format == Param::Format::NHWCD4) { | |||||
dest.shape[height_idx] = oshp2d.shape[0]; | |||||
dest.shape[height_idx + 2] = oshp2d.shape[1]; | |||||
} else { | |||||
for (int i = 0; i < 2; ++i) | |||||
dest.shape[height_idx + i] = oshp2d.shape[i]; | |||||
} | |||||
break; | |||||
} | } | ||||
} | } | ||||