|
|
@@ -14,6 +14,7 @@ import numpy as np |
|
|
|
from .._internal.dtype import is_quantize |
|
|
|
from ..core import Buffer, Parameter, Tensor |
|
|
|
from ..logger import get_logger |
|
|
|
from ..utils.hook import HookHandler |
|
|
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
|
@@ -57,19 +58,51 @@ class Module(metaclass=ABCMeta): |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self): |
|
|
|
# runtime attributes |
|
|
|
self.training = True |
|
|
|
self.quantize_diabled = False |
|
|
|
|
|
|
|
# hooks |
|
|
|
self._forward_pre_hooks = OrderedDict() |
|
|
|
self._forward_hooks = OrderedDict() |
|
|
|
|
|
|
|
@abstractmethod |
|
|
|
def forward(self, inputs): |
|
|
|
pass |
|
|
|
|
|
|
|
def register_forward_pre_hook(self, hook: Callable) -> HookHandler: |
|
|
|
"""Register a hook to handle forward inputs. `hook` should be a function |
|
|
|
|
|
|
|
Note that `inputs` keyword inputs |
|
|
|
|
|
|
|
:param hook: a function that receive `module` and `inputs`, then return |
|
|
|
a modified `inputs` or `None`. |
|
|
|
:return: a handler with :meth:`~.HookHandler.remove` interface to delete the hook. |
|
|
|
""" |
|
|
|
return HookHandler(self._forward_pre_hooks, hook) |
|
|
|
|
|
|
|
def register_forward_hook(self, hook: Callable) -> HookHandler: |
|
|
|
"""Register a hook to handle forward results. `hook` should be a function that |
|
|
|
receive `module`, `inputs` and `outputs`, then return a modified `outputs` or `None`. |
|
|
|
|
|
|
|
This method return a handler with :meth:`~.HookHandler.remove` interface to delete the hook. |
|
|
|
""" |
|
|
|
return HookHandler(self._forward_hooks, hook) |
|
|
|
|
|
|
|
def __call__(self, *inputs, **kwargs): |
|
|
|
# ToDo: Convert numpy or scalar |
|
|
|
# Maybe ToDo: set training phase |
|
|
|
# Maybe ToDo: set computing graph |
|
|
|
for hook in self._forward_pre_hooks.values(): |
|
|
|
modified_inputs = hook(self, inputs) |
|
|
|
if modified_inputs is not None: |
|
|
|
if not isinstance(modified_inputs, tuple): |
|
|
|
modified_inputs = (modified_inputs,) |
|
|
|
inputs = modified_inputs |
|
|
|
|
|
|
|
outputs = self.forward(*inputs, **kwargs) |
|
|
|
# Maybe ToDo: set connectivity metadata |
|
|
|
|
|
|
|
for hook in self._forward_hooks.values(): |
|
|
|
modified_outputs = hook(self, inputs, outputs) |
|
|
|
if modified_outputs is not None: |
|
|
|
outputs = modified_outputs |
|
|
|
return outputs |
|
|
|
|
|
|
|
def _flatten( |
|
|
@@ -191,29 +224,6 @@ class Module(metaclass=ABCMeta): |
|
|
|
with_key=False, predicate=_is_buffer, recursive=recursive, **kwargs |
|
|
|
) |
|
|
|
|
|
|
|
def replace_param( |
|
|
|
self, params: dict, start_pos: int, seen: Optional[Set[int]] = None |
|
|
|
): |
|
|
|
offset = 0 |
|
|
|
if seen is None: |
|
|
|
seen = set([id(self)]) |
|
|
|
module_dict = vars(self) |
|
|
|
for key in sorted(module_dict): |
|
|
|
hash_id = id(module_dict[key]) |
|
|
|
if hash_id in seen: |
|
|
|
continue |
|
|
|
seen.add(hash_id) |
|
|
|
if isinstance(module_dict[key], Parameter): |
|
|
|
if start_pos + offset in params: |
|
|
|
assert module_dict[key].shape == params[start_pos + offset].shape |
|
|
|
module_dict[key] = params[start_pos + offset] |
|
|
|
offset += 1 |
|
|
|
if isinstance(module_dict[key], Module): |
|
|
|
offset += module_dict[key].replace_param( |
|
|
|
params, start_pos + offset, seen |
|
|
|
) |
|
|
|
return offset |
|
|
|
|
|
|
|
def named_buffers( |
|
|
|
self, prefix: Optional[str] = None, recursive: bool = True, **kwargs |
|
|
|
) -> Iterable[Tuple[str, Buffer]]: |
|
|
@@ -327,6 +337,32 @@ class Module(metaclass=ABCMeta): |
|
|
|
|
|
|
|
self.apply(fn) |
|
|
|
|
|
|
|
def replace_param( |
|
|
|
self, params: dict, start_pos: int, seen: Optional[Set[int]] = None |
|
|
|
): |
|
|
|
"""Replace module's parameters with `params`, used by :class:`~.ParamPack` to |
|
|
|
speedup multimachine training. |
|
|
|
""" |
|
|
|
offset = 0 |
|
|
|
if seen is None: |
|
|
|
seen = set([id(self)]) |
|
|
|
module_dict = vars(self) |
|
|
|
for key in sorted(module_dict): |
|
|
|
hash_id = id(module_dict[key]) |
|
|
|
if hash_id in seen: |
|
|
|
continue |
|
|
|
seen.add(hash_id) |
|
|
|
if isinstance(module_dict[key], Parameter): |
|
|
|
if start_pos + offset in params: |
|
|
|
assert module_dict[key].shape == params[start_pos + offset].shape |
|
|
|
module_dict[key] = params[start_pos + offset] |
|
|
|
offset += 1 |
|
|
|
if isinstance(module_dict[key], Module): |
|
|
|
offset += module_dict[key].replace_param( |
|
|
|
params, start_pos + offset, seen |
|
|
|
) |
|
|
|
return offset |
|
|
|
|
|
|
|
def state_dict(self, rst=None, prefix="", keep_var=False): |
|
|
|
r"""Returns a dictionary containing whole states of the module. |
|
|
|
""" |
|
|
|