Browse Source

fix(mge/jit): fix add_update semantic

GitOrigin-RevId: f541ac7c6d
tags/v0.5.0
Megvii Engine Team 5 years ago
parent
commit
b3b14fdfe1
4 changed files with 53 additions and 5 deletions
  1. +33
    -4
      python_module/megengine/core/tensor.py
  2. +4
    -1
      python_module/megengine/functional/graph.py
  3. +2
    -0
      python_module/megengine/jit/__init__.py
  4. +14
    -0
      python_module/test/unit/jit/test_jit.py

+ 33
- 4
python_module/megengine/core/tensor.py View File

@@ -9,6 +9,7 @@
import collections import collections
import functools import functools
import itertools import itertools
import weakref
from typing import Union from typing import Union


import numpy as np import numpy as np
@@ -100,6 +101,14 @@ class MGBIndexWrapper:
)(wrap_idx(idx)) )(wrap_idx(idx))




class Guard:
def __init__(self, deleter):
self.deleter = deleter

def __del__(self):
self.deleter()


class Tensor: class Tensor:
r"""The main data container in MegEngine. r"""The main data container in MegEngine.
Use :func:`~.tensor` to create a Tensor with existed data. Use :func:`~.tensor` to create a Tensor with existed data.
@@ -111,6 +120,7 @@ class Tensor:
self._reset(val, requires_grad=requires_grad) self._reset(val, requires_grad=requires_grad)


def _reset(self, val=None, *, requires_grad=None): def _reset(self, val=None, *, requires_grad=None):
self.__sym_override = None
if val is None: if val is None:
self.__val = None self.__val = None
self.__sym = None self.__sym = None
@@ -154,17 +164,20 @@ class Tensor:
return self.numpy().item() return self.numpy().item()


def _attach(self, comp_graph, *, volatile=True): def _attach(self, comp_graph, *, volatile=True):
sym = self.__sym_override or self.__sym
if sym:
if sym.owner_graph != comp_graph:
raise RuntimeError("internal error")
return sym
if self.__val: if self.__val:
return self.__val.symvar(comp_graph, volatile=volatile) return self.__val.symvar(comp_graph, volatile=volatile)
if self.__sym:
if self.__sym.owner_graph != comp_graph:
raise RuntimeError("internal error")
return self.__sym
else: else:
raise ValueError("uninitialized") raise ValueError("uninitialized")


@property @property
def _symvar(self): def _symvar(self):
if self.__sym_override:
return self.__sym_override
if self.__sym: if self.__sym:
assert not self.__val assert not self.__val
return self.__sym return self.__sym
@@ -174,10 +187,26 @@ class Tensor:
return self._attach(get_default_graph()) return self._attach(get_default_graph())


def __mgb_symvar__(self, comp_graph=None, **_): def __mgb_symvar__(self, comp_graph=None, **_):
if self.__sym_override:
return self.__sym_override
if self.__val and comp_graph: if self.__val and comp_graph:
return self._attach(comp_graph) return self._attach(comp_graph)
return self._symvar # read by mgb.opr return self._symvar # read by mgb.opr


def _override_symvar_during_trace(self, trace, symvar):
assert self.__val and not self.__sym
assert trace is type(trace)._active_instance
deleters = trace._user_cache.setdefault(Tensor, set())
self_ref = weakref.ref(self)

def restore():
self = self_ref()
if self is not None:
self.__sym_override = None

deleters.add(Guard(restore))
self.__sym_override = symvar

@property @property
def dtype(self): def dtype(self):
r"""Return the data type of the tensor. r"""Return the data type of the tensor.


+ 4
- 1
python_module/megengine/functional/graph.py View File

@@ -13,7 +13,7 @@ import megengine._internal as mgb


from ..core.graph import get_default_graph from ..core.graph import get_default_graph
from ..core.tensor import Tensor, wrap_io_tensor from ..core.tensor import Tensor, wrap_io_tensor
from ..jit import barrier, mark_impure
from ..jit import barrier, mark_impure, trace




@wrap_io_tensor @wrap_io_tensor
@@ -112,6 +112,9 @@ def add_update(
) )
mark_impure(u) mark_impure(u)


if trace._active_instance:
dest._override_symvar_during_trace(trace._active_instance, u)

return Tensor(u) return Tensor(u)






+ 2
- 0
python_module/megengine/jit/__init__.py View File

@@ -367,10 +367,12 @@ class trace:
raise RuntimeError("nested trace is unsupported") raise RuntimeError("nested trace is unsupported")
self._status = self._STARTED self._status = self._STARTED
type(self)._active_instance = self type(self)._active_instance = self
self._user_cache = {}
try: try:
yield yield
finally: finally:
self._status = self._FINISHED self._status = self._FINISHED
self._user_cache = None
type(self)._active_instance = None type(self)._active_instance = None


def _run_wrapped(self): def _run_wrapped(self):


+ 14
- 0
python_module/test/unit/jit/test_jit.py View File

@@ -16,6 +16,7 @@ import pytest
import megengine as mge import megengine as mge
import megengine._internal as mgb import megengine._internal as mgb
import megengine.module as M import megengine.module as M
from megengine import functional as F
from megengine import jit, tensor from megengine import jit, tensor
from megengine.core.tensor import Tensor from megengine.core.tensor import Tensor
from megengine.jit import SublinearMemoryConfig from megengine.jit import SublinearMemoryConfig
@@ -57,6 +58,19 @@ def test_symbolic():
f.trace(0) f.trace(0)




def test_add_update_semantic():
for symbolic in [False, True]:
x = tensor(0)

@jit.trace(symbolic=symbolic)
def f():
F.add_update(x, 1)
return x + 1

np.testing.assert_equal(f().numpy(), [2])
np.testing.assert_equal(f().numpy(), [3])


def test_dump(): def test_dump():
@jit.trace(symbolic=True) @jit.trace(symbolic=True)
def f(x, y): def f(x, y):


Loading…
Cancel
Save