Browse Source

fix(mge/jit): fix add_update semantic

GitOrigin-RevId: f541ac7c6d
tags/v0.5.0
Megvii Engine Team 5 years ago
parent
commit
b3b14fdfe1
4 changed files with 53 additions and 5 deletions
  1. +33
    -4
      python_module/megengine/core/tensor.py
  2. +4
    -1
      python_module/megengine/functional/graph.py
  3. +2
    -0
      python_module/megengine/jit/__init__.py
  4. +14
    -0
      python_module/test/unit/jit/test_jit.py

+ 33
- 4
python_module/megengine/core/tensor.py View File

@@ -9,6 +9,7 @@
import collections
import functools
import itertools
import weakref
from typing import Union

import numpy as np
@@ -100,6 +101,14 @@ class MGBIndexWrapper:
)(wrap_idx(idx))


class Guard:
def __init__(self, deleter):
self.deleter = deleter

def __del__(self):
self.deleter()


class Tensor:
r"""The main data container in MegEngine.
Use :func:`~.tensor` to create a Tensor with existed data.
@@ -111,6 +120,7 @@ class Tensor:
self._reset(val, requires_grad=requires_grad)

def _reset(self, val=None, *, requires_grad=None):
self.__sym_override = None
if val is None:
self.__val = None
self.__sym = None
@@ -154,17 +164,20 @@ class Tensor:
return self.numpy().item()

def _attach(self, comp_graph, *, volatile=True):
sym = self.__sym_override or self.__sym
if sym:
if sym.owner_graph != comp_graph:
raise RuntimeError("internal error")
return sym
if self.__val:
return self.__val.symvar(comp_graph, volatile=volatile)
if self.__sym:
if self.__sym.owner_graph != comp_graph:
raise RuntimeError("internal error")
return self.__sym
else:
raise ValueError("uninitialized")

@property
def _symvar(self):
if self.__sym_override:
return self.__sym_override
if self.__sym:
assert not self.__val
return self.__sym
@@ -174,10 +187,26 @@ class Tensor:
return self._attach(get_default_graph())

def __mgb_symvar__(self, comp_graph=None, **_):
if self.__sym_override:
return self.__sym_override
if self.__val and comp_graph:
return self._attach(comp_graph)
return self._symvar # read by mgb.opr

def _override_symvar_during_trace(self, trace, symvar):
assert self.__val and not self.__sym
assert trace is type(trace)._active_instance
deleters = trace._user_cache.setdefault(Tensor, set())
self_ref = weakref.ref(self)

def restore():
self = self_ref()
if self is not None:
self.__sym_override = None

deleters.add(Guard(restore))
self.__sym_override = symvar

@property
def dtype(self):
r"""Return the data type of the tensor.


+ 4
- 1
python_module/megengine/functional/graph.py View File

@@ -13,7 +13,7 @@ import megengine._internal as mgb

from ..core.graph import get_default_graph
from ..core.tensor import Tensor, wrap_io_tensor
from ..jit import barrier, mark_impure
from ..jit import barrier, mark_impure, trace


@wrap_io_tensor
@@ -112,6 +112,9 @@ def add_update(
)
mark_impure(u)

if trace._active_instance:
dest._override_symvar_during_trace(trace._active_instance, u)

return Tensor(u)




+ 2
- 0
python_module/megengine/jit/__init__.py View File

@@ -367,10 +367,12 @@ class trace:
raise RuntimeError("nested trace is unsupported")
self._status = self._STARTED
type(self)._active_instance = self
self._user_cache = {}
try:
yield
finally:
self._status = self._FINISHED
self._user_cache = None
type(self)._active_instance = None

def _run_wrapped(self):


+ 14
- 0
python_module/test/unit/jit/test_jit.py View File

@@ -16,6 +16,7 @@ import pytest
import megengine as mge
import megengine._internal as mgb
import megengine.module as M
from megengine import functional as F
from megengine import jit, tensor
from megengine.core.tensor import Tensor
from megengine.jit import SublinearMemoryConfig
@@ -57,6 +58,19 @@ def test_symbolic():
f.trace(0)


def test_add_update_semantic():
for symbolic in [False, True]:
x = tensor(0)

@jit.trace(symbolic=symbolic)
def f():
F.add_update(x, 1)
return x + 1

np.testing.assert_equal(f().numpy(), [2])
np.testing.assert_equal(f().numpy(), [3])


def test_dump():
@jit.trace(symbolic=True)
def f(x, y):


Loading…
Cancel
Save