Browse Source

feat(mge): restore Function

GitOrigin-RevId: dd455238ba
release-1.2
Megvii Engine Team 4 years ago
parent
commit
b5e46ae92f
7 changed files with 35 additions and 13 deletions
  1. +6
    -5
      imperative/python/megengine/__init__.py
  2. +23
    -1
      imperative/python/megengine/core/autodiff/grad.py
  3. +1
    -1
      imperative/python/megengine/quantization/fake_quant.py
  4. +1
    -1
      imperative/python/megengine/quantization/internal_fake_quant.py
  5. +1
    -1
      imperative/python/megengine/quantization/utils.py
  6. +2
    -2
      imperative/python/test/unit/core/test_function.py
  7. +1
    -2
      imperative/python/test/unit/quantization/test_fake_quant.py

+ 6
- 5
imperative/python/megengine/__init__.py View File

@@ -6,11 +6,11 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import atexit
import ctypes
import os
import sys
import platform
import ctypes
import atexit
import sys

if sys.platform == "win32":
lib_path = os.path.join(os.path.dirname(__file__), "core/lib")
@@ -71,14 +71,15 @@ if sys.platform == "win32":

kernel32.SetErrorMode(old_error_mode)

from .core._imperative_rt.utils import _set_fork_exec_path_for_timed_func
from .core._imperative_rt.core2 import sync
from .core._imperative_rt.utils import _set_fork_exec_path_for_timed_func
from .device import *
from .logger import enable_debug_log, get_logger, set_log_file, set_log_level
from .serialization import load, save
from .tensor import Parameter, Tensor, tensor
from .utils import comp_graph_tools as cgtools
from .utils import persistent_cache
from .version import __version__
from .utils import persistent_cache, comp_graph_tools as cgtools

_set_fork_exec_path_for_timed_func(
sys.executable,


+ 23
- 1
imperative/python/megengine/core/autodiff/grad.py View File

@@ -16,7 +16,7 @@ import numpy as np

import megengine as mge

from .._imperative_rt import core2
from .._imperative_rt import core2, ops
from ..ops.builtin import Elemwise, OpDef, RemoteSend
from ..ops.special import Const
from ..tensor.core import TensorBase, TensorWrapperBase, apply
@@ -211,3 +211,25 @@ class Grad:

def __exit__(self, _1, _2, _3):
del self._impl


class Function(ops.PyOpBase):
def _default_rule(self, *args):
ret = self.forward(*args)
self.__single_output = isinstance(ret, core2.Tensor)
return ret

def _grad_rule(self, *args):
return self._default_rule(*args), self.backward

def __call__(self, *args):
ret = core2.apply(self, *args)
if self.__single_output:
(ret,) = ret
return ret

def __getstate__(self):
return self.__dict__

def __setstate__(self, state):
self.__dict__.update(state)

+ 1
- 1
imperative/python/megengine/quantization/fake_quant.py View File

@@ -11,8 +11,8 @@ from typing import Iterable
import numpy as np

from .. import functional as F
from ..core.autodiff.grad import Function
from ..core.tensor.dtype import _metadata_dict, get_quantized_dtype
from ..core.tensor.function import Function
from ..module import Module
from ..tensor import Parameter, Tensor
from .utils import QuantMode, fake_quant_tensor, get_qparam_dict


+ 1
- 1
imperative/python/megengine/quantization/internal_fake_quant.py View File

@@ -12,7 +12,7 @@ from functools import partial
import numpy as np

from .. import functional as F
from ..core.tensor.function import Function
from ..core.autodiff.grad import Function
from .fake_quant import _FakeQuantize
from .observer import MinMaxObserver
from .qconfig import QConfig


+ 1
- 1
imperative/python/megengine/quantization/utils.py View File

@@ -12,11 +12,11 @@ from typing import Dict
import numpy as np

from .. import functional as F
from ..core.autodiff.grad import Function
from ..core.ops import builtin
from ..core.tensor import megbrain_graph
from ..core.tensor.core import apply
from ..core.tensor.dtype import _metadata_dict
from ..core.tensor.function import Function
from ..tensor import Tensor




+ 2
- 2
imperative/python/test/unit/core/test_function.py View File

@@ -15,7 +15,7 @@ import megengine.optimizer as optimizer
from megengine import Parameter
from megengine import Tensor as tensor
from megengine import tensor
from megengine.core.tensor.function import Function
from megengine.core.autodiff.grad import Function
from megengine.module import Module


@@ -239,7 +239,7 @@ def test_none_in_out_grad():

def backward(self, grad_a, grad_b):
assert grad_b is None
return (grad_a, 0.0)
return (grad_a, None)

class Simple(Module):
def __init__(self, a, b):


+ 1
- 2
imperative/python/test/unit/quantization/test_fake_quant.py View File

@@ -11,8 +11,7 @@ import pytest

import megengine as mge
from megengine import tensor
from megengine.core.autodiff.grad import Grad
from megengine.core.tensor.function import Function
from megengine.core.autodiff.grad import Function, Grad
from megengine.core.tensor.utils import make_shape_tuple
from megengine.quantization.fake_quant import TQT_Function
from megengine.quantization.internal_fake_quant import *


Loading…
Cancel
Save