Browse Source

refactor(mge): tensor_shape -> symbolic_shape

GitOrigin-RevId: 366dc048bf
release-1.1
Megvii Engine Team 4 years ago
parent
commit
9748aebeea
10 changed files with 31 additions and 31 deletions
  1. +8
    -8
      imperative/python/megengine/core/_trace_option.py
  2. +3
    -3
      imperative/python/megengine/core/tensor/indexing.py
  3. +4
    -4
      imperative/python/megengine/core/tensor/tensor_wrapper.py
  4. +4
    -4
      imperative/python/megengine/jit/tracing.py
  5. +1
    -1
      imperative/python/test/integration/test_correctness.py
  6. +2
    -2
      imperative/python/test/unit/core/test_indexing_op.py
  7. +1
    -1
      imperative/python/test/unit/functional/test_functional.py
  8. +1
    -1
      imperative/python/test/unit/functional/test_tensor.py
  9. +1
    -1
      imperative/python/test/unit/module/test_batchnorm.py
  10. +6
    -6
      imperative/python/test/unit/test_tracing.py

+ 8
- 8
imperative/python/megengine/core/_trace_option.py View File

@@ -9,20 +9,20 @@

import os

_use_tensor_shape = False
if os.environ.get("MEGENGINE_USE_TENSOR_SHAPE"):
_use_tensor_shape = True
_use_symbolic_shape = False
if os.environ.get("MEGENGINE_USE_SYMBOLIC_SHAPE"):
_use_symbolic_shape = True


def use_tensor_shape() -> bool:
def use_symbolic_shape() -> bool:
"""Returns whether tensor.shape returns a tensor instead of a tuple

"""
return _use_tensor_shape
return _use_symbolic_shape


def set_tensor_shape(option: bool):
def set_symbolic_shape(option: bool):
""" Sets whether tensor.shape returns a tensor instead of a tuple
"""
global _use_tensor_shape
_use_tensor_shape = option
global _use_symbolic_shape
_use_symbolic_shape = option

+ 3
- 3
imperative/python/megengine/core/tensor/indexing.py View File

@@ -10,7 +10,7 @@ from typing import Iterable

import numpy as np

from .._trace_option import use_tensor_shape
from .._trace_option import use_symbolic_shape
from ..ops import builtin
from ..ops.special import Const
from .core import TensorBase, TensorWrapperBase, apply
@@ -58,7 +58,7 @@ def check_bool_index(tensor, tuple_val):
)
)
i = i.reshape(-1)
if not use_tensor_shape():
if not use_symbolic_shape():
cur_shape = (
cur_shape[:idx]
+ (i.shape[0],)
@@ -76,7 +76,7 @@ def check_bool_index(tensor, tuple_val):
offset += 1
tensor = tensor.reshape(cur_shape)
tdim += tot
if use_tensor_shape():
if use_symbolic_shape():
cur_shape = make_shape_tuple(cur_shape)
new_tuple_val.append(i)
else:


+ 4
- 4
imperative/python/megengine/core/tensor/tensor_wrapper.py View File

@@ -11,7 +11,7 @@ import collections

import numpy as np

from .._trace_option import use_tensor_shape
from .._trace_option import use_symbolic_shape
from ..ops import builtin
from ..ops.builtin import GetVarShape
from ..ops.special import Const
@@ -342,7 +342,7 @@ class ArrayMethodMixin(abc.ABC):

def __len__(self):
shape = self.shape
if use_tensor_shape():
if use_symbolic_shape():
shape = shape.numpy()
if shape:
return int(shape[0])
@@ -372,7 +372,7 @@ class ArrayMethodMixin(abc.ABC):

@property
def size(self):
if use_tensor_shape():
if use_symbolic_shape():
return self.shape.prod()
return np.prod(self.shape).item()

@@ -462,7 +462,7 @@ class GenericTensorWrapper(ArrayMethodMixin, TensorWrapperBase):

@property
def shape(self):
if use_tensor_shape():
if use_symbolic_shape():
return apply(GetVarShape(), self)[0]
else:
return self.__wrapped__.shape


+ 4
- 4
imperative/python/megengine/jit/tracing.py View File

@@ -19,7 +19,7 @@ import numpy as np

from ..core._imperative_rt import GraphProfiler
from ..core._imperative_rt.ops import OprAttr
from ..core._trace_option import set_tensor_shape
from ..core._trace_option import set_symbolic_shape
from ..core.ops.special import Const
from ..core.tensor import megbrain_graph as G
from ..core.tensor.core import OpBase, TensorBase, TensorWrapperBase, apply
@@ -121,7 +121,7 @@ class trace:
sublinear_memory_config: SublinearMemoryConfig = None,
profiling: bool = False,
opt_level: int = None,
tensor_shape: bool = True,
symbolic_shape: bool = True,
):
self.__wrapped__ = function
self._symbolic = symbolic
@@ -130,7 +130,7 @@ class trace:
self._profiling = profiling
self._profiler = None
self._graph_opt_level = opt_level
self._tensor_shape = tensor_shape
self._symbolic_shape = symbolic_shape

self._reset()

@@ -152,7 +152,7 @@ class trace:
self._output_bindings = None
self._output_names = None

