|
@@ -9,6 +9,7 @@ |
|
|
import collections |
|
|
import collections |
|
|
import functools |
|
|
import functools |
|
|
import itertools |
|
|
import itertools |
|
|
|
|
|
import weakref |
|
|
from typing import Union |
|
|
from typing import Union |
|
|
|
|
|
|
|
|
import numpy as np |
|
|
import numpy as np |
|
@@ -100,6 +101,14 @@ class MGBIndexWrapper: |
|
|
)(wrap_idx(idx)) |
|
|
)(wrap_idx(idx)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Guard: |
|
|
|
|
|
def __init__(self, deleter): |
|
|
|
|
|
self.deleter = deleter |
|
|
|
|
|
|
|
|
|
|
|
def __del__(self): |
|
|
|
|
|
self.deleter() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Tensor: |
|
|
class Tensor: |
|
|
r"""The main data container in MegEngine. |
|
|
r"""The main data container in MegEngine. |
|
|
Use :func:`~.tensor` to create a Tensor with existed data. |
|
|
Use :func:`~.tensor` to create a Tensor with existed data. |
|
@@ -111,6 +120,7 @@ class Tensor: |
|
|
self._reset(val, requires_grad=requires_grad) |
|
|
self._reset(val, requires_grad=requires_grad) |
|
|
|
|
|
|
|
|
def _reset(self, val=None, *, requires_grad=None): |
|
|
def _reset(self, val=None, *, requires_grad=None): |
|
|
|
|
|
self.__sym_override = None |
|
|
if val is None: |
|
|
if val is None: |
|
|
self.__val = None |
|
|
self.__val = None |
|
|
self.__sym = None |
|
|
self.__sym = None |
|
@@ -154,17 +164,20 @@ class Tensor: |
|
|
return self.numpy().item() |
|
|
return self.numpy().item() |
|
|
|
|
|
|
|
|
def _attach(self, comp_graph, *, volatile=True): |
|
|
def _attach(self, comp_graph, *, volatile=True): |
|
|
|
|
|
sym = self.__sym_override or self.__sym |
|
|
|
|
|
if sym: |
|
|
|
|
|
if sym.owner_graph != comp_graph: |
|
|
|
|
|
raise RuntimeError("internal error") |
|
|
|
|
|
return sym |
|
|
if self.__val: |
|
|
if self.__val: |
|
|
return self.__val.symvar(comp_graph, volatile=volatile) |
|
|
return self.__val.symvar(comp_graph, volatile=volatile) |
|
|
if self.__sym: |
|
|
|
|
|
if self.__sym.owner_graph != comp_graph: |
|
|
|
|
|
raise RuntimeError("internal error") |
|
|
|
|
|
return self.__sym |
|
|
|
|
|
else: |
|
|
else: |
|
|
raise ValueError("uninitialized") |
|
|
raise ValueError("uninitialized") |
|
|
|
|
|
|
|
|
@property |
|
|
@property |
|
|
def _symvar(self): |
|
|
def _symvar(self): |
|
|
|
|
|
if self.__sym_override: |
|
|
|
|
|
return self.__sym_override |
|
|
if self.__sym: |
|
|
if self.__sym: |
|
|
assert not self.__val |
|
|
assert not self.__val |
|
|
return self.__sym |
|
|
return self.__sym |
|
@@ -174,10 +187,26 @@ class Tensor: |
|
|
return self._attach(get_default_graph()) |
|
|
return self._attach(get_default_graph()) |
|
|
|
|
|
|
|
|
def __mgb_symvar__(self, comp_graph=None, **_): |
|
|
def __mgb_symvar__(self, comp_graph=None, **_): |
|
|
|
|
|
if self.__sym_override: |
|
|
|
|
|
return self.__sym_override |
|
|
if self.__val and comp_graph: |
|
|
if self.__val and comp_graph: |
|
|
return self._attach(comp_graph) |
|
|
return self._attach(comp_graph) |
|
|
return self._symvar # read by mgb.opr |
|
|
return self._symvar # read by mgb.opr |
|
|
|
|
|
|
|
|
|
|
|
def _override_symvar_during_trace(self, trace, symvar): |
|
|
|
|
|
assert self.__val and not self.__sym |
|
|
|
|
|
assert trace is type(trace)._active_instance |
|
|
|
|
|
deleters = trace._user_cache.setdefault(Tensor, set()) |
|
|
|
|
|
self_ref = weakref.ref(self) |
|
|
|
|
|
|
|
|
|
|
|
def restore(): |
|
|
|
|
|
self = self_ref() |
|
|
|
|
|
if self is not None: |
|
|
|
|
|
self.__sym_override = None |
|
|
|
|
|
|
|
|
|
|
|
deleters.add(Guard(restore)) |
|
|
|
|
|
self.__sym_override = symvar |
|
|
|
|
|
|
|
|
@property |
|
|
@property |
|
|
def dtype(self): |
|
|
def dtype(self): |
|
|
r"""Return the data type of the tensor. |
|
|
r"""Return the data type of the tensor. |
|
|