Browse Source

feat(mge/module): add forward hook support

GitOrigin-RevId: c0db58df13
tags/v1.0.0-rc1
Megvii Engine Team Xinran Xu 4 years ago
parent
commit
13e8f00a37
3 changed files with 139 additions and 28 deletions
  1. +63
    -27
      python_module/megengine/module/module.py
  2. +23
    -0
      python_module/megengine/utils/hook.py
  3. +53
    -1
      python_module/test/unit/module/test_module.py

+ 63
- 27
python_module/megengine/module/module.py View File

@@ -14,6 +14,7 @@ import numpy as np
from .._internal.dtype import is_quantize
from ..core import Buffer, Parameter, Tensor
from ..logger import get_logger
from ..utils.hook import HookHandler

logger = get_logger(__name__)

@@ -57,19 +58,51 @@ class Module(metaclass=ABCMeta):
"""

def __init__(self):
# runtime attributes
self.training = True
self.quantize_diabled = False

# hooks
self._forward_pre_hooks = OrderedDict()
self._forward_hooks = OrderedDict()

@abstractmethod
def forward(self, inputs):
pass

def register_forward_pre_hook(self, hook: Callable) -> HookHandler:
"""Register a hook to handle forward inputs. `hook` should be a function

Note that `inputs` keyword inputs

:param hook: a function that receive `module` and `inputs`, then return
a modified `inputs` or `None`.
:return: a handler with :meth:`~.HookHandler.remove` interface to delete the hook.
"""
return HookHandler(self._forward_pre_hooks, hook)

def register_forward_hook(self, hook: Callable) -> HookHandler:
"""Register a hook to handle forward results. `hook` should be a function that
receive `module`, `inputs` and `outputs`, then return a modified `outputs` or `None`.

This method return a handler with :meth:`~.HookHandler.remove` interface to delete the hook.
"""
return HookHandler(self._forward_hooks, hook)

def __call__(self, *inputs, **kwargs):
# ToDo: Convert numpy or scalar
# Maybe ToDo: set training phase
# Maybe ToDo: set computing graph
for hook in self._forward_pre_hooks.values():
modified_inputs = hook(self, inputs)
if modified_inputs is not None:
if not isinstance(modified_inputs, tuple):
modified_inputs = (modified_inputs,)
inputs = modified_inputs

outputs = self.forward(*inputs, **kwargs)
# Maybe ToDo: set connectivity metadata

for hook in self._forward_hooks.values():
modified_outputs = hook(self, inputs, outputs)
if modified_outputs is not None:
outputs = modified_outputs
return outputs

def _flatten(
@@ -191,29 +224,6 @@ class Module(metaclass=ABCMeta):
with_key=False, predicate=_is_buffer, recursive=recursive, **kwargs
)

def replace_param(
self, params: dict, start_pos: int, seen: Optional[Set[int]] = None
):
offset = 0
if seen is None:
seen = set([id(self)])
module_dict = vars(self)
for key in sorted(module_dict):
hash_id = id(module_dict[key])
if hash_id in seen:
continue
seen.add(hash_id)
if isinstance(module_dict[key], Parameter):
if start_pos + offset in params:
assert module_dict[key].shape == params[start_pos + offset].shape
module_dict[key] = params[start_pos + offset]
offset += 1
if isinstance(module_dict[key], Module):
offset += module_dict[key].replace_param(
params, start_pos + offset, seen
)
return offset

def named_buffers(
self, prefix: Optional[str] = None, recursive: bool = True, **kwargs
) -> Iterable[Tuple[str, Buffer]]:
@@ -327,6 +337,32 @@ class Module(metaclass=ABCMeta):

self.apply(fn)

def replace_param(
self, params: dict, start_pos: int, seen: Optional[Set[int]] = None
):
"""Replace module's parameters with `params`, used by :class:`~.ParamPack` to
speedup multimachine training.
"""
offset = 0
if seen is None:
seen = set([id(self)])
module_dict = vars(self)
for key in sorted(module_dict):
hash_id = id(module_dict[key])
if hash_id in seen:
continue
seen.add(hash_id)
if isinstance(module_dict[key], Parameter):
if start_pos + offset in params:
assert module_dict[key].shape == params[start_pos + offset].shape
module_dict[key] = params[start_pos + offset]
offset += 1
if isinstance(module_dict[key], Module):
offset += module_dict[key].replace_param(
params, start_pos + offset, seen
)
return offset

def state_dict(self, rst=None, prefix="", keep_var=False):
r"""Returns a dictionary containing whole states of the module.
"""


+ 23
- 0
python_module/megengine/utils/hook.py View File

@@ -0,0 +1,23 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import weakref


class HookHandler:
hook_num = 0

def __init__(self, source_dict, hook):
self.id = HookHandler.hook_num
HookHandler.hook_num += 1
source_dict[self.id] = hook
self.source_ref = weakref.ref(source_dict)

def remove(self):
source_dict = self.source_ref()
if source_dict is not None and self.id in source_dict:
del source_dict[self.id]

+ 53
- 1
python_module/test/unit/module/test_module.py View File

@@ -17,6 +17,7 @@ from helpers import MLP

import megengine as mge
import megengine._internal as mgb
import megengine.functional as F
from megengine.core import Buffer, Parameter, Tensor, tensor
from megengine.module import (
BatchNorm1d,
@@ -37,7 +38,7 @@ class MyModule(Module):
self.bn = BatchNorm2d(4)

def forward(self, x):
x = self.bn(x)
return self.bn(x)

def __init__(self):
super().__init__()
@@ -145,6 +146,57 @@ def test_module_api_iterable_stability():
assert list(m.modules()) == l


def test_module_api_hooks():
net = MyModule()
pre_hook_num = 0
post_hook_num = 0
hooks = []

def pre_hook(module, inputs):
nonlocal pre_hook_num
pre_hook_num += 1
modified_inputs = tuple(inp + 1 for inp in inputs)
return modified_inputs

def post_hook(module, inputs, outputs):
nonlocal post_hook_num
post_hook_num += 1
outputs += 1
return outputs

net.apply(lambda module: hooks.append(module.register_forward_pre_hook(pre_hook)))
net.apply(lambda module: hooks.append(module.register_forward_hook(post_hook)))

shape = (1, 4, 1, 1)
x = tensor(np.zeros(shape, dtype=np.float32))
y = net(x)

assert pre_hook_num == 4
assert post_hook_num == 4
mean1 = Parameter(np.zeros(shape), dtype=np.float32)
bn1 = F.batch_norm2d(
x + 3, mean1, Parameter(np.ones(shape), dtype=np.float32), training=True
)
assertTensorClose(
net.i.bn.running_mean, mean1,
)
mean2 = Parameter(np.zeros(shape), dtype=np.float32)
bn2 = F.batch_norm2d(
bn1 + 3, mean2, Parameter(np.ones(shape), dtype=np.float32), training=True
)
assertTensorClose(
net.bn.running_mean, mean2,
)
assertTensorClose(bn2 + 2, y)

assert len(hooks) == 8
for handler in hooks:
handler.remove()
y = net(x)
assert pre_hook_num == 4
assert post_hook_num == 4


class MyModule2(Module):
class InnerModule(Module):
def __init__(self):


Loading…
Cancel
Save