GitOrigin-RevId: aeb2770401
release-1.10
@@ -7,8 +7,10 @@ namespace megdnn { | |||
void ROIAlignBase::deduce_layout_fwd( | |||
const TensorLayout& src, const TensorLayout& rois, TensorLayout& dst, | |||
TensorLayout& index) { | |||
megdnn_assert_contiguous(src); | |||
megdnn_assert_contiguous(rois); | |||
if (!src.is_empty()) | |||
megdnn_assert_contiguous(src); | |||
if (!rois.is_empty()) | |||
megdnn_assert_contiguous(rois); | |||
megdnn_assert_contiguous(dst); | |||
megdnn_assert_contiguous(index); | |||
auto errmsg = [&]() { | |||
@@ -16,14 +16,14 @@ from .tensor import broadcast_to, concat, expand_dims, reshape, transpose | |||
__all__ = [ | |||
"correlation", | |||
"cvt_color", | |||
"roi_pooling", | |||
"roi_align", | |||
"interpolate", | |||
"nms", | |||
"nvof", | |||
"remap", | |||
"roi_align", | |||
"roi_pooling", | |||
"warp_affine", | |||
"warp_perspective", | |||
"interpolate", | |||
"nvof", | |||
] | |||
@@ -95,9 +95,9 @@ def roi_pooling( | |||
Args: | |||
inp: tensor that represents the input feature, `(N, C, H, W)` images. | |||
rois: K, 5)` boxes. First column is the index into N. The other 4 columns are xyxy. | |||
output_shape: height, width)` of output rois feature. | |||
mode: max" or "average", use max/average align just like max/average pooling. Default: "max" | |||
rois: `(K, 5)` boxes. First column is the index into N. The other 4 columns are xyxy. | |||
output_shape: `(height, width)` of output rois feature. | |||
mode: "max" or "average", use max/average align just like max/average pooling. Default: "max" | |||
scale: scale the input boxes by this number. Default: 1.0 | |||
Returns: | |||
@@ -176,9 +176,9 @@ def roi_align( | |||
Args: | |||
inp: tensor that represents the input feature, shape is `(N, C, H, W)`. | |||
rois: N, 5)` boxes. First column is the box index. The other 4 columns are ``xyxy``. | |||
output_shape: height, width)` shape of output rois feature. | |||
mode: max" or "average", use max/average align just like max/average pooling. Default: "average" | |||
rois: `(N, 5)` boxes. First column is the box index. The other 4 columns are ``xyxy``. | |||
output_shape: `(height, width)` shape of output rois feature. | |||
mode: "max" or "average", use max/average align just like max/average pooling. Default: "average" | |||
spatial_scale: scale the input boxes by this number. Default: 1.0 | |||
sample_points: number of inputs samples to take for each output sample. | |||
0 to take samples densely. Default: 2 | |||
@@ -345,7 +345,7 @@ def warp_affine( | |||
Args: | |||
inp: input image. | |||
mat: batch, 2, 3)` transformation matrix. | |||
mat: `(batch, 2, 3)` transformation matrix. | |||
out_shape: output tensor shape. | |||
border_mode: pixel extrapolation method. | |||
Default: "wrap". Currently "constant", "reflect", | |||
@@ -289,6 +289,37 @@ def test_roi_align(): | |||
assert make_shape_tuple(inp_feat.grad.shape) == make_shape_tuple(inp_feat.shape) | |||
@pytest.mark.parametrize("shapes", [((2, 0, 26, 26), (4, 5)), ((2, 3, 26, 26), (0, 5))]) | |||
@pytest.mark.parametrize("is_tracing", [False, True]) | |||
def test_roi_align_empty(shapes, is_tracing): | |||
inp_feat = tensor(np.random.randn(*(shapes[0]))) | |||
rois = tensor(np.random.random(shapes[1])) | |||
output_shape = (7, 7) | |||
def func(inp, rois): | |||
out_feat = F.vision.roi_align( | |||
inp_feat, | |||
rois, | |||
output_shape=output_shape, | |||
mode="average", | |||
spatial_scale=1.0 / 4, | |||
sample_points=2, | |||
aligned=True, | |||
) | |||
return out_feat | |||
if is_tracing: | |||
func = jit.trace(func) | |||
for _ in range(3): | |||
out_feat = func(inp_feat, rois) | |||
assert make_shape_tuple(out_feat.shape) == ( | |||
rois.shape[0], | |||
inp_feat.shape[1], | |||
*output_shape, | |||
) | |||
def _gen_correlation(random=True, constant=1, image_shape=(2, 1, 160, 160)): | |||
if random: | |||
inp_feat1 = np.random.randn( | |||
@@ -442,21 +442,6 @@ OP_TRAIT_REG(AssertEqual, AssertEqual).apply_on_var_node(apply_on_var_node).fall | |||
} // namespace | |||
namespace { | |||
namespace roi_align { | |||
VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const ROIAlign&>(def); | |||
mgb_assert(inputs.size() == 2); | |||
OperatorNodeConfig config{op.make_name()}; | |||
auto* opr = opr::ROIAlign::make(inputs[0], inputs[1], op.param(), config) | |||
.node() | |||
->owner_opr(); | |||
return {opr->output(0), opr->output(1)}; | |||
} | |||
OP_TRAIT_REG(ROIAlign, ROIAlign).apply_on_var_node(apply_on_var_node).fallback(); | |||
} // namespace roi_align | |||
} // namespace | |||
namespace { | |||
namespace correlation { | |||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const Correlation&>(def); | |||
@@ -523,22 +508,6 @@ OP_TRAIT_REG(Diag, Diag).apply_on_var_node(apply_on_var_node).fallback(); | |||
} // namespace | |||
namespace { | |||
namespace roi_pooling { | |||
VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const ROIPooling&>(def); | |||
mgb_assert(inputs.size() == 3); | |||
OperatorNodeConfig config{op.make_name()}; | |||
auto* opr = | |||
opr::ROIPooling::make(inputs[0], inputs[1], inputs[2], op.param(), config) | |||
.node() | |||
->owner_opr(); | |||
return {opr->output(0), opr->output(1)}; | |||
} | |||
OP_TRAIT_REG(ROIPooling, ROIPooling).apply_on_var_node(apply_on_var_node).fallback(); | |||
} // namespace roi_pooling | |||
} // namespace | |||
namespace { | |||
namespace remap { | |||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const Remap&>(def); | |||
@@ -1,8 +1,11 @@ | |||
#include "megbrain/imperative/ops/autogen.h" | |||
#include "megbrain/opr/dnn/roi_align.h" | |||
#include "megbrain/opr/dnn/roi_pooling.h" | |||
#include "megbrain/opr/imgproc.h" | |||
#include "../blob_manager_impl.h" | |||
#include "../dnn_op_helper.h" | |||
#include "../op_trait.h" | |||
namespace mgb { | |||
namespace imperative { | |||
@@ -15,5 +18,119 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
} | |||
OP_TRAIT_REG(CvtColor, CvtColor).apply_on_var_node(apply_on_var_node).fallback(); | |||
} // namespace | |||
namespace { | |||
namespace roi_align { | |||
VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const ROIAlign&>(def); | |||
mgb_assert(inputs.size() == 2); | |||
OperatorNodeConfig config{op.make_name()}; | |||
auto* opr = opr::ROIAlign::make(inputs[0], inputs[1], op.param(), config) | |||
.node() | |||
->owner_opr(); | |||
return {opr->output(0), opr->output(1)}; | |||
} | |||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | |||
auto&& op = static_cast<const ROIAlign&>(def); | |||
if (inputs[0].layout.is_empty() || inputs[1].layout.is_empty()) { | |||
return {{{TensorLayout(inputs[0].layout.dtype), inputs[0].comp_node}, | |||
{TensorLayout(dtype::Int32()), inputs[1].comp_node}}, | |||
false}; | |||
} | |||
SmallVector<LogicalTensorDesc> descs(2u); | |||
size_t n = inputs[1].layout[0]; | |||
size_t c = inputs[0].layout[1]; | |||
descs[0].layout = TensorLayout( | |||
{n, c, op.pooled_height, op.pooled_width}, inputs[0].layout.dtype); | |||
descs[0].layout.init_contiguous_stride(); | |||
descs[0].comp_node = inputs[0].comp_node; | |||
descs[1].layout = | |||
TensorLayout({n, c, op.pooled_height, op.pooled_width}, dtype::Int32()); | |||
descs[1].layout.init_contiguous_stride(); | |||
descs[1].comp_node = descs[0].comp_node; | |||
return {descs, true}; | |||
} | |||
SmallVector<TensorPtr> apply_on_physical_tensor( | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
auto&& op = static_cast<const ROIAlign&>(def); | |||
CompNode cn = inputs[0]->comp_node(); | |||
TensorLayout out_layout = output_descs[0].layout; | |||
TensorLayout ind_layout = output_descs[1].layout; | |||
if (!validated) { | |||
size_t n = inputs[1]->layout()[0]; | |||
size_t c = inputs[0]->layout()[1]; | |||
out_layout = TensorLayout( | |||
{n, c, op.pooled_height, op.pooled_width}, inputs[0]->layout().dtype); | |||
out_layout.init_contiguous_stride(); | |||
ind_layout = | |||
TensorLayout({n, c, op.pooled_height, op.pooled_width}, dtype::Int32()); | |||
ind_layout.init_contiguous_stride(); | |||
} | |||
DeviceTensorND out = | |||
BlobManager::inst()->alloc_workspace_with_defrag(cn, out_layout); | |||
DeviceTensorND inds = | |||
BlobManager::inst()->alloc_workspace_with_defrag(cn, ind_layout); | |||
if (out_layout.is_empty() || ind_layout.is_empty()) { | |||
return {Tensor::make(out), Tensor::make(inds)}; | |||
} | |||
DnnOprCaller<megdnn::ROIAlign> dnn_opr(cn); | |||
dnn_opr.op->param() = op.param(); | |||
size_t sz = dnn_opr.op->get_workspace_in_bytes( | |||
inputs[0]->layout(), inputs[1]->layout(), out_layout, ind_layout); | |||
TensorLayout w_layout({sz}, dtype::Byte()); | |||
auto dnn_wk = dnn_opr.create_workspace(w_layout); | |||
dnn_opr.op->exec( | |||
inputs[0]->dnn_tensor(), inputs[1]->dnn_tensor(), out.as_megdnn(), | |||
inds.as_megdnn(), dnn_wk); | |||
return {Tensor::make(out), Tensor::make(inds)}; | |||
} | |||
SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint( | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
SmallVector<VarNode::LayoutConstraintCallback> layout_checker(inputs.size()); | |||
layout_checker[0] = layout_checker[1] = [](const TensorLayout& layout) { | |||
return layout.is_contiguous(); | |||
}; | |||
return layout_checker; | |||
} | |||
OP_TRAIT_REG(ROIAlign, ROIAlign) | |||
.apply_on_var_node(apply_on_var_node) | |||
.apply_on_physical_tensor(apply_on_physical_tensor) | |||
.infer_output_attrs_fallible(infer_output_attrs_fallible) | |||
.get_input_layout_constraint(get_input_layout_constraint) | |||
.fallback(); | |||
} // namespace roi_align | |||
} // namespace | |||
namespace { | |||
namespace roi_pooling { | |||
VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const ROIPooling&>(def); | |||
mgb_assert(inputs.size() == 3); | |||
OperatorNodeConfig config{op.make_name()}; | |||
auto* opr = | |||
opr::ROIPooling::make(inputs[0], inputs[1], inputs[2], op.param(), config) | |||
.node() | |||
->owner_opr(); | |||
return {opr->output(0), opr->output(1)}; | |||
} | |||
OP_TRAIT_REG(ROIPooling, ROIPooling).apply_on_var_node(apply_on_var_node).fallback(); | |||
} // namespace roi_pooling | |||
} // namespace | |||
} // namespace imperative | |||
} // namespace mgb |
@@ -20,6 +20,8 @@ ROIAlignForward::ROIAlignForward( | |||
add_input({src, rois}); | |||
output(0)->dtype(dtype::Float32()); | |||
output(1)->dtype(dtype::Int32()); | |||
output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); | |||
output(1)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); | |||
} | |||
SymbolVar ROIAlignForward::make( | |||
@@ -29,6 +31,35 @@ SymbolVar ROIAlignForward::make( | |||
src.node(), rois.node(), param, config); | |||
} | |||
ROIAlignForward::NodeProp* ROIAlignForward::do_make_node_prop() const { | |||
auto ret = Super::do_make_node_prop(); | |||
ret->add_dep_type_existing_var(input(0), NodeProp::DepType::VALUE_ALLOW_EMPTY); | |||
ret->add_dep_type_existing_var(input(1), NodeProp::DepType::VALUE_ALLOW_EMPTY); | |||
return ret; | |||
} | |||
void ROIAlignForward::scn_do_execute() { | |||
auto src = input(0)->dev_tensor().as_megdnn(), | |||
rois = input(1)->dev_tensor().as_megdnn(), | |||
dst = output(0)->dev_tensor().as_megdnn(), | |||
index = output(1)->dev_tensor().as_megdnn(); | |||
if ((src.layout.is_empty() || rois.layout.is_empty())) { | |||
return; | |||
} | |||
megdnn_opr()->exec( | |||
src, rois, dst, index, intl::get_megdnn_workspace_from_var(output(2))); | |||
} | |||
size_t ROIAlignForward::get_workspace_size_bytes( | |||
const TensorShapeArray& inp_shapes, const TensorShapeArray& out_shapes) const { | |||
TensorLayout inp{inp_shapes[0], input(0)->dtype(), input(0)->format()}, | |||
rois{inp_shapes[1], input(1)->dtype(), input(1)->format()}, | |||
out{out_shapes[0], output(0)->dtype(), output(0)->format()}, | |||
index{out_shapes[1], output(1)->dtype(), output(1)->format()}; | |||
return megdnn_opr()->get_workspace_in_bytes(inp, rois, index, out); | |||
} | |||
#if MGB_ENABLE_GRAD | |||
MGB_IMPL_OPR_GRAD(ROIAlignForward) { | |||
if (wrt_idx == 0) { | |||
@@ -16,6 +16,13 @@ public: | |||
MGE_WIN_DECLSPEC_FUC static SymbolVar make( | |||
SymbolVar src, SymbolVar rois, const Param& param = {}, | |||
const OperatorNodeConfig& config = {}); | |||
private: | |||
void scn_do_execute() override; | |||
NodeProp* do_make_node_prop() const override; | |||
size_t get_workspace_size_bytes( | |||
const TensorShapeArray& input_shapes, | |||
const TensorShapeArray& output_shapes) const override; | |||
}; | |||
using ROIAlign = ROIAlignForward; | |||