@@ -7,7 +7,7 @@ | |||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
# pylint: disable=redefined-builtin | # pylint: disable=redefined-builtin | ||||
from . import metric, vision | |||||
from . import metric, utils, vision | |||||
from .elemwise import * | from .elemwise import * | ||||
from .math import * | from .math import * | ||||
from .nn import * | from .nn import * | ||||
@@ -11,6 +11,7 @@ from typing import Iterable, Union | |||||
import numpy as np | import numpy as np | ||||
from ..tensor import Tensor | from ..tensor import Tensor | ||||
from .elemwise import abs, maximum, minimum | |||||
from .math import topk as _topk | from .math import topk as _topk | ||||
from .tensor import broadcast_to, transpose | from .tensor import broadcast_to, transpose | ||||
@@ -0,0 +1,57 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
from ..core._imperative_rt.core2 import apply | |||||
from ..core._imperative_rt.core2 import sync as _sync | |||||
from ..core.ops.builtin import AssertEqual | |||||
from ..tensor import Tensor | |||||
from .elemwise import abs, maximum, minimum | |||||
def _assert_equal( | |||||
expect: Tensor, actual: Tensor, *, maxerr: float = 0.0001, verbose: bool = False | |||||
): | |||||
r""" | |||||
Asserts two tensors equal and returns expected value (first input). | |||||
It is a variant of python assert which is symbolically traceable (similar to ``numpy.testing.assert_equal``). | |||||
If we want to verify the correctness of model, just ``assert`` its states and outputs. | |||||
While sometimes we need to verify the correctness at different backends for *dumped* model | |||||
(or in :class:`~jit.trace` context), and no python code could be executed in that case. | |||||
Thus we have to use :func:`~functional.utils._assert_equal` instead. | |||||
:param expect: expected tensor value | |||||
:param actual: tensor to check value | |||||
:param maxerr: max allowed error; error is defined as the minimal of absolute and relative error | |||||
:param verbose: whether to print maxerr to stdout during opr exec | |||||
:return: expected tensor | |||||
Examples: | |||||
.. testcode:: | |||||
import numpy as np | |||||
from megengine import tensor | |||||
import megengine.functional as F | |||||
x = tensor([1, 2, 3], np.float32) | |||||
y = tensor([1, 2, 3], np.float32) | |||||
print(F.utils._assert_equal(x, y, maxerr=0).numpy()) | |||||
Outputs: | |||||
.. testoutput:: | |||||
[1. 2. 3.] | |||||
""" | |||||
err = ( | |||||
abs(expect - actual) | |||||
/ maximum(minimum(abs(expect), abs(actual)), Tensor(1.0, dtype="float32")) | |||||
).max() | |||||
result = apply(AssertEqual(maxerr=maxerr, verbose=verbose), expect, actual, err)[0] | |||||
_sync() # sync interpreter to get exception | |||||
return result |
@@ -28,7 +28,12 @@ from ..core._imperative_rt.core2 import ( | |||||
unset_compiled, | unset_compiled, | ||||
unset_tracing, | unset_tracing, | ||||
) | ) | ||||
from ..core._imperative_rt.ops import CollectiveComm, RemoteRecv, RemoteSend | |||||
from ..core._imperative_rt.ops import ( | |||||
AssertEqual, | |||||
CollectiveComm, | |||||
RemoteRecv, | |||||
RemoteSend, | |||||
) | |||||
from ..core._trace_option import set_symbolic_shape | from ..core._trace_option import set_symbolic_shape | ||||
from ..core._wrap import device as as_device | from ..core._wrap import device as as_device | ||||
from ..core.ops.builtin import BackwardGraph, OpDef | from ..core.ops.builtin import BackwardGraph, OpDef | ||||
@@ -110,7 +115,7 @@ class TensorInfo: | |||||
self.data_reader = None | self.data_reader = None | ||||
_io_op_types = {CollectiveComm, RemoteSend, RemoteRecv} | |||||
_io_op_types = {AssertEqual, CollectiveComm, RemoteSend, RemoteRecv} | |||||
class trace: | class trace: | ||||
@@ -21,6 +21,7 @@ from megengine.core._trace_option import use_symbolic_shape | |||||
from megengine.core.autodiff.grad import Grad | from megengine.core.autodiff.grad import Grad | ||||
from megengine.core.tensor.utils import make_shape_tuple | from megengine.core.tensor.utils import make_shape_tuple | ||||
from megengine.distributed.helper import get_device_count_by_fork | from megengine.distributed.helper import get_device_count_by_fork | ||||
from megengine.jit import trace | |||||
def test_where(): | def test_where(): | ||||
@@ -746,3 +747,18 @@ def test_ones(val): | |||||
shp = tensor(val) | shp = tensor(val) | ||||
np_shp = np.array(val) | np_shp = np.array(val) | ||||
np.testing.assert_equal(F.ones(shp), np.ones(np_shp)) | np.testing.assert_equal(F.ones(shp), np.ones(np_shp)) | ||||
def test_assert_equal(): | |||||
shape = (2, 3, 4, 5) | |||||
x = F.ones(shape, dtype=np.float32) | |||||
y = F.zeros(shape, dtype=np.float32) + 1.00001 | |||||
z = F.utils._assert_equal(x, y) | |||||
def test_assert_not_equal(): | |||||
shape = (2, 3, 4, 5) | |||||
x = F.ones(shape, dtype=np.float32) | |||||
y = F.zeros(shape, dtype=np.float32) + 1.1 | |||||
with pytest.raises(RuntimeError): | |||||
z = F.utils._assert_equal(x, y) |
@@ -451,20 +451,22 @@ OP_TRAIT_REG(Identity, Identity) | |||||
namespace { namespace assert_equal { | namespace { namespace assert_equal { | ||||
auto apply_on_var_node( | auto apply_on_var_node( | ||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
auto&& op = static_cast<const AssertEqual&>(def); | |||||
mgb_assert(inputs.size() == 2); | |||||
OperatorNodeConfig config{op.make_name()}; | |||||
return opr::AssertEqual::make(inputs[0], inputs[1], op.param(), config); | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
auto&& op = def.cast_final<AssertEqual>(); | |||||
if (inputs.size() == 2) { | |||||
return opr::AssertEqual::make(inputs[0], inputs[1], op.param()); | |||||
} else { | |||||
// workaround for MiniGraph, which only allow one opr in the graph | |||||
mgb_assert(inputs.size() == 3); | |||||
return opr::AssertEqual::make(inputs[0], inputs[1], inputs[2], op.param(), {}); | |||||
} | } | ||||
} | |||||
OP_TRAIT_REG(AssertEqual, AssertEqual) | OP_TRAIT_REG(AssertEqual, AssertEqual) | ||||
.apply_on_var_node(apply_on_var_node) | .apply_on_var_node(apply_on_var_node) | ||||
.fallback(); | .fallback(); | ||||
}} | |||||
}} // assert_equal | |||||
namespace { namespace uniform_rng { | namespace { namespace uniform_rng { | ||||
auto apply_on_var_node( | auto apply_on_var_node( | ||||
@@ -445,6 +445,12 @@ public: | |||||
size_t nr_oprs_in_graph() const override {return m_opr_refkeeper.size();} | size_t nr_oprs_in_graph() const override {return m_opr_refkeeper.size();} | ||||
void record_async_error(std::unique_ptr<MegBrainError> async_exc) override { | |||||
if (!ProxyGraph::tm_async_error) { | |||||
std::swap(async_exc, tm_async_error); | |||||
} | |||||
} | |||||
std::unique_ptr<cg::AsyncExecutable> compile(const OutputSpec &out_spec) override {mgb_assert(0);} | std::unique_ptr<cg::AsyncExecutable> compile(const OutputSpec &out_spec) override {mgb_assert(0);} | ||||
SmallVector<std::unique_ptr<cg::AsyncExecutable>> compile_multi_part( | SmallVector<std::unique_ptr<cg::AsyncExecutable>> compile_multi_part( | ||||
const SmallVector<OutputSpec>& out_specs) override {mgb_assert(0);} | const SmallVector<OutputSpec>& out_specs) override {mgb_assert(0);} | ||||
@@ -457,7 +463,6 @@ public: | |||||
size_t get_device_memory_size(CompNode cn) override {mgb_assert(0);} | size_t get_device_memory_size(CompNode cn) override {mgb_assert(0);} | ||||
size_t clear_device_memory() override {mgb_assert(0);} | size_t clear_device_memory() override {mgb_assert(0);} | ||||
void set_as_subgraph(ComputingGraph &par_graph) override {mgb_assert(0);} | void set_as_subgraph(ComputingGraph &par_graph) override {mgb_assert(0);} | ||||
void record_async_error(std::unique_ptr<MegBrainError> async_exc) override {mgb_assert(0);} | |||||
}; | }; | ||||
std::atomic<size_t> ProxyGraph::ProxyGraphImpl::m_node_id = 0; | std::atomic<size_t> ProxyGraph::ProxyGraphImpl::m_node_id = 0; | ||||
@@ -861,6 +866,8 @@ TensorPtr ProxyGraph::as_tensor(cg::OperatorNodeBase* opr, bool share) { | |||||
} | } | ||||
} | } | ||||
thread_local std::unique_ptr<MegBrainError> ProxyGraph::tm_async_error; | |||||
} // namespace imperative | } // namespace imperative | ||||
} // namespace mgb | } // namespace mgb | ||||
@@ -24,6 +24,9 @@ namespace imperative { | |||||
class ProxyGraph : public NonCopyableObj { | class ProxyGraph : public NonCopyableObj { | ||||
public: | public: | ||||
static ProxyGraph* get_default_graph(); | static ProxyGraph* get_default_graph(); | ||||
static std::unique_ptr<MegBrainError> get_async_error() { | |||||
return std::move(tm_async_error); | |||||
} | |||||
/********************** Physical Tensor API **********************/ | /********************** Physical Tensor API **********************/ | ||||
@@ -98,6 +101,8 @@ private: | |||||
std::unique_ptr<ExecEnv> m_env; | std::unique_ptr<ExecEnv> m_env; | ||||
std::unique_ptr<StaticInferManager> m_static_infer_manager; | std::unique_ptr<StaticInferManager> m_static_infer_manager; | ||||
std::unique_ptr<SeqCompNodeOptimizer> m_seq_comp_node_optimizer; | std::unique_ptr<SeqCompNodeOptimizer> m_seq_comp_node_optimizer; | ||||
static thread_local std::unique_ptr<MegBrainError> tm_async_error; | |||||
}; | }; | ||||
} // namespace imperative | } // namespace imperative | ||||
@@ -101,6 +101,10 @@ apply_on_physical_tensor(const OpDef& def, | |||||
} | } | ||||
} | } | ||||
exec(def, inputs, outputs); | exec(def, inputs, outputs); | ||||
auto async_error = ProxyGraph::get_async_error(); | |||||
if (async_error) { | |||||
throw *async_error; | |||||
} | |||||
return outputs; | return outputs; | ||||
} | } | ||||