Browse Source

fix(mge/param_pack): release old parameters

GitOrigin-RevId: 40a1f044e9
tags/v0.3.2
Megvii Engine Team 5 years ago
parent
commit
f077a5292c
2 changed files with 13 additions and 5 deletions
  1. +5
    -1
      python_module/megengine/core/tensor_nn.py
  2. +8
    -4
      python_module/megengine/module/parampack.py

+ 5
- 1
python_module/megengine/core/tensor_nn.py View File

@@ -35,4 +35,8 @@ class Parameter(Tensor):
def shape(self):
r"""Return shape of parameter.
"""
return self._symvar.imm_shape
if self._Tensor__val is not None:
return self._Tensor__val.shape
elif self._Tensor__sym is not None:
return self._Tensor__sym.imm_shape
return None

+ 8
- 4
python_module/megengine/module/parampack.py View File

@@ -56,7 +56,7 @@ class ParamPack(Module):
for param in params:
if self._nr_ignore_first > ignored:
ignored += 1
self._grouped_params.append([{"tensor": param, "id": param_id}])
self._grouped_params.append([{"shape": param.shape, "id": param_id}])
self._packed_params.append(param)
else:
key = (param.dtype, param.device, param.requires_grad)
@@ -96,7 +96,9 @@ class ParamPack(Module):
if idx == 1:
# ignore param packs with only one item
self._packed_params.append(params[0]["tensor"])
self._grouped_params.append(params)
self._grouped_params.append(
[{"shape": params[0]["tensor"].shape, "id": params[0]["id"]}]
)
continue

packed_value = np.zeros((offset,), dtype=dtype)
@@ -110,7 +112,9 @@ class ParamPack(Module):
requires_grad=requires_grad,
)
self._packed_params.append(new_param)
self._grouped_params.append(params)
self._grouped_params.append(
[{"shape": i["tensor"].shape, "id": i["id"]} for i in params]
)

def forward(self, *args, **kwargs):
replace_param = dict()
@@ -120,7 +124,7 @@ class ParamPack(Module):
if len(grouped_params) == 1:
continue
split = param_pack_split(
packed_param._symvar, [i["tensor"].shape for i in grouped_params]
packed_param._symvar, [i["shape"] for i in grouped_params]
)
split = [
Parameter(Tensor(i, requires_grad=packed_param.requires_grad))


Loading…
Cancel
Save