Browse Source

fix(mge/api): check input dim of dot and mark output as scalar

GitOrigin-RevId: a3ba7e099c
release-1.2
Megvii Engine Team 4 years ago
parent
commit
44742e32e9
1 changed files with 8 additions and 3 deletions
  1. +8
    -3
      imperative/python/megengine/functional/nn.py

+ 8
- 3
imperative/python/megengine/functional/nn.py View File

@@ -16,7 +16,7 @@ from ..core.ops import builtin
from ..core.ops.builtin import BatchNorm
from ..core.ops.special import Const
from ..core.tensor import megbrain_graph, utils
from ..core.tensor.utils import astensor1d
from ..core.tensor.utils import astensor1d, setscalar
from ..distributed import WORLD, is_distributed
from ..jit.tracing import is_tracing
from ..random import uniform
@@ -1133,7 +1133,8 @@ def matmul(
def dot(inp1: Tensor, inp2: Tensor) -> Tensor:
"""
Computes dot-product of two vectors ``inp1`` and ``inp2``.
inputs must be 1-dimensional, scalar input can be automatically broadcasted.
inputs must be 1-dimensional or scalar. A scalar input is automatically broadcasted.
Refer to :func:`~.matmul` for more general usage.

:param inp1: first vector.
:param inp2: second vector.
@@ -1156,12 +1157,16 @@ def dot(inp1: Tensor, inp2: Tensor) -> Tensor:

.. testoutput::

[55.]
55.

"""
op = builtin.Dot()
inp1, inp2 = utils.convert_inputs(inp1, inp2)
assert (
inp1.ndim <= 1 and inp2.ndim <= 1
), "Input tensors for dot must be 1-dimensional or scalar"
(result,) = apply(op, inp1, inp2)
setscalar(result)
return result




Loading…
Cancel
Save