Browse Source

feat(mge/parampack): add user-defined key to pack params

GitOrigin-RevId: 7d51dcae23
tags/v0.4.0
Megvii Engine Team Xinran Xu 5 years ago
parent
commit
99fae95e02
2 changed files with 39 additions and 7 deletions
  1. +24
    -7
      python_module/megengine/module/parampack.py
  2. +15
    -0
      python_module/test/integration/test_parampack.py

+ 24
- 7
python_module/megengine/module/parampack.py View File

@@ -7,7 +7,7 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections
from typing import Iterable, Optional
from typing import Callable, Iterable, Optional, Tuple

import numpy as np

@@ -35,16 +35,18 @@ class ParamPack(Module):
nr_ignore_first: int = 8,
max_size_per_group: int = 10,
max_nr_params_per_group: int = 100,
group_func: Callable = lambda name, param: 0,
):
super().__init__()
self._model = model
self._nr_ignore_first = nr_ignore_first
self._max_size_per_group = max_size_per_group
self._max_nr_params_per_group = max_nr_params_per_group
self._group_func = group_func
self._grouped_params = []
self._packed_params = []

params = model.parameters()
params = model.named_parameters()
self._pack_params(params)

def parameters(self, requires_grad: Optional[bool] = None) -> Iterable[Parameter]:
@@ -52,20 +54,33 @@ class ParamPack(Module):
if requires_grad is None or param.requires_grad == requires_grad:
yield param

def _pack_params(self, params: Iterable[Parameter]):
def named_parameters(
self, requires_grad: Optional[bool] = None
) -> Iterable[Tuple[str, Parameter]]:
for idx, param in enumerate(self._packed_params):
if requires_grad is None or param.requires_grad == requires_grad:
yield "packed_param_" + str(idx), param

def _pack_params(self, params: Iterable[Tuple[str, Parameter]]):
groups = collections.defaultdict(list)
ignored = 0
param_id = 0
for param in params:
for name, param in params:
if self._nr_ignore_first > ignored:
ignored += 1
self._grouped_params.append([{"shape": param.shape, "id": param_id}])
param.pack_group_key = self._group_func(name, param)
self._packed_params.append(param)
else:
key = (param.dtype, param.device, param.requires_grad)
key = (
param.dtype,
param.device,
param.requires_grad,
self._group_func(name, param),
)
groups[key].append({"tensor": param, "id": param_id})
param_id += 1
for (dtype, device, requires_grad) in groups.keys():
for (dtype, device, requires_grad, group_key) in groups.keys():
dtype_sz = np.dtype(dtype).itemsize
align = device.mem_align
if align < dtype_sz:
@@ -74,7 +89,7 @@ class ParamPack(Module):
assert align % dtype_sz == 0
align //= dtype_sz

group = groups[(dtype, device, requires_grad)]
group = groups[(dtype, device, requires_grad, group_key)]
while group:
aligned_pos = []
offset = 0
@@ -98,6 +113,7 @@ class ParamPack(Module):
group = group[idx:]
if idx == 1:
# ignore param packs with only one item
params[0]["tensor"].pack_group_key = group_key
self._packed_params.append(params[0]["tensor"])
self._grouped_params.append(
[{"shape": params[0]["tensor"].shape, "id": params[0]["id"]}]
@@ -114,6 +130,7 @@ class ParamPack(Module):
dtype=dtype,
requires_grad=requires_grad,
)
new_param.pack_group_key = group_key
self._packed_params.append(new_param)
self._grouped_params.append(
[{"shape": i["tensor"].shape, "id": i["id"]} for i in params]


+ 15
- 0
python_module/test/integration/test_parampack.py View File

@@ -257,3 +257,18 @@ def test_correctness_parampack():
pred1 = infer1(data).numpy()
pred2 = infer2(data).numpy()
assert np.allclose(pred1, pred2)


def test_parampack_group_func():
net = XORNet()
net = ParamPack(
net,
nr_ignore_first=1,
max_size_per_group=10,
max_nr_params_per_group=100,
group_func=lambda n, p: "weight" in n,
)
for p in net.parameters(requires_grad=True):
assert p.pack_group_key is not None
for n, p in net.named_parameters(requires_grad=True):
assert p.pack_group_key is not None

Loading…
Cancel
Save