GitOrigin-RevId: 56918db014
tags/v1.0.0-rc1
@@ -11,6 +11,7 @@ import itertools | |||||
import numpy as np | import numpy as np | ||||
import pytest | import pytest | ||||
import megengine.core.ops.builtin as builtin | |||||
import megengine.core.tensor.dtype as dtype | import megengine.core.tensor.dtype as dtype | ||||
import megengine.functional as F | import megengine.functional as F | ||||
from megengine import Buffer, Parameter, is_cuda_available, tensor | from megengine import Buffer, Parameter, is_cuda_available, tensor | ||||
@@ -631,3 +632,20 @@ def test_condtake(): | |||||
val, idx = F.cond_take(yy, xx) | val, idx = F.cond_take(yy, xx) | ||||
np.testing.assert_equal(val.numpy(), x[y]) | np.testing.assert_equal(val.numpy(), x[y]) | ||||
np.testing.assert_equal(idx.numpy(), np.where(y.reshape(-1))[0]) | np.testing.assert_equal(idx.numpy(), np.where(y.reshape(-1))[0]) | ||||
def test_condtake_is_same(): | |||||
op1 = builtin.CondTake() | |||||
op2 = builtin.CondTake() | |||||
assert op1 == op2 | |||||
def test_nms_is_same(): | |||||
op1 = builtin.NMSKeep(0.7, 100) | |||||
op2 = builtin.NMSKeep(0.7, 100) | |||||
op3 = builtin.NMSKeep(0.8, 100) | |||||
op4 = builtin.NMSKeep(0.7, 200) | |||||
assert op1 == op2 | |||||
assert op1 != op3 | |||||
assert op1 != op4 | |||||
assert op3 != op4 |
@@ -19,6 +19,15 @@ class CondTake : public OpDefImplBase<CondTake> { | |||||
MGB_DYN_TYPE_OBJ_FINAL_DECL; | MGB_DYN_TYPE_OBJ_FINAL_DECL; | ||||
public: | public: | ||||
CondTake() = default; | CondTake() = default; | ||||
size_t hash() const override { | |||||
return reinterpret_cast<std::uintptr_t>(dyn_typeinfo()); | |||||
} | |||||
bool is_same_st(const Hashable& rhs) const override { | |||||
return rhs.dyn_typeinfo() == dyn_typeinfo(); | |||||
} | |||||
}; | }; | ||||
} // namespace mgb::imperative | } // namespace mgb::imperative |
@@ -23,6 +23,20 @@ public: | |||||
NMSKeep() = default; | NMSKeep() = default; | ||||
NMSKeep(float iou_thresh_, uint32_t max_output_): | NMSKeep(float iou_thresh_, uint32_t max_output_): | ||||
iou_thresh(iou_thresh_), max_output(max_output_) {} | iou_thresh(iou_thresh_), max_output(max_output_) {} | ||||
size_t hash() const override { | |||||
return hash_pair_combine( | |||||
hash_pair_combine(mgb::hash(iou_thresh), mgb::hash(max_output)), | |||||
reinterpret_cast<std::uintptr_t>(dyn_typeinfo())); | |||||
} | |||||
bool is_same_st(const Hashable& rhs_) const override { | |||||
auto&& rhs = static_cast<const NMSKeep&>(rhs_); | |||||
return rhs.dyn_typeinfo() == dyn_typeinfo() | |||||
&& rhs.iou_thresh == iou_thresh | |||||
&& rhs.max_output == max_output; | |||||
} | |||||
}; | }; | ||||
} // namespace mgb::imperative | } // namespace mgb::imperative |