|
|
@@ -7,7 +7,7 @@ |
|
|
|
# software distributed under the License is distributed on an |
|
|
|
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
import collections |
|
|
|
from typing import Iterable, Optional |
|
|
|
from typing import Callable, Iterable, Optional, Tuple |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
|
|
|
@@ -35,16 +35,18 @@ class ParamPack(Module): |
|
|
|
nr_ignore_first: int = 8, |
|
|
|
max_size_per_group: int = 10, |
|
|
|
max_nr_params_per_group: int = 100, |
|
|
|
group_func: Callable = lambda name, param: 0, |
|
|
|
): |
|
|
|
super().__init__() |
|
|
|
self._model = model |
|
|
|
self._nr_ignore_first = nr_ignore_first |
|
|
|
self._max_size_per_group = max_size_per_group |
|
|
|
self._max_nr_params_per_group = max_nr_params_per_group |
|
|
|
self._group_func = group_func |
|
|
|
self._grouped_params = [] |
|
|
|
self._packed_params = [] |
|
|
|
|
|
|
|
params = model.parameters() |
|
|
|
params = model.named_parameters() |
|
|
|
self._pack_params(params) |
|
|
|
|
|
|
|
def parameters(self, requires_grad: Optional[bool] = None) -> Iterable[Parameter]: |
|
|
@@ -52,20 +54,33 @@ class ParamPack(Module): |
|
|
|
if requires_grad is None or param.requires_grad == requires_grad: |
|
|
|
yield param |
|
|
|
|
|
|
|
def _pack_params(self, params: Iterable[Parameter]): |
|
|
|
def named_parameters( |
|
|
|
self, requires_grad: Optional[bool] = None |
|
|
|
) -> Iterable[Tuple[str, Parameter]]: |
|
|
|
for idx, param in enumerate(self._packed_params): |
|
|
|
if requires_grad is None or param.requires_grad == requires_grad: |
|
|
|
yield "packed_param_" + str(idx), param |
|
|
|
|
|
|
|
def _pack_params(self, params: Iterable[Tuple[str, Parameter]]): |
|
|
|
groups = collections.defaultdict(list) |
|
|
|
ignored = 0 |
|
|
|
param_id = 0 |
|
|
|
for param in params: |
|
|
|
for name, param in params: |
|
|
|
if self._nr_ignore_first > ignored: |
|
|
|
ignored += 1 |
|
|
|
self._grouped_params.append([{"shape": param.shape, "id": param_id}]) |
|
|
|
param.pack_group_key = self._group_func(name, param) |
|
|
|
self._packed_params.append(param) |
|
|
|
else: |
|
|
|
key = (param.dtype, param.device, param.requires_grad) |
|
|
|
key = ( |
|
|
|
param.dtype, |
|
|
|
param.device, |
|
|
|
param.requires_grad, |
|
|
|
self._group_func(name, param), |
|
|
|
) |
|
|
|
groups[key].append({"tensor": param, "id": param_id}) |
|
|
|
param_id += 1 |
|
|
|
for (dtype, device, requires_grad) in groups.keys(): |
|
|
|
for (dtype, device, requires_grad, group_key) in groups.keys(): |
|
|
|
dtype_sz = np.dtype(dtype).itemsize |
|
|
|
align = device.mem_align |
|
|
|
if align < dtype_sz: |
|
|
@@ -74,7 +89,7 @@ class ParamPack(Module): |
|
|
|
assert align % dtype_sz == 0 |
|
|
|
align //= dtype_sz |
|
|
|
|
|
|
|
group = groups[(dtype, device, requires_grad)] |
|
|
|
group = groups[(dtype, device, requires_grad, group_key)] |
|
|
|
while group: |
|
|
|
aligned_pos = [] |
|
|
|
offset = 0 |
|
|
@@ -98,6 +113,7 @@ class ParamPack(Module): |
|
|
|
group = group[idx:] |
|
|
|
if idx == 1: |
|
|
|
# ignore param packs with only one item |
|
|
|
params[0]["tensor"].pack_group_key = group_key |
|
|
|
self._packed_params.append(params[0]["tensor"]) |
|
|
|
self._grouped_params.append( |
|
|
|
[{"shape": params[0]["tensor"].shape, "id": params[0]["id"]}] |
|
|
@@ -114,6 +130,7 @@ class ParamPack(Module): |
|
|
|
dtype=dtype, |
|
|
|
requires_grad=requires_grad, |
|
|
|
) |
|
|
|
new_param.pack_group_key = group_key |
|
|
|
self._packed_params.append(new_param) |
|
|
|
self._grouped_params.append( |
|
|
|
[{"shape": i["tensor"].shape, "id": i["id"]} for i in params] |
|
|
|