|
- # -*- coding: utf-8 -*-
- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- #
- # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- import numpy as np
-
-
- def assertTensorClose(
- v0, v1, *, max_err: float = 1e-6, allow_special_values: bool = False, name=None
- ):
- """
- :param allow_special_values: whether to allow :attr:`v0` and :attr:`v1` to contain inf and nan values.
- :param max_err: relative error
- """
- __tracebackhide__ = True # pylint: disable=unused-variable
-
- assert (
- v0.dtype == v1.dtype
- ), "Two Tensor must have same dtype, but the inputs are {} and {}".format(
- v0.dtype, v1.dtype
- )
- v0 = np.ascontiguousarray(v0, dtype=np.float32).copy()
- v1 = np.ascontiguousarray(v1, dtype=np.float32).copy()
- if allow_special_values:
- # check nan and rm it
- v0_nan_mask = np.isnan(v0)
- if np.any(v0_nan_mask):
- assert np.array_equiv(v0_nan_mask, np.isnan(v1)), (v0, v1)
- v0[v0_nan_mask] = 0
- v1[v0_nan_mask] = 0
- # check inf and rm it
- v0_inf_mask = v0 == float("inf")
- if np.any(v0_inf_mask):
- assert np.array_equiv(v0_inf_mask, v1 == float("inf")), (v0, v1)
- v0[v0_inf_mask] = 0
- v1[v0_inf_mask] = 0
- # check -inf and rm it
- v0_inf_mask = v0 == float("-inf")
- if np.any(v0_inf_mask):
- assert np.array_equiv(v0_inf_mask, v1 == float("-inf")), (v0, v1)
- v0[v0_inf_mask] = 0
- v1[v0_inf_mask] = 0
- else:
- assert np.isfinite(v0.sum()) and np.isfinite(v1.sum()), (v0, v1)
-
- assert v0.shape == v1.shape, "Two tensor must have same shape({} v.s. {})".format(
- v0.shape, v1.shape
- )
- vdiv = np.max([np.abs(v0), np.abs(v1), np.ones_like(v0)], axis=0)
- err = np.abs(v0 - v1) / vdiv
- check = err > max_err
- if check.sum():
- idx = tuple(i[0] for i in np.nonzero(check))
- if name is None:
- name = "tensor"
- else:
- name = "tensor {}".format(name)
- raise AssertionError(
- "{} not equal: "
- "shape={} nonequal_idx={} v0={} v1={} err={}".format(
- name, v0.shape, idx, v0[idx], v1[idx], err[idx]
- )
- )
|