GitOrigin-RevId: aeb2770401
release-1.10
@@ -7,8 +7,10 @@ namespace megdnn { | |||||
void ROIAlignBase::deduce_layout_fwd( | void ROIAlignBase::deduce_layout_fwd( | ||||
const TensorLayout& src, const TensorLayout& rois, TensorLayout& dst, | const TensorLayout& src, const TensorLayout& rois, TensorLayout& dst, | ||||
TensorLayout& index) { | TensorLayout& index) { | ||||
megdnn_assert_contiguous(src); | |||||
megdnn_assert_contiguous(rois); | |||||
if (!src.is_empty()) | |||||
megdnn_assert_contiguous(src); | |||||
if (!rois.is_empty()) | |||||
megdnn_assert_contiguous(rois); | |||||
megdnn_assert_contiguous(dst); | megdnn_assert_contiguous(dst); | ||||
megdnn_assert_contiguous(index); | megdnn_assert_contiguous(index); | ||||
auto errmsg = [&]() { | auto errmsg = [&]() { | ||||
@@ -16,14 +16,14 @@ from .tensor import broadcast_to, concat, expand_dims, reshape, transpose | |||||
__all__ = [ | __all__ = [ | ||||
"correlation", | "correlation", | ||||
"cvt_color", | "cvt_color", | ||||
"roi_pooling", | |||||
"roi_align", | |||||
"interpolate", | |||||
"nms", | "nms", | ||||
"nvof", | |||||
"remap", | "remap", | ||||
"roi_align", | |||||
"roi_pooling", | |||||
"warp_affine", | "warp_affine", | ||||
"warp_perspective", | "warp_perspective", | ||||
"interpolate", | |||||
"nvof", | |||||
] | ] | ||||
@@ -95,9 +95,9 @@ def roi_pooling( | |||||
Args: | Args: | ||||
inp: tensor that represents the input feature, `(N, C, H, W)` images. | inp: tensor that represents the input feature, `(N, C, H, W)` images. | ||||
rois: K, 5)` boxes. First column is the index into N. The other 4 columns are xyxy. | |||||
output_shape: height, width)` of output rois feature. | |||||
mode: max" or "average", use max/average align just like max/average pooling. Default: "max" | |||||
rois: `(K, 5)` boxes. First column is the index into N. The other 4 columns are xyxy. | |||||
output_shape: `(height, width)` of output rois feature. | |||||
mode: "max" or "average", use max/average align just like max/average pooling. Default: "max" | |||||
scale: scale the input boxes by this number. Default: 1.0 | scale: scale the input boxes by this number. Default: 1.0 | ||||
Returns: | Returns: | ||||
@@ -176,9 +176,9 @@ def roi_align( | |||||
Args: | Args: | ||||
inp: tensor that represents the input feature, shape is `(N, C, H, W)`. | inp: tensor that represents the input feature, shape is `(N, C, H, W)`. | ||||
rois: N, 5)` boxes. First column is the box index. The other 4 columns are ``xyxy``. | |||||
output_shape: height, width)` shape of output rois feature. | |||||
mode: max" or "average", use max/average align just like max/average pooling. Default: "average" | |||||
rois: `(N, 5)` boxes. First column is the box index. The other 4 columns are ``xyxy``. | |||||
output_shape: `(height, width)` shape of output rois feature. | |||||
mode: "max" or "average", use max/average align just like max/average pooling. Default: "average" | |||||
spatial_scale: scale the input boxes by this number. Default: 1.0 | spatial_scale: scale the input boxes by this number. Default: 1.0 | ||||
sample_points: number of inputs samples to take for each output sample. | sample_points: number of inputs samples to take for each output sample. | ||||
0 to take samples densely. Default: 2 | 0 to take samples densely. Default: 2 | ||||
@@ -345,7 +345,7 @@ def warp_affine( | |||||
Args: | Args: | ||||
inp: input image. | inp: input image. | ||||
mat: batch, 2, 3)` transformation matrix. | |||||
mat: `(batch, 2, 3)` transformation matrix. | |||||
out_shape: output tensor shape. | out_shape: output tensor shape. | ||||
border_mode: pixel extrapolation method. | border_mode: pixel extrapolation method. | ||||
Default: "wrap". Currently "constant", "reflect", | Default: "wrap". Currently "constant", "reflect", | ||||
@@ -289,6 +289,37 @@ def test_roi_align(): | |||||
assert make_shape_tuple(inp_feat.grad.shape) == make_shape_tuple(inp_feat.shape) | assert make_shape_tuple(inp_feat.grad.shape) == make_shape_tuple(inp_feat.shape) | ||||
@pytest.mark.parametrize("shapes", [((2, 0, 26, 26), (4, 5)), ((2, 3, 26, 26), (0, 5))]) | |||||
@pytest.mark.parametrize("is_tracing", [False, True]) | |||||
def test_roi_align_empty(shapes, is_tracing): | |||||
inp_feat = tensor(np.random.randn(*(shapes[0]))) | |||||
rois = tensor(np.random.random(shapes[1])) | |||||
output_shape = (7, 7) | |||||
def func(inp, rois): | |||||
out_feat = F.vision.roi_align( | |||||
inp_feat, | |||||
rois, | |||||
output_shape=output_shape, | |||||
mode="average", | |||||
spatial_scale=1.0 / 4, | |||||
sample_points=2, | |||||
aligned=True, | |||||
) | |||||
return out_feat | |||||
if is_tracing: | |||||
func = jit.trace(func) | |||||
for _ in range(3): | |||||
out_feat = func(inp_feat, rois) | |||||
assert make_shape_tuple(out_feat.shape) == ( | |||||
rois.shape[0], | |||||
inp_feat.shape[1], | |||||
*output_shape, | |||||
) | |||||
def _gen_correlation(random=True, constant=1, image_shape=(2, 1, 160, 160)): | def _gen_correlation(random=True, constant=1, image_shape=(2, 1, 160, 160)): | ||||
if random: | if random: | ||||
inp_feat1 = np.random.randn( | inp_feat1 = np.random.randn( | ||||
@@ -442,21 +442,6 @@ OP_TRAIT_REG(AssertEqual, AssertEqual).apply_on_var_node(apply_on_var_node).fall | |||||
} // namespace | } // namespace | ||||
namespace { | namespace { | ||||
namespace roi_align { | |||||
VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& op = static_cast<const ROIAlign&>(def); | |||||
mgb_assert(inputs.size() == 2); | |||||
OperatorNodeConfig config{op.make_name()}; | |||||
auto* opr = opr::ROIAlign::make(inputs[0], inputs[1], op.param(), config) | |||||
.node() | |||||
->owner_opr(); | |||||
return {opr->output(0), opr->output(1)}; | |||||
} | |||||
OP_TRAIT_REG(ROIAlign, ROIAlign).apply_on_var_node(apply_on_var_node).fallback(); | |||||
} // namespace roi_align | |||||
} // namespace | |||||
namespace { | |||||
namespace correlation { | namespace correlation { | ||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | ||||
auto&& op = static_cast<const Correlation&>(def); | auto&& op = static_cast<const Correlation&>(def); | ||||
@@ -523,22 +508,6 @@ OP_TRAIT_REG(Diag, Diag).apply_on_var_node(apply_on_var_node).fallback(); | |||||
} // namespace | } // namespace | ||||
namespace { | namespace { | ||||
namespace roi_pooling { | |||||
VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& op = static_cast<const ROIPooling&>(def); | |||||
mgb_assert(inputs.size() == 3); | |||||
OperatorNodeConfig config{op.make_name()}; | |||||
auto* opr = | |||||
opr::ROIPooling::make(inputs[0], inputs[1], inputs[2], op.param(), config) | |||||
.node() | |||||
->owner_opr(); | |||||
return {opr->output(0), opr->output(1)}; | |||||
} | |||||
OP_TRAIT_REG(ROIPooling, ROIPooling).apply_on_var_node(apply_on_var_node).fallback(); | |||||
} // namespace roi_pooling | |||||
} // namespace | |||||
namespace { | |||||
namespace remap { | namespace remap { | ||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | ||||
auto&& op = static_cast<const Remap&>(def); | auto&& op = static_cast<const Remap&>(def); | ||||
@@ -1,8 +1,11 @@ | |||||
#include "megbrain/imperative/ops/autogen.h" | #include "megbrain/imperative/ops/autogen.h" | ||||
#include "megbrain/opr/dnn/roi_align.h" | |||||
#include "megbrain/opr/dnn/roi_pooling.h" | |||||
#include "megbrain/opr/imgproc.h" | #include "megbrain/opr/imgproc.h" | ||||
#include "../blob_manager_impl.h" | |||||
#include "../dnn_op_helper.h" | |||||
#include "../op_trait.h" | #include "../op_trait.h" | ||||
namespace mgb { | namespace mgb { | ||||
namespace imperative { | namespace imperative { | ||||
@@ -15,5 +18,119 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
} | } | ||||
OP_TRAIT_REG(CvtColor, CvtColor).apply_on_var_node(apply_on_var_node).fallback(); | OP_TRAIT_REG(CvtColor, CvtColor).apply_on_var_node(apply_on_var_node).fallback(); | ||||
} // namespace | } // namespace | ||||
namespace { | |||||
namespace roi_align { | |||||
VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& op = static_cast<const ROIAlign&>(def); | |||||
mgb_assert(inputs.size() == 2); | |||||
OperatorNodeConfig config{op.make_name()}; | |||||
auto* opr = opr::ROIAlign::make(inputs[0], inputs[1], op.param(), config) | |||||
.node() | |||||
->owner_opr(); | |||||
return {opr->output(0), opr->output(1)}; | |||||
} | |||||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | |||||
auto&& op = static_cast<const ROIAlign&>(def); | |||||
if (inputs[0].layout.is_empty() || inputs[1].layout.is_empty()) { | |||||
return {{{TensorLayout(inputs[0].layout.dtype), inputs[0].comp_node}, | |||||
{TensorLayout(dtype::Int32()), inputs[1].comp_node}}, | |||||
false}; | |||||
} | |||||
SmallVector<LogicalTensorDesc> descs(2u); | |||||
size_t n = inputs[1].layout[0]; | |||||
size_t c = inputs[0].layout[1]; | |||||
descs[0].layout = TensorLayout( | |||||
{n, c, op.pooled_height, op.pooled_width}, inputs[0].layout.dtype); | |||||
descs[0].layout.init_contiguous_stride(); | |||||
descs[0].comp_node = inputs[0].comp_node; | |||||
descs[1].layout = | |||||
TensorLayout({n, c, op.pooled_height, op.pooled_width}, dtype::Int32()); | |||||
descs[1].layout.init_contiguous_stride(); | |||||
descs[1].comp_node = descs[0].comp_node; | |||||
return {descs, true}; | |||||
} | |||||
SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||||
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||||
auto&& op = static_cast<const ROIAlign&>(def); | |||||
CompNode cn = inputs[0]->comp_node(); | |||||
TensorLayout out_layout = output_descs[0].layout; | |||||
TensorLayout ind_layout = output_descs[1].layout; | |||||
if (!validated) { | |||||
size_t n = inputs[1]->layout()[0]; | |||||
size_t c = inputs[0]->layout()[1]; | |||||
out_layout = TensorLayout( | |||||
{n, c, op.pooled_height, op.pooled_width}, inputs[0]->layout().dtype); | |||||
out_layout.init_contiguous_stride(); | |||||
ind_layout = | |||||
TensorLayout({n, c, op.pooled_height, op.pooled_width}, dtype::Int32()); | |||||
ind_layout.init_contiguous_stride(); | |||||
} | |||||
DeviceTensorND out = | |||||
BlobManager::inst()->alloc_workspace_with_defrag(cn, out_layout); | |||||
DeviceTensorND inds = | |||||
BlobManager::inst()->alloc_workspace_with_defrag(cn, ind_layout); | |||||
if (out_layout.is_empty() || ind_layout.is_empty()) { | |||||
return {Tensor::make(out), Tensor::make(inds)}; | |||||
} | |||||
DnnOprCaller<megdnn::ROIAlign> dnn_opr(cn); | |||||
dnn_opr.op->param() = op.param(); | |||||
size_t sz = dnn_opr.op->get_workspace_in_bytes( | |||||
inputs[0]->layout(), inputs[1]->layout(), out_layout, ind_layout); | |||||
TensorLayout w_layout({sz}, dtype::Byte()); | |||||
auto dnn_wk = dnn_opr.create_workspace(w_layout); | |||||
dnn_opr.op->exec( | |||||
inputs[0]->dnn_tensor(), inputs[1]->dnn_tensor(), out.as_megdnn(), | |||||
inds.as_megdnn(), dnn_wk); | |||||
return {Tensor::make(out), Tensor::make(inds)}; | |||||
} | |||||
SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint( | |||||
const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||||
SmallVector<VarNode::LayoutConstraintCallback> layout_checker(inputs.size()); | |||||
layout_checker[0] = layout_checker[1] = [](const TensorLayout& layout) { | |||||
return layout.is_contiguous(); | |||||
}; | |||||
return layout_checker; | |||||
} | |||||
OP_TRAIT_REG(ROIAlign, ROIAlign) | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.apply_on_physical_tensor(apply_on_physical_tensor) | |||||
.infer_output_attrs_fallible(infer_output_attrs_fallible) | |||||
.get_input_layout_constraint(get_input_layout_constraint) | |||||
.fallback(); | |||||
} // namespace roi_align | |||||
} // namespace | |||||
namespace { | |||||
namespace roi_pooling { | |||||
VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& op = static_cast<const ROIPooling&>(def); | |||||
mgb_assert(inputs.size() == 3); | |||||
OperatorNodeConfig config{op.make_name()}; | |||||
auto* opr = | |||||
opr::ROIPooling::make(inputs[0], inputs[1], inputs[2], op.param(), config) | |||||
.node() | |||||
->owner_opr(); | |||||
return {opr->output(0), opr->output(1)}; | |||||
} | |||||
OP_TRAIT_REG(ROIPooling, ROIPooling).apply_on_var_node(apply_on_var_node).fallback(); | |||||
} // namespace roi_pooling | |||||
} // namespace | |||||
} // namespace imperative | } // namespace imperative | ||||
} // namespace mgb | } // namespace mgb |
@@ -20,6 +20,8 @@ ROIAlignForward::ROIAlignForward( | |||||
add_input({src, rois}); | add_input({src, rois}); | ||||
output(0)->dtype(dtype::Float32()); | output(0)->dtype(dtype::Float32()); | ||||
output(1)->dtype(dtype::Int32()); | output(1)->dtype(dtype::Int32()); | ||||
output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); | |||||
output(1)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); | |||||
} | } | ||||
SymbolVar ROIAlignForward::make( | SymbolVar ROIAlignForward::make( | ||||
@@ -29,6 +31,35 @@ SymbolVar ROIAlignForward::make( | |||||
src.node(), rois.node(), param, config); | src.node(), rois.node(), param, config); | ||||
} | } | ||||
ROIAlignForward::NodeProp* ROIAlignForward::do_make_node_prop() const { | |||||
auto ret = Super::do_make_node_prop(); | |||||
ret->add_dep_type_existing_var(input(0), NodeProp::DepType::VALUE_ALLOW_EMPTY); | |||||
ret->add_dep_type_existing_var(input(1), NodeProp::DepType::VALUE_ALLOW_EMPTY); | |||||
return ret; | |||||
} | |||||
void ROIAlignForward::scn_do_execute() { | |||||
auto src = input(0)->dev_tensor().as_megdnn(), | |||||
rois = input(1)->dev_tensor().as_megdnn(), | |||||
dst = output(0)->dev_tensor().as_megdnn(), | |||||
index = output(1)->dev_tensor().as_megdnn(); | |||||
if ((src.layout.is_empty() || rois.layout.is_empty())) { | |||||
return; | |||||
} | |||||
megdnn_opr()->exec( | |||||
src, rois, dst, index, intl::get_megdnn_workspace_from_var(output(2))); | |||||
} | |||||
size_t ROIAlignForward::get_workspace_size_bytes( | |||||
const TensorShapeArray& inp_shapes, const TensorShapeArray& out_shapes) const { | |||||
TensorLayout inp{inp_shapes[0], input(0)->dtype(), input(0)->format()}, | |||||
rois{inp_shapes[1], input(1)->dtype(), input(1)->format()}, | |||||
out{out_shapes[0], output(0)->dtype(), output(0)->format()}, | |||||
index{out_shapes[1], output(1)->dtype(), output(1)->format()}; | |||||
return megdnn_opr()->get_workspace_in_bytes(inp, rois, index, out); | |||||
} | |||||
#if MGB_ENABLE_GRAD | #if MGB_ENABLE_GRAD | ||||
MGB_IMPL_OPR_GRAD(ROIAlignForward) { | MGB_IMPL_OPR_GRAD(ROIAlignForward) { | ||||
if (wrt_idx == 0) { | if (wrt_idx == 0) { | ||||
@@ -16,6 +16,13 @@ public: | |||||
MGE_WIN_DECLSPEC_FUC static SymbolVar make( | MGE_WIN_DECLSPEC_FUC static SymbolVar make( | ||||
SymbolVar src, SymbolVar rois, const Param& param = {}, | SymbolVar src, SymbolVar rois, const Param& param = {}, | ||||
const OperatorNodeConfig& config = {}); | const OperatorNodeConfig& config = {}); | ||||
private: | |||||
void scn_do_execute() override; | |||||
NodeProp* do_make_node_prop() const override; | |||||
size_t get_workspace_size_bytes( | |||||
const TensorShapeArray& input_shapes, | |||||
const TensorShapeArray& output_shapes) const override; | |||||
}; | }; | ||||
using ROIAlign = ROIAlignForward; | using ROIAlign = ROIAlignForward; | ||||