set_tensor_shape(self._tensor_shape)
set_symbolic_shape(self._symbolic_shape)

def _new_handle(self):
handle = len(self._tinfo)


+ 1
- 1
imperative/python/test/integration/test_correctness.py View File

@@ -18,7 +18,7 @@ import megengine as mge
import megengine.autodiff as ad
import megengine.functional as F
from megengine import jit
from megengine.core._trace_option import set_tensor_shape
from megengine.core._trace_option import set_symbolic_shape
from megengine.core.tensor.utils import make_shape_tuple
from megengine.functional.debug_param import set_conv_execution_strategy
from megengine.jit import SublinearMemoryConfig


+ 2
- 2
imperative/python/test/unit/core/test_indexing_op.py View File

@@ -13,7 +13,7 @@ import pytest

import megengine.core.ops.builtin
import megengine.core.tensor.raw_tensor
from megengine.core._trace_option import use_tensor_shape
from megengine.core._trace_option import use_symbolic_shape
from megengine.core.ops._internal import all_ops
from megengine.core.tensor import Tensor
from megengine.core.tensor.core import apply
@@ -532,7 +532,7 @@ def test_advance_indexing_with_bool():
np.testing.assert_equal(a, aa.numpy())

# XXX: trace does not expect empty condtake tensor
if not use_tensor_shape():
if not use_symbolic_shape():
a = np.ones((2, 2), dtype=np.int32)
b = np.array([[False, False], [False, False]])
aa = Tensor(a)


+ 1
- 1
imperative/python/test/unit/functional/test_functional.py View File

@@ -17,7 +17,7 @@ import megengine.core.ops.builtin as builtin
import megengine.core.tensor.dtype as dtype
import megengine.functional as F
from megengine import Parameter, Tensor, is_cuda_available, tensor
from megengine.core._trace_option import use_tensor_shape
from megengine.core._trace_option import use_symbolic_shape
from megengine.core.autodiff.grad import Grad
from megengine.core.tensor.utils import make_shape_tuple



+ 1
- 1
imperative/python/test/unit/functional/test_tensor.py View File

@@ -15,7 +15,7 @@ from utils import opr_test

import megengine.functional as F
from megengine import tensor
from megengine.core._trace_option import use_tensor_shape
from megengine.core._trace_option import use_symbolic_shape
from megengine.core.tensor.utils import astensor1d
from megengine.distributed.helper import get_device_count_by_fork



+ 1
- 1
imperative/python/test/unit/module/test_batchnorm.py View File

@@ -16,7 +16,7 @@ import pytest
import megengine as mge
import megengine.distributed as dist
from megengine import Tensor
from megengine.core._trace_option import use_tensor_shape
from megengine.core._trace_option import use_symbolic_shape
from megengine.module import BatchNorm1d, BatchNorm2d, SyncBatchNorm

_assert_allclose = functools.partial(np.testing.assert_allclose, atol=5e-6, rtol=5e-6)


+ 6
- 6
imperative/python/test/unit/test_tracing.py View File

@@ -15,7 +15,7 @@ import pytest
import megengine.core.tensor.megbrain_graph as G
import megengine.functional as F
from megengine import cgtools, tensor
from megengine.core._trace_option import set_tensor_shape
from megengine.core._trace_option import set_symbolic_shape
from megengine.core.ops import builtin as ops
from megengine.core.tensor.core import apply
from megengine.core.tensor.raw_tensor import as_raw_tensor
@@ -238,7 +238,7 @@ def test_optimize_for_inference():
def test_optimize_for_inference_broadcast():
a = tensor(np.ones(1, dtype=np.float32))

@trace(capture_as_const=True, tensor_shape=True)
@trace(capture_as_const=True, symbolic_shape=True)
def f():
(b,) = apply(ops.Broadcast(), a, tensor([1, 10], dtype=np.int32))
return b
@@ -248,7 +248,7 @@ def test_optimize_for_inference_broadcast():


def test_trace_cvt_bool():
set_tensor_shape(True)
set_symbolic_shape(True)
x = tensor([0], dtype=np.int32)

@trace(symbolic=True)
@@ -261,7 +261,7 @@ def test_trace_cvt_bool():

def test_trace_reshape():
for symbolic in [False, True]:
set_tensor_shape(True)
set_symbolic_shape(True)
x1 = tensor(np.random.randn(2, 10, 10))
x2 = tensor(np.random.randn(4, 10, 10))
x3 = tensor(np.random.randn(8, 10, 10))
@@ -344,7 +344,7 @@ def test_raise_on_trace():

def test_trace_broadcast():
for symbolic in [False, True]:
set_tensor_shape(True)
set_symbolic_shape(True)
x1 = tensor(np.random.randn(3, 1, 1))
x2 = tensor(np.random.randn(1, 4, 1))
x3 = tensor(np.random.randn(1, 1, 5))
@@ -382,7 +382,7 @@ def test_trace_nms():


def test_trace_valid_broadcast():
set_tensor_shape(True)
set_symbolic_shape(True)
x1 = tensor(np.random.randn(1, 1))
x2 = tensor(np.random.randn(1, 2))
shape = (tensor([2]), tensor([2]))


Loading…
Cancel
Save