|
|
@@ -16,6 +16,14 @@ from .._internal.opr import param_pack_split |
|
|
|
|
|
|
|
|
|
|
|
class ParamPack(Module): |
|
|
|
r"""Pack module's parameters |
|
|
|
|
|
|
|
:param model: the module you want to pack parameters. |
|
|
|
:param nr_ignore_first: how many parameters will be unpacked at first. |
|
|
|
:param max_size_per_group: upper bound of packed parameters' size in MB. |
|
|
|
:param max_nr_params_per_group: upper bound of the number of parameters of each group. |
|
|
|
|
|
|
|
""" |
|
|
|
def __init__(self, |
|
|
|
model: Module, |
|
|
|
nr_ignore_first:int = 8, |
|
|
|