Browse Source

feat(functional/ops): add _assert_equal

GitOrigin-RevId: b7ce4158b7
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
46cad4d3c8
9 changed files with 110 additions and 13 deletions
  1. +1
    -1
      imperative/python/megengine/functional/__init__.py
  2. +1
    -0
      imperative/python/megengine/functional/metric.py
  3. +57
    -0
      imperative/python/megengine/functional/utils.py
  4. +7
    -2
      imperative/python/megengine/jit/tracing.py
  5. +16
    -0
      imperative/python/test/unit/functional/test_functional.py
  6. +11
    -9
      imperative/src/impl/ops/specializations.cpp
  7. +8
    -1
      imperative/src/impl/proxy_graph.cpp
  8. +5
    -0
      imperative/src/impl/proxy_graph.h
  9. +4
    -0
      imperative/src/impl/proxy_graph_detail.cpp

+ 1
- 1
imperative/python/megengine/functional/__init__.py View File

@@ -7,7 +7,7 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# pylint: disable=redefined-builtin # pylint: disable=redefined-builtin
from . import metric, vision
from . import metric, utils, vision
from .elemwise import * from .elemwise import *
from .math import * from .math import *
from .nn import * from .nn import *


+ 1
- 0
imperative/python/megengine/functional/metric.py View File

@@ -11,6 +11,7 @@ from typing import Iterable, Union
import numpy as np import numpy as np


from ..tensor import Tensor from ..tensor import Tensor
from .elemwise import abs, maximum, minimum
from .math import topk as _topk from .math import topk as _topk
from .tensor import broadcast_to, transpose from .tensor import broadcast_to, transpose




+ 57
- 0
imperative/python/megengine/functional/utils.py View File

@@ -0,0 +1,57 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ..core._imperative_rt.core2 import apply
from ..core._imperative_rt.core2 import sync as _sync
from ..core.ops.builtin import AssertEqual
from ..tensor import Tensor
from .elemwise import abs, maximum, minimum


def _assert_equal(
expect: Tensor, actual: Tensor, *, maxerr: float = 0.0001, verbose: bool = False
):
r"""
Asserts two tensors equal and returns expected value (first input).
It is a variant of python assert which is symbolically traceable (similar to ``numpy.testing.assert_equal``).
If we want to verify the correctness of model, just ``assert`` its states and outputs.
While sometimes we need to verify the correctness at different backends for *dumped* model
(or in :class:`~jit.trace` context), and no python code could be executed in that case.
Thus we have to use :func:`~functional.utils._assert_equal` instead.

:param expect: expected tensor value
:param actual: tensor to check value
:param maxerr: max allowed error; error is defined as the minimal of absolute and relative error
:param verbose: whether to print maxerr to stdout during opr exec
:return: expected tensor

Examples:

.. testcode::

import numpy as np
from megengine import tensor
import megengine.functional as F

x = tensor([1, 2, 3], np.float32)
y = tensor([1, 2, 3], np.float32)
print(F.utils._assert_equal(x, y, maxerr=0).numpy())

Outputs:

.. testoutput::

[1. 2. 3.]
"""
err = (
abs(expect - actual)
/ maximum(minimum(abs(expect), abs(actual)), Tensor(1.0, dtype="float32"))
).max()
result = apply(AssertEqual(maxerr=maxerr, verbose=verbose), expect, actual, err)[0]
_sync() # sync interpreter to get exception
return result

+ 7
- 2
imperative/python/megengine/jit/tracing.py View File

@@ -28,7 +28,12 @@ from ..core._imperative_rt.core2 import (
unset_compiled, unset_compiled,
unset_tracing, unset_tracing,
) )
from ..core._imperative_rt.ops import CollectiveComm, RemoteRecv, RemoteSend
from ..core._imperative_rt.ops import (
AssertEqual,
CollectiveComm,
RemoteRecv,
RemoteSend,
)
from ..core._trace_option import set_symbolic_shape from ..core._trace_option import set_symbolic_shape
from ..core._wrap import device as as_device from ..core._wrap import device as as_device
from ..core.ops.builtin import BackwardGraph, OpDef from ..core.ops.builtin import BackwardGraph, OpDef
@@ -110,7 +115,7 @@ class TensorInfo:
self.data_reader = None self.data_reader = None




_io_op_types = {CollectiveComm, RemoteSend, RemoteRecv}
_io_op_types = {AssertEqual, CollectiveComm, RemoteSend, RemoteRecv}




class trace: class trace:


+ 16
- 0
imperative/python/test/unit/functional/test_functional.py View File

@@ -21,6 +21,7 @@ from megengine.core._trace_option import use_symbolic_shape
from megengine.core.autodiff.grad import Grad from megengine.core.autodiff.grad import Grad
from megengine.core.tensor.utils import make_shape_tuple from megengine.core.tensor.utils import make_shape_tuple
from megengine.distributed.helper import get_device_count_by_fork from megengine.distributed.helper import get_device_count_by_fork
from megengine.jit import trace




def test_where(): def test_where():
@@ -746,3 +747,18 @@ def test_ones(val):
shp = tensor(val) shp = tensor(val)
np_shp = np.array(val) np_shp = np.array(val)
np.testing.assert_equal(F.ones(shp), np.ones(np_shp)) np.testing.assert_equal(F.ones(shp), np.ones(np_shp))


