Browse Source

fix(mge/param_pack): release old parameters

GitOrigin-RevId: 40a1f044e9
tags/v0.3.2
Megvii Engine Team 5 years ago
parent
commit
f077a5292c
2 changed files with 13 additions and 5 deletions
  1. +5
    -1
      python_module/megengine/core/tensor_nn.py
  2. +8
    -4
      python_module/megengine/module/parampack.py

+ 5
- 1
python_module/megengine/core/tensor_nn.py View File

@@ -35,4 +35,8 @@ class Parameter(Tensor):
def shape(self): def shape(self):
r"""Return shape of parameter. r"""Return shape of parameter.
""" """
return self._symvar.imm_shape
if self._Tensor__val is not None:
return self._Tensor__val.shape
elif self._Tensor__sym is not None:
return self._Tensor__sym.imm_shape
return None

+ 8
- 4
python_module/megengine/module/parampack.py View File

@@ -56,7 +56,7 @@ class ParamPack(Module):
for param in params: for param in params:
if self._nr_ignore_first > ignored: if self._nr_ignore_first > ignored:
ignored += 1 ignored += 1
self._grouped_params.append([{"tensor": param, "id": param_id}])
self._grouped_params.append([{"shape": param.shape, "id": param_id}])
self._packed_params.append(param) self._packed_params.append(param)
else: else:
key = (param.dtype, param.device, param.requires_grad) key = (param.dtype, param.device, param.requires_grad)
@@ -96,7 +96,9 @@ class ParamPack(Module):
if idx == 1: if idx == 1:
# ignore param packs with only one item # ignore param packs with only one item
self._packed_params.append(params[0]["tensor"]) self._packed_params.append(params[0]["tensor"])
self._grouped_params.append(params)
self._grouped_params.append(
[{"shape": params[0]["tensor"].shape, "id": params[0]["id"]}]
)
continue continue


packed_value = np.zeros((offset,), dtype=dtype) packed_value = np.zeros((offset,), dtype=dtype)
@@ -110,7 +112,9 @@ class ParamPack(Module):
requires_grad=requires_grad, requires_grad=requires_grad,
) )
self._packed_params.append(new_param) self._packed_params.append(new_param)
self._grouped_params.append(params)
self._grouped_params.append(
[{"shape": i["tensor"].shape, "id": i["id"]} for i in params]
)


def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
replace_param = dict() replace_param = dict()
@@ -120,7 +124,7 @@ class ParamPack(Module):
if len(grouped_params) == 1: if len(grouped_params) == 1:
continue continue
split = param_pack_split( split = param_pack_split(
packed_param._symvar, [i["tensor"].shape for i in grouped_params]
packed_param._symvar, [i["shape"] for i in grouped_params]
) )
split = [ split = [
Parameter(Tensor(i, requires_grad=packed_param.requires_grad)) Parameter(Tensor(i, requires_grad=packed_param.requires_grad))


Loading…
Cancel
Save