@@ -1579,3 +1579,9 @@ def batched_nms( | |||||
indices = indices[0][: count.item()] | indices = indices[0][: count.item()] | ||||
keep_inds = sorted_idx[indices] | keep_inds = sorted_idx[indices] | ||||
return keep_inds | return keep_inds | ||||
from .loss import * # isort:skip | |||||
from .quantized import conv_bias_activation # isort:skip |
@@ -551,3 +551,5 @@ def test_nms_is_same(): | |||||
assert op1 != op3 | assert op1 != op3 | ||||
assert op1 != op4 | assert op1 != op4 | ||||
assert op3 != op4 | assert op3 != op4 | ||||
@@ -159,6 +159,7 @@ void Cumsum::init_output_static_infer_desc() { | |||||
{SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_workspace}); | {SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_workspace}); | ||||
} | } | ||||
/* ================= CondTake ================= */ | /* ================= CondTake ================= */ | ||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CondTake); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(CondTake); | ||||
@@ -63,4 +63,5 @@ decl_opr('TopK', | |||||
inputs=['data', 'k'], params='TopK', | inputs=['data', 'k'], params='TopK', | ||||
desc='Select the top k values from sorted result.') | desc='Select the top k values from sorted result.') | ||||
# vim: ft=python | # vim: ft=python |
@@ -70,6 +70,7 @@ namespace opr { | |||||
using CumsumV1 = opr::Cumsum; | using CumsumV1 = opr::Cumsum; | ||||
MGB_SEREG_OPR(CumsumV1, 1); | MGB_SEREG_OPR(CumsumV1, 1); | ||||
} // namespace opr | } // namespace opr | ||||
} // namespace mgb | } // namespace mgb | ||||
@@ -94,6 +94,7 @@ MGB_DEFINE_OPR_CLASS(Cumsum, cg::SingleCNOperatorNodeBaseT< | |||||
void init_output_static_infer_desc() override; | void init_output_static_infer_desc() override; | ||||
}; | }; | ||||
namespace intl { | namespace intl { | ||||
using CondTakeBase = | using CondTakeBase = | ||||
cg::SingleCNOperatorNode<cg::OperatorNodeBase, | cg::SingleCNOperatorNode<cg::OperatorNodeBase, | ||||
@@ -28,6 +28,7 @@ table Blob { | |||||
} | } | ||||
table Reserved0 {} | table Reserved0 {} | ||||
table Reserved1 {} | |||||
union OperatorParam { | union OperatorParam { | ||||
param.Empty = 1, | param.Empty = 1, | ||||
@@ -100,6 +101,7 @@ union OperatorParam { | |||||
param.Remap = 68, | param.Remap = 68, | ||||
param.NMSKeep = 69, | param.NMSKeep = 69, | ||||
param.AdaptivePooling = 70, | param.AdaptivePooling = 70, | ||||
Reserved1 = 71, | |||||
} | } | ||||
table Operator { | table Operator { | ||||
@@ -143,3 +143,4 @@ pdef('PersistentOutputStorage').add_fields( | |||||
' no branch is taken') | ' no branch is taken') | ||||
) | ) | ||||
) | ) | ||||