You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 2.1 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. # -*- coding: utf-8 -*-
  2. from ..core._imperative_rt.core2 import apply
  3. from ..core._imperative_rt.core2 import sync as _sync
  4. from ..core.ops.builtin import AssertEqual
  5. from ..tensor import Tensor
  6. from ..utils.deprecation import deprecated_func
  7. from .elemwise import abs, maximum, minimum
  8. from .tensor import ones, zeros
  9. __all__ = ["topk_accuracy"]
  10. def _assert_equal(
  11. expect: Tensor, actual: Tensor, *, maxerr: float = 0.0001, verbose: bool = False
  12. ):
  13. r"""Asserts two tensors equal and returns expected value (first input).
  14. It is a variant of python assert which is symbolically traceable (similar to ``numpy.testing.assert_equal``).
  15. If we want to verify the correctness of model, just ``assert`` its states and outputs.
  16. While sometimes we need to verify the correctness at different backends for *dumped* model
  17. (or in :class:`~jit.trace` context), and no python code could be executed in that case.
  18. Thus we have to use :func:`~functional.utils._assert_equal` instead.
  19. Args:
  20. expect: expected tensor value
  21. actual: tensor to check value
  22. maxerr: max allowed error; error is defined as the minimal of absolute and relative error
  23. verbose: whether to print maxerr to stdout during opr exec
  24. Examples:
  25. >>> x = Tensor([1, 2, 3], dtype="float32")
  26. >>> y = Tensor([1, 2, 3], dtype="float32")
  27. >>> F.utils._assert_equal(x, y, maxerr=0)
  28. Tensor([1. 2. 3.], device=xpux:0)
  29. """
  30. err = (
  31. abs(expect - actual)
  32. / maximum(
  33. minimum(abs(expect), abs(actual)),
  34. Tensor(1.0, dtype="float32", device=expect.device),
  35. )
  36. ).max()
  37. result = apply(AssertEqual(maxerr=maxerr, verbose=verbose), expect, actual, err)[0]
  38. _sync() # sync interpreter to get exception
  39. return result
  40. def _simulate_error():
  41. x1 = zeros(100)
  42. x2 = ones(100)
  43. (ret,) = apply(AssertEqual(maxerr=0, verbose=False), x1, x2, x2)
  44. return ret
  45. topk_accuracy = deprecated_func(
  46. "1.3", "megengine.functional.metric", "topk_accuracy", True
  47. )
  48. copy = deprecated_func("1.3", "megengine.functional.tensor", "copy", True)