Browse Source

feat(mge/parampack): add user-defined key to pack params

GitOrigin-RevId: 7d51dcae23
tags/v0.4.0
Megvii Engine Team Xinran Xu 5 years ago
parent
commit
99fae95e02
2 changed files with 39 additions and 7 deletions
  1. +24
    -7
      python_module/megengine/module/parampack.py
  2. +15
    -0
      python_module/test/integration/test_parampack.py

+ 24
- 7
python_module/megengine/module/parampack.py View File

@@ -7,7 +7,7 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections import collections
from typing import Iterable, Optional
from typing import Callable, Iterable, Optional, Tuple


import numpy as np import numpy as np


@@ -35,16 +35,18 @@ class ParamPack(Module):
nr_ignore_first: int = 8, nr_ignore_first: int = 8,
max_size_per_group: int = 10, max_size_per_group: int = 10,
max_nr_params_per_group: int = 100, max_nr_params_per_group: int = 100,
group_func: Callable = lambda name, param: 0,
): ):
super().__init__() super().__init__()
self._model = model self._model = model
self._nr_ignore_first = nr_ignore_first self._nr_ignore_first = nr_ignore_first
self._max_size_per_group = max_size_per_group self._max_size_per_group = max_size_per_group
self._max_nr_params_per_group = max_nr_params_per_group self._max_nr_params_per_group = max_nr_params_per_group
self._group_func = group_func
self._grouped_params = [] self._grouped_params = []
self._packed_params = [] self._packed_params = []


params = model.parameters()
params = model.named_parameters()
self._pack_params(params) self._pack_params(params)


def parameters(self, requires_grad: Optional[bool] = None) -> Iterable[Parameter]: def parameters(self, requires_grad: Optional[bool] = None) -> Iterable[Parameter]:
@@ -52,20 +54,33 @@ class ParamPack(Module):
if requires_grad is None or param.requires_grad == requires_grad: if requires_grad is None or param.requires_grad == requires_grad:
yield param yield param


def _pack_params(self, params: Iterable[Parameter]):
def named_parameters(
self, requires_grad: Optional[bool] = None
) -> Iterable[Tuple[str, Parameter]]:
for idx, param in enumerate(self._packed_params):
if requires_grad is None or param.requires_grad == requires_grad:
yield "packed_param_" + str(idx), param

def _pack_params(self, params: Iterable[Tuple[str, Parameter]]):
groups = collections.defaultdict(list) groups = collections.defaultdict(list)
ignored = 0 ignored = 0
param_id = 0 param_id = 0
for param in params:
for name, param in params:
if self._nr_ignore_first > ignored: if self._nr_ignore_first > ignored:
ignored += 1 ignored += 1
self._grouped_params.append([{"shape": param.shape, "id": param_id}]) self._grouped_params.append([{"shape": param.shape, "id": param_id}])
param.pack_group_key = self._group_func(name, param)
self._packed_params.append(param) self._packed_params.append(param)
else: else:
key = (param.dtype, param.device, param.requires_grad)
key = (
param.dtype,
param.device,
param.requires_grad,
self._group_func(name, param),
)
groups[key].append({"tensor": param, "id": param_id}) groups[key].append({"tensor": param, "id": param_id})
param_id += 1 param_id += 1
for (dtype, device, requires_grad) in groups.keys():
for (dtype, device, requires_grad, group_key) in groups.keys():
dtype_sz = np.dtype(dtype).itemsize dtype_sz = np.dtype(dtype).itemsize
align = device.mem_align align = device.mem_align
if align < dtype_sz: if align < dtype_sz:
@@ -74,7 +89,7 @@ class ParamPack(Module):
assert align % dtype_sz == 0 assert align % dtype_sz == 0
align //= dtype_sz align //= dtype_sz


group = groups[(dtype, device, requires_grad)]
group = groups[(dtype, device, requires_grad, group_key)]
while group: while group:
aligned_pos = [] aligned_pos = []
offset = 0 offset = 0
@@ -98,6 +113,7 @@ class ParamPack(Module):
group = group[idx:] group = group[idx:]
if idx == 1: if idx == 1:
# ignore param packs with only one item # ignore param packs with only one item
params[0]["tensor"].pack_group_key = group_key
self._packed_params.append(params[0]["tensor"]) self._packed_params.append(params[0]["tensor"])
self._grouped_params.append( self._grouped_params.append(
[{"shape": params[0]["tensor"].shape, "id": params[0]["id"]}] [{"shape": params[0]["tensor"].shape, "id": params[0]["id"]}]
@@ -114,6 +130,7 @@ class ParamPack(Module):
dtype=dtype, dtype=dtype,
requires_grad=requires_grad, requires_grad=requires_grad,
) )
new_param.pack_group_key = group_key
self._packed_params.append(new_param) self._packed_params.append(new_param)
self._grouped_params.append( self._grouped_params.append(
[{"shape": i["tensor"].shape, "id": i["id"]} for i in params] [{"shape": i["tensor"].shape, "id": i["id"]} for i in params]


+ 15
- 0
python_module/test/integration/test_parampack.py View File

@@ -257,3 +257,18 @@ def test_correctness_parampack():
pred1 = infer1(data).numpy() pred1 = infer1(data).numpy()
pred2 = infer2(data).numpy() pred2 = infer2(data).numpy()
assert np.allclose(pred1, pred2) assert np.allclose(pred1, pred2)


def test_parampack_group_func():
net = XORNet()
net = ParamPack(
net,
nr_ignore_first=1,
max_size_per_group=10,
max_nr_params_per_group=100,
group_func=lambda n, p: "weight" in n,
)
for p in net.parameters(requires_grad=True):
assert p.pack_group_key is not None
for n, p in net.named_parameters(requires_grad=True):
assert p.pack_group_key is not None

Loading…
Cancel
Save