@@ -9,20 +9,20 @@ | |||||
import os | import os | ||||
_use_tensor_shape = False | |||||
if os.environ.get("MEGENGINE_USE_TENSOR_SHAPE"): | |||||
_use_tensor_shape = True | |||||
_use_symbolic_shape = False | |||||
if os.environ.get("MEGENGINE_USE_SYMBOLIC_SHAPE"): | |||||
_use_symbolic_shape = True | |||||
def use_tensor_shape() -> bool: | |||||
def use_symbolic_shape() -> bool: | |||||
"""Returns whether tensor.shape returns a tensor instead of a tuple | """Returns whether tensor.shape returns a tensor instead of a tuple | ||||
""" | """ | ||||
return _use_tensor_shape | |||||
return _use_symbolic_shape | |||||
def set_tensor_shape(option: bool): | |||||
def set_symbolic_shape(option: bool): | |||||
""" Sets whether tensor.shape returns a tensor instead of a tuple | """ Sets whether tensor.shape returns a tensor instead of a tuple | ||||
""" | """ | ||||
global _use_tensor_shape | |||||
_use_tensor_shape = option | |||||
global _use_symbolic_shape | |||||
_use_symbolic_shape = option |
@@ -10,7 +10,7 @@ from typing import Iterable | |||||
import numpy as np | import numpy as np | ||||
from .._trace_option import use_tensor_shape | |||||
from .._trace_option import use_symbolic_shape | |||||
from ..ops import builtin | from ..ops import builtin | ||||
from ..ops.special import Const | from ..ops.special import Const | ||||
from .core import TensorBase, TensorWrapperBase, apply | from .core import TensorBase, TensorWrapperBase, apply | ||||
@@ -58,7 +58,7 @@ def check_bool_index(tensor, tuple_val): | |||||
) | ) | ||||
) | ) | ||||
i = i.reshape(-1) | i = i.reshape(-1) | ||||
if not use_tensor_shape(): | |||||
if not use_symbolic_shape(): | |||||
cur_shape = ( | cur_shape = ( | ||||
cur_shape[:idx] | cur_shape[:idx] | ||||
+ (i.shape[0],) | + (i.shape[0],) | ||||
@@ -76,7 +76,7 @@ def check_bool_index(tensor, tuple_val): | |||||
offset += 1 | offset += 1 | ||||
tensor = tensor.reshape(cur_shape) | tensor = tensor.reshape(cur_shape) | ||||
tdim += tot | tdim += tot | ||||
if use_tensor_shape(): | |||||
if use_symbolic_shape(): | |||||
cur_shape = make_shape_tuple(cur_shape) | cur_shape = make_shape_tuple(cur_shape) | ||||
new_tuple_val.append(i) | new_tuple_val.append(i) | ||||
else: | else: | ||||
@@ -11,7 +11,7 @@ import collections | |||||
import numpy as np | import numpy as np | ||||
from .._trace_option import use_tensor_shape | |||||
from .._trace_option import use_symbolic_shape | |||||
from ..ops import builtin | from ..ops import builtin | ||||
from ..ops.builtin import GetVarShape | from ..ops.builtin import GetVarShape | ||||
from ..ops.special import Const | from ..ops.special import Const | ||||
@@ -342,7 +342,7 @@ class ArrayMethodMixin(abc.ABC): | |||||
def __len__(self): | def __len__(self): | ||||
shape = self.shape | shape = self.shape | ||||
if use_tensor_shape(): | |||||
if use_symbolic_shape(): | |||||
shape = shape.numpy() | shape = shape.numpy() | ||||
if shape: | if shape: | ||||
return int(shape[0]) | return int(shape[0]) | ||||
@@ -372,7 +372,7 @@ class ArrayMethodMixin(abc.ABC): | |||||
@property | @property | ||||
def size(self): | def size(self): | ||||
if use_tensor_shape(): | |||||
if use_symbolic_shape(): | |||||
return self.shape.prod() | return self.shape.prod() | ||||
return np.prod(self.shape).item() | return np.prod(self.shape).item() | ||||
@@ -462,7 +462,7 @@ class GenericTensorWrapper(ArrayMethodMixin, TensorWrapperBase): | |||||
@property | @property | ||||
def shape(self): | def shape(self): | ||||
if use_tensor_shape(): | |||||
if use_symbolic_shape(): | |||||
return apply(GetVarShape(), self)[0] | return apply(GetVarShape(), self)[0] | ||||
else: | else: | ||||
return self.__wrapped__.shape | return self.__wrapped__.shape | ||||
@@ -19,7 +19,7 @@ import numpy as np | |||||
from ..core._imperative_rt import GraphProfiler | from ..core._imperative_rt import GraphProfiler | ||||
from ..core._imperative_rt.ops import OprAttr | from ..core._imperative_rt.ops import OprAttr | ||||
from ..core._trace_option import set_tensor_shape | |||||
from ..core._trace_option import set_symbolic_shape | |||||
from ..core.ops.special import Const | from ..core.ops.special import Const | ||||
from ..core.tensor import megbrain_graph as G | from ..core.tensor import megbrain_graph as G | ||||
from ..core.tensor.core import OpBase, TensorBase, TensorWrapperBase, apply | from ..core.tensor.core import OpBase, TensorBase, TensorWrapperBase, apply | ||||
@@ -121,7 +121,7 @@ class trace: | |||||
sublinear_memory_config: SublinearMemoryConfig = None, | sublinear_memory_config: SublinearMemoryConfig = None, | ||||
profiling: bool = False, | profiling: bool = False, | ||||
opt_level: int = None, | opt_level: int = None, | ||||
tensor_shape: bool = True, | |||||
symbolic_shape: bool = True, | |||||
): | ): | ||||
self.__wrapped__ = function | self.__wrapped__ = function | ||||
self._symbolic = symbolic | self._symbolic = symbolic | ||||
@@ -130,7 +130,7 @@ class trace: | |||||
self._profiling = profiling | self._profiling = profiling | ||||
self._profiler = None | self._profiler = None | ||||
self._graph_opt_level = opt_level | self._graph_opt_level = opt_level | ||||
self._tensor_shape = tensor_shape | |||||
self._symbolic_shape = symbolic_shape | |||||
self._reset() | self._reset() | ||||
@@ -152,7 +152,7 @@ class trace: | |||||
self._output_bindings = None | self._output_bindings = None | ||||
self._output_names = None | self._output_names = None | ||||
set_tensor_shape(self._tensor_shape) | |||||
set_symbolic_shape(self._symbolic_shape) | |||||
def _new_handle(self): | def _new_handle(self): | ||||
handle = len(self._tinfo) | handle = len(self._tinfo) | ||||
@@ -18,7 +18,7 @@ import megengine as mge | |||||
import megengine.autodiff as ad | import megengine.autodiff as ad | ||||
import megengine.functional as F | import megengine.functional as F | ||||
from megengine import jit | from megengine import jit | ||||
from megengine.core._trace_option import set_tensor_shape | |||||
from megengine.core._trace_option import set_symbolic_shape | |||||
from megengine.core.tensor.utils import make_shape_tuple | from megengine.core.tensor.utils import make_shape_tuple | ||||
from megengine.functional.debug_param import set_conv_execution_strategy | from megengine.functional.debug_param import set_conv_execution_strategy | ||||
from megengine.jit import SublinearMemoryConfig | from megengine.jit import SublinearMemoryConfig | ||||
@@ -13,7 +13,7 @@ import pytest | |||||
import megengine.core.ops.builtin | import megengine.core.ops.builtin | ||||
import megengine.core.tensor.raw_tensor | import megengine.core.tensor.raw_tensor | ||||
from megengine.core._trace_option import use_tensor_shape | |||||
from megengine.core._trace_option import use_symbolic_shape | |||||
from megengine.core.ops._internal import all_ops | from megengine.core.ops._internal import all_ops | ||||
from megengine.core.tensor import Tensor | from megengine.core.tensor import Tensor | ||||
from megengine.core.tensor.core import apply | from megengine.core.tensor.core import apply | ||||
@@ -532,7 +532,7 @@ def test_advance_indexing_with_bool(): | |||||
np.testing.assert_equal(a, aa.numpy()) | np.testing.assert_equal(a, aa.numpy()) | ||||
# XXX: trace does not expect empty condtake tensor | # XXX: trace does not expect empty condtake tensor | ||||
if not use_tensor_shape(): | |||||
if not use_symbolic_shape(): | |||||
a = np.ones((2, 2), dtype=np.int32) | a = np.ones((2, 2), dtype=np.int32) | ||||
b = np.array([[False, False], [False, False]]) | b = np.array([[False, False], [False, False]]) | ||||
aa = Tensor(a) | aa = Tensor(a) | ||||
@@ -17,7 +17,7 @@ import megengine.core.ops.builtin as builtin | |||||
import megengine.core.tensor.dtype as dtype | import megengine.core.tensor.dtype as dtype | ||||
import megengine.functional as F | import megengine.functional as F | ||||
from megengine import Parameter, Tensor, is_cuda_available, tensor | from megengine import Parameter, Tensor, is_cuda_available, tensor | ||||
from megengine.core._trace_option import use_tensor_shape | |||||
from megengine.core._trace_option import use_symbolic_shape | |||||
from megengine.core.autodiff.grad import Grad | from megengine.core.autodiff.grad import Grad | ||||
from megengine.core.tensor.utils import make_shape_tuple | from megengine.core.tensor.utils import make_shape_tuple | ||||
@@ -15,7 +15,7 @@ from utils import opr_test | |||||
import megengine.functional as F | import megengine.functional as F | ||||
from megengine import tensor | from megengine import tensor | ||||
from megengine.core._trace_option import use_tensor_shape | |||||
from megengine.core._trace_option import use_symbolic_shape | |||||
from megengine.core.tensor.utils import astensor1d | from megengine.core.tensor.utils import astensor1d | ||||
from megengine.distributed.helper import get_device_count_by_fork | from megengine.distributed.helper import get_device_count_by_fork | ||||
@@ -16,7 +16,7 @@ import pytest | |||||
import megengine as mge | import megengine as mge | ||||
import megengine.distributed as dist | import megengine.distributed as dist | ||||
from megengine import Tensor | from megengine import Tensor | ||||
from megengine.core._trace_option import use_tensor_shape | |||||
from megengine.core._trace_option import use_symbolic_shape | |||||
from megengine.module import BatchNorm1d, BatchNorm2d, SyncBatchNorm | from megengine.module import BatchNorm1d, BatchNorm2d, SyncBatchNorm | ||||
_assert_allclose = functools.partial(np.testing.assert_allclose, atol=5e-6, rtol=5e-6) | _assert_allclose = functools.partial(np.testing.assert_allclose, atol=5e-6, rtol=5e-6) | ||||
@@ -15,7 +15,7 @@ import pytest | |||||
import megengine.core.tensor.megbrain_graph as G | import megengine.core.tensor.megbrain_graph as G | ||||
import megengine.functional as F | import megengine.functional as F | ||||
from megengine import cgtools, tensor | from megengine import cgtools, tensor | ||||
from megengine.core._trace_option import set_tensor_shape | |||||
from megengine.core._trace_option import set_symbolic_shape | |||||
from megengine.core.ops import builtin as ops | from megengine.core.ops import builtin as ops | ||||
from megengine.core.tensor.core import apply | from megengine.core.tensor.core import apply | ||||
from megengine.core.tensor.raw_tensor import as_raw_tensor | from megengine.core.tensor.raw_tensor import as_raw_tensor | ||||
@@ -238,7 +238,7 @@ def test_optimize_for_inference(): | |||||
def test_optimize_for_inference_broadcast(): | def test_optimize_for_inference_broadcast(): | ||||
a = tensor(np.ones(1, dtype=np.float32)) | a = tensor(np.ones(1, dtype=np.float32)) | ||||
@trace(capture_as_const=True, tensor_shape=True) | |||||
@trace(capture_as_const=True, symbolic_shape=True) | |||||
def f(): | def f(): | ||||
(b,) = apply(ops.Broadcast(), a, tensor([1, 10], dtype=np.int32)) | (b,) = apply(ops.Broadcast(), a, tensor([1, 10], dtype=np.int32)) | ||||
return b | return b | ||||
@@ -248,7 +248,7 @@ def test_optimize_for_inference_broadcast(): | |||||
def test_trace_cvt_bool(): | def test_trace_cvt_bool(): | ||||
set_tensor_shape(True) | |||||
set_symbolic_shape(True) | |||||
x = tensor([0], dtype=np.int32) | x = tensor([0], dtype=np.int32) | ||||
@trace(symbolic=True) | @trace(symbolic=True) | ||||
@@ -261,7 +261,7 @@ def test_trace_cvt_bool(): | |||||
def test_trace_reshape(): | def test_trace_reshape(): | ||||
for symbolic in [False, True]: | for symbolic in [False, True]: | ||||
set_tensor_shape(True) | |||||
set_symbolic_shape(True) | |||||
x1 = tensor(np.random.randn(2, 10, 10)) | x1 = tensor(np.random.randn(2, 10, 10)) | ||||
x2 = tensor(np.random.randn(4, 10, 10)) | x2 = tensor(np.random.randn(4, 10, 10)) | ||||
x3 = tensor(np.random.randn(8, 10, 10)) | x3 = tensor(np.random.randn(8, 10, 10)) | ||||
@@ -344,7 +344,7 @@ def test_raise_on_trace(): | |||||
def test_trace_broadcast(): | def test_trace_broadcast(): | ||||
for symbolic in [False, True]: | for symbolic in [False, True]: | ||||
set_tensor_shape(True) | |||||
set_symbolic_shape(True) | |||||
x1 = tensor(np.random.randn(3, 1, 1)) | x1 = tensor(np.random.randn(3, 1, 1)) | ||||
x2 = tensor(np.random.randn(1, 4, 1)) | x2 = tensor(np.random.randn(1, 4, 1)) | ||||
x3 = tensor(np.random.randn(1, 1, 5)) | x3 = tensor(np.random.randn(1, 1, 5)) | ||||
@@ -382,7 +382,7 @@ def test_trace_nms(): | |||||
def test_trace_valid_broadcast(): | def test_trace_valid_broadcast(): | ||||
set_tensor_shape(True) | |||||
set_symbolic_shape(True) | |||||
x1 = tensor(np.random.randn(1, 1)) | x1 = tensor(np.random.randn(1, 1)) | ||||
x2 = tensor(np.random.randn(1, 2)) | x2 = tensor(np.random.randn(1, 2)) | ||||
shape = (tensor([2]), tensor([2])) | shape = (tensor([2]), tensor([2])) | ||||