diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 06a78e73..905a7ac5 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -1579,3 +1579,9 @@ def batched_nms( indices = indices[0][: count.item()] keep_inds = sorted_idx[indices] return keep_inds + + + + +from .loss import * # isort:skip +from .quantized import conv_bias_activation # isort:skip diff --git a/imperative/python/test/unit/functional/test_functional.py b/imperative/python/test/unit/functional/test_functional.py index f3187ec4..3dfce3fa 100644 --- a/imperative/python/test/unit/functional/test_functional.py +++ b/imperative/python/test/unit/functional/test_functional.py @@ -551,3 +551,5 @@ def test_nms_is_same(): assert op1 != op3 assert op1 != op4 assert op3 != op4 + + diff --git a/src/opr/impl/misc.cpp b/src/opr/impl/misc.cpp index 3f738d4e..d4c1bb0b 100644 --- a/src/opr/impl/misc.cpp +++ b/src/opr/impl/misc.cpp @@ -159,6 +159,7 @@ void Cumsum::init_output_static_infer_desc() { {SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_workspace}); } + /* ================= CondTake ================= */ MGB_DYN_TYPE_OBJ_FINAL_IMPL(CondTake); diff --git a/src/opr/impl/misc.oprdecl b/src/opr/impl/misc.oprdecl index b2444c65..d76f473d 100644 --- a/src/opr/impl/misc.oprdecl +++ b/src/opr/impl/misc.oprdecl @@ -63,4 +63,5 @@ decl_opr('TopK', inputs=['data', 'k'], params='TopK', desc='Select the top k values from sorted result.') + # vim: ft=python diff --git a/src/opr/impl/misc.sereg.h b/src/opr/impl/misc.sereg.h index 7c5e7ea6..b8562ee5 100644 --- a/src/opr/impl/misc.sereg.h +++ b/src/opr/impl/misc.sereg.h @@ -70,6 +70,7 @@ namespace opr { using CumsumV1 = opr::Cumsum; MGB_SEREG_OPR(CumsumV1, 1); + } // namespace opr } // namespace mgb diff --git a/src/opr/include/megbrain/opr/misc.h b/src/opr/include/megbrain/opr/misc.h index e6285a41..314adefc 100644 --- a/src/opr/include/megbrain/opr/misc.h +++ b/src/opr/include/megbrain/opr/misc.h @@ -94,6 +94,7 @@ MGB_DEFINE_OPR_CLASS(Cumsum, cg::SingleCNOperatorNodeBaseT< void init_output_static_infer_desc() override; }; + namespace intl { using CondTakeBase = cg::SingleCNOperatorNode