Browse Source

feat(opr): add confidential operator

GitOrigin-RevId: 53c2d4bc45
release-1.1
Megvii Engine Team 4 years ago
parent
commit
e0da74852e
8 changed files with 15 additions and 0 deletions
  1. +6
    -0
      imperative/python/megengine/functional/nn.py
  2. +2
    -0
      imperative/python/test/unit/functional/test_functional.py
  3. +1
    -0
      src/opr/impl/misc.cpp
  4. +1
    -0
      src/opr/impl/misc.oprdecl
  5. +1
    -0
      src/opr/impl/misc.sereg.h
  6. +1
    -0
      src/opr/include/megbrain/opr/misc.h
  7. +2
    -0
      src/serialization/impl/schema.fbs
  8. +1
    -0
      tools/param_defs/mgb_opr_param_defs.py

+ 6
- 0
imperative/python/megengine/functional/nn.py View File

@@ -1579,3 +1579,9 @@ def batched_nms(
indices = indices[0][: count.item()]
keep_inds = sorted_idx[indices]
return keep_inds




from .loss import * # isort:skip
from .quantized import conv_bias_activation # isort:skip

+ 2
- 0
imperative/python/test/unit/functional/test_functional.py View File

@@ -551,3 +551,5 @@ def test_nms_is_same():
assert op1 != op3
assert op1 != op4
assert op3 != op4



+ 1
- 0
src/opr/impl/misc.cpp View File

@@ -159,6 +159,7 @@ void Cumsum::init_output_static_infer_desc() {
{SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_workspace});
}


/* ================= CondTake ================= */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CondTake);



+ 1
- 0
src/opr/impl/misc.oprdecl View File

@@ -63,4 +63,5 @@ decl_opr('TopK',
inputs=['data', 'k'], params='TopK',
desc='Select the top k values from sorted result.')


# vim: ft=python

+ 1
- 0
src/opr/impl/misc.sereg.h View File

@@ -70,6 +70,7 @@ namespace opr {
using CumsumV1 = opr::Cumsum;
MGB_SEREG_OPR(CumsumV1, 1);


} // namespace opr
} // namespace mgb



+ 1
- 0
src/opr/include/megbrain/opr/misc.h View File

@@ -94,6 +94,7 @@ MGB_DEFINE_OPR_CLASS(Cumsum, cg::SingleCNOperatorNodeBaseT<
void init_output_static_infer_desc() override;
};


namespace intl {
using CondTakeBase =
cg::SingleCNOperatorNode<cg::OperatorNodeBase,


+ 2
- 0
src/serialization/impl/schema.fbs View File

@@ -28,6 +28,7 @@ table Blob {
}

table Reserved0 {}
table Reserved1 {}

union OperatorParam {
param.Empty = 1,
@@ -100,6 +101,7 @@ union OperatorParam {
param.Remap = 68,
param.NMSKeep = 69,
param.AdaptivePooling = 70,
Reserved1 = 71,
}

table Operator {


+ 1
- 0
tools/param_defs/mgb_opr_param_defs.py View File

@@ -143,3 +143,4 @@ pdef('PersistentOutputStorage').add_fields(
' no branch is taken')
)
)


Loading…
Cancel
Save