Browse Source

fix(sdk): add AssertEqual opr, fix dump_with_testcase_mge

GitOrigin-RevId: 6f797570b6
release-1.2
Megvii Engine Team 4 years ago
parent
commit
aba0acc797
3 changed files with 23 additions and 6 deletions
  1. +16
    -0
      imperative/src/impl/ops/specializations.cpp
  2. +6
    -6
      sdk/load-and-run/dump_with_testcase_mge.py
  3. +1
    -0
      src/core/include/megbrain/ir/ops.td

+ 16
- 0
imperative/src/impl/ops/specializations.cpp View File

@@ -418,6 +418,22 @@ OP_TRAIT_REG(Identity, Identity)
.fallback(); .fallback();
}} // identity }} // identity


namespace { namespace assert_equal {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const AssertEqual&>(def);
mgb_assert(inputs.size() == 2);
return opr::AssertEqual::make(inputs[0],inputs[1],op.param());

}
OP_TRAIT_REG(AssertEqual, AssertEqual)
.apply_on_var_node(apply_on_var_node)
.fallback();

}}

namespace { namespace uniform_rng { namespace { namespace uniform_rng {
auto apply_on_var_node( auto apply_on_var_node(
const OpDef& def, const OpDef& def,


+ 6
- 6
sdk/load-and-run/dump_with_testcase_mge.py View File

@@ -19,9 +19,9 @@ import megengine.core._imperative_rt as rt
import megengine.core.tensor.megbrain_graph as G import megengine.core.tensor.megbrain_graph as G
from megengine.utils import comp_graph_tools as cgtools from megengine.utils import comp_graph_tools as cgtools
from megengine.core.ops import builtin from megengine.core.ops import builtin
from megengine.core.tensor.core import apply
from megengine.core._imperative_rt.core2 import apply
from megengine.core.tensor.megbrain_graph import VarNode from megengine.core.tensor.megbrain_graph import VarNode
from megengine.core.tensor.raw_tensor import as_raw_tensor
from megengine import tensor


logger = mge.get_logger(__name__) logger = mge.get_logger(__name__)


@@ -195,7 +195,7 @@ def make_feeds(args):
func = cg_rt.compile([node.outputs[0] for node in output_nodes]) func = cg_rt.compile([node.outputs[0] for node in output_nodes])


def make_dev_tensor(value, dtype=None, device=None): def make_dev_tensor(value, dtype=None, device=None):
return as_raw_tensor(value, dtype=dtype, device=device)._dev_tensor()
return tensor(value, dtype=dtype, device=device)._dev_tensor()


def calculate(*args, **kwargs): def calculate(*args, **kwargs):
output_val = [] output_val = []
@@ -268,8 +268,8 @@ def make_feeds(args):


def assert_equal(expect, real, **kwargs): def assert_equal(expect, real, **kwargs):
op = builtin.AssertEqual(**kwargs) op = builtin.AssertEqual(**kwargs)
(res,) = apply(op, expect, real)
return res
(res,) = G.apply_normal_varnode(op, expect, real)
return G.VarNode(res)


verbose = not args.silent verbose = not args.silent


@@ -509,7 +509,7 @@ def main():
) )


def make_dev_tensor(value, dtype=None, device=None): def make_dev_tensor(value, dtype=None, device=None):
return as_raw_tensor(value, dtype=dtype, device=device)._dev_tensor()
return tensor(value, dtype=dtype, device=device)._dev_tensor()


for testcase in feeds["testcases"]: for testcase in feeds["testcases"]:
assert isinstance(testcase, dict) assert isinstance(testcase, dict)


+ 1
- 0
src/core/include/megbrain/ir/ops.td View File

@@ -231,6 +231,7 @@ def BatchedIncrMeshIndexing: FancyIndexingBase<"BatchedIncrMeshIndexing">;
def BatchedSetMeshIndexing: FancyIndexingBase<"BatchedSetMeshIndexing">; def BatchedSetMeshIndexing: FancyIndexingBase<"BatchedSetMeshIndexing">;


def FakeQuant: MgbHashableOp<"FakeQuant", [FakeQuantParam]>; def FakeQuant: MgbHashableOp<"FakeQuant", [FakeQuantParam]>;
def AssertEqual: MgbHashableOp<"AssertEqual",[AssertEqualParam]>;
def ElemwiseMultiType: MgbHashableOp<"ElemwiseMultiType", [ElemwiseMultiTypeParam]> { def ElemwiseMultiType: MgbHashableOp<"ElemwiseMultiType", [ElemwiseMultiTypeParam]> {
let extraArguments = (ins let extraArguments = (ins
MgbDTypeAttr:$dtype MgbDTypeAttr:$dtype


Loading…
Cancel
Save