def test_assert_equal():
shape = (2, 3, 4, 5)
x = F.ones(shape, dtype=np.float32)
y = F.zeros(shape, dtype=np.float32) + 1.00001
z = F.utils._assert_equal(x, y)


def test_assert_not_equal():
shape = (2, 3, 4, 5)
x = F.ones(shape, dtype=np.float32)
y = F.zeros(shape, dtype=np.float32) + 1.1
with pytest.raises(RuntimeError):
z = F.utils._assert_equal(x, y)

+ 11
- 9
imperative/src/impl/ops/specializations.cpp View File

@@ -451,20 +451,22 @@ OP_TRAIT_REG(Identity, Identity)


namespace { namespace assert_equal { namespace { namespace assert_equal {
auto apply_on_var_node( auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const AssertEqual&>(def);
mgb_assert(inputs.size() == 2);
OperatorNodeConfig config{op.make_name()};
return opr::AssertEqual::make(inputs[0], inputs[1], op.param(), config);

const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = def.cast_final<AssertEqual>();
if (inputs.size() == 2) {
return opr::AssertEqual::make(inputs[0], inputs[1], op.param());
} else {
// workaround for MiniGraph, which only allow one opr in the graph
mgb_assert(inputs.size() == 3);
return opr::AssertEqual::make(inputs[0], inputs[1], inputs[2], op.param(), {});
} }
}


OP_TRAIT_REG(AssertEqual, AssertEqual) OP_TRAIT_REG(AssertEqual, AssertEqual)
.apply_on_var_node(apply_on_var_node) .apply_on_var_node(apply_on_var_node)
.fallback(); .fallback();

}}
}} // assert_equal


namespace { namespace uniform_rng { namespace { namespace uniform_rng {
auto apply_on_var_node( auto apply_on_var_node(


+ 8
- 1
imperative/src/impl/proxy_graph.cpp View File

@@ -445,6 +445,12 @@ public:


size_t nr_oprs_in_graph() const override {return m_opr_refkeeper.size();} size_t nr_oprs_in_graph() const override {return m_opr_refkeeper.size();}


void record_async_error(std::unique_ptr<MegBrainError> async_exc) override {
if (!ProxyGraph::tm_async_error) {
std::swap(async_exc, tm_async_error);
}
}

std::unique_ptr<cg::AsyncExecutable> compile(const OutputSpec &out_spec) override {mgb_assert(0);} std::unique_ptr<cg::AsyncExecutable> compile(const OutputSpec &out_spec) override {mgb_assert(0);}
SmallVector<std::unique_ptr<cg::AsyncExecutable>> compile_multi_part( SmallVector<std::unique_ptr<cg::AsyncExecutable>> compile_multi_part(
const SmallVector<OutputSpec>& out_specs) override {mgb_assert(0);} const SmallVector<OutputSpec>& out_specs) override {mgb_assert(0);}
@@ -457,7 +463,6 @@ public:
size_t get_device_memory_size(CompNode cn) override {mgb_assert(0);} size_t get_device_memory_size(CompNode cn) override {mgb_assert(0);}
size_t clear_device_memory() override {mgb_assert(0);} size_t clear_device_memory() override {mgb_assert(0);}
void set_as_subgraph(ComputingGraph &par_graph) override {mgb_assert(0);} void set_as_subgraph(ComputingGraph &par_graph) override {mgb_assert(0);}
void record_async_error(std::unique_ptr<MegBrainError> async_exc) override {mgb_assert(0);}
}; };


std::atomic<size_t> ProxyGraph::ProxyGraphImpl::m_node_id = 0; std::atomic<size_t> ProxyGraph::ProxyGraphImpl::m_node_id = 0;
@@ -861,6 +866,8 @@ TensorPtr ProxyGraph::as_tensor(cg::OperatorNodeBase* opr, bool share) {
} }
} }


thread_local std::unique_ptr<MegBrainError> ProxyGraph::tm_async_error;

} // namespace imperative } // namespace imperative
} // namespace mgb } // namespace mgb




+ 5
- 0
imperative/src/impl/proxy_graph.h View File

@@ -24,6 +24,9 @@ namespace imperative {
class ProxyGraph : public NonCopyableObj { class ProxyGraph : public NonCopyableObj {
public: public:
static ProxyGraph* get_default_graph(); static ProxyGraph* get_default_graph();
static std::unique_ptr<MegBrainError> get_async_error() {
return std::move(tm_async_error);
}


/********************** Physical Tensor API **********************/ /********************** Physical Tensor API **********************/


@@ -98,6 +101,8 @@ private:
std::unique_ptr<ExecEnv> m_env; std::unique_ptr<ExecEnv> m_env;
std::unique_ptr<StaticInferManager> m_static_infer_manager; std::unique_ptr<StaticInferManager> m_static_infer_manager;
std::unique_ptr<SeqCompNodeOptimizer> m_seq_comp_node_optimizer; std::unique_ptr<SeqCompNodeOptimizer> m_seq_comp_node_optimizer;

static thread_local std::unique_ptr<MegBrainError> tm_async_error;
}; };


} // namespace imperative } // namespace imperative


+ 4
- 0
imperative/src/impl/proxy_graph_detail.cpp View File

@@ -101,6 +101,10 @@ apply_on_physical_tensor(const OpDef& def,
} }
} }
exec(def, inputs, outputs); exec(def, inputs, outputs);
auto async_error = ProxyGraph::get_async_error();
if (async_error) {
throw *async_error;
}
return outputs; return outputs;
} }




Loading…
Cancel
Save