Browse Source

feat(mge): remove F.identity

GitOrigin-RevId: 858be627ac
release-1.1
Megvii Engine Team 4 years ago
parent
commit
cf0507e1ba
4 changed files with 10 additions and 20 deletions
  1. +0
    -13
      imperative/python/megengine/functional/tensor.py
  2. +7
    -4
      imperative/python/megengine/functional/utils.py
  3. +2
    -2
      imperative/python/megengine/module/identity.py
  4. +1
    -1
      imperative/python/test/unit/functional/test_tensor.py

+ 0
- 13
imperative/python/megengine/functional/tensor.py View File

@@ -42,7 +42,6 @@ __all__ = [
"full",
"full_like",
"gather",
"identity",
"linspace",
"ones",
"ones_like",
@@ -178,18 +177,6 @@ def full_like(inp: Tensor, value: Union[int, float]) -> Tensor:
return full(inp.shape, value, dtype=inp.dtype, device=inp.device)


def identity(inp: Tensor) -> Tensor:
"""Applies an identity transformation to input tensor.

:param inp: input tensor.
:return: output tensor.
"""
op = builtin.Identity()
(data,) = convert_inputs(inp)
(output,) = apply(op, data)
return output


def broadcast_to(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor:
"""
Broadcasts a tensor to given shape.


+ 7
- 4
imperative/python/megengine/functional/utils.py View File

@@ -11,7 +11,8 @@ from typing import Iterable, Union

import numpy as np

from ..core.ops.builtin import Copy
from ..core._wrap import device as as_device
from ..core.ops.builtin import Copy, Identity
from ..core.tensor import Tensor
from ..core.tensor.core import apply
from .math import topk as _topk
@@ -63,12 +64,12 @@ def accuracy(
return accs


def copy(inp, cn):
def copy(inp, device=None):
r"""
Copies tensor to another device.

:param inp: input tensor.
:param cn: destination device.
:param device: destination device.

Examples:

@@ -88,4 +89,6 @@ def copy(inp, cn):

[1 2 3]
"""
return apply(Copy(comp_node=cn), inp)[0]
if device is None:
return apply(Identity(), inp)[0]
return apply(Copy(comp_node=as_device(device).to_c()), inp)[0]

+ 2
- 2
imperative/python/megengine/module/identity.py View File

@@ -6,7 +6,7 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ..functional import identity
from ..functional import copy
from .module import Module


@@ -14,4 +14,4 @@ class Identity(Module):
r"""A placeholder identity operator that will ignore any argument."""

def forward(self, x):
return identity(x)
return copy(x)

+ 1
- 1
imperative/python/test/unit/functional/test_tensor.py View File

@@ -314,7 +314,7 @@ def test_device():

def test_identity():
x = tensor(np.random.random((5, 10)).astype(np.float32))
y = F.identity(x)
y = F.copy(x)
np.testing.assert_equal(y.numpy(), x)




Loading…
Cancel
Save