Browse Source

fix(mge): make parampack run with tensor symbolic shape

GitOrigin-RevId: 6fc313785d
release-1.2
Megvii Engine Team 4 years ago
parent
commit
b9762d714c
3 changed files with 3 additions and 4 deletions
  1. +0
    -1
      imperative/python/megengine/core/tensor/tensor_wrapper.py
  2. +2
    -2
      imperative/python/megengine/distributed/helper.py
  3. +1
    -1
      imperative/python/test/integration/test_dp_correctness.py

+ 0
- 1
imperative/python/megengine/core/tensor/tensor_wrapper.py View File

@@ -14,7 +14,6 @@ import numpy as np

from .._imperative_rt.common import CompNode
from .._imperative_rt.core2 import Tensor, apply
from .._trace_option import use_symbolic_shape
from ..ops import builtin
from ..ops.builtin import Elemwise, GetVarShape
from ..ops.special import Const


+ 2
- 2
imperative/python/megengine/distributed/helper.py View File

@@ -218,7 +218,7 @@ class AllreduceCallback:
if len(self._packing_list[dtype]) == 0:
return
grad_list = [self._gradients_dict[p] for p in self._packing_list[dtype]]
shapes = [p.shape for p in self._packing_list[dtype]]
shapes = [p._tuple_shape for p in self._packing_list[dtype]]
reduced_grads = pack_allreduce_split(
grad_list, shapes, self._group, self._reduce_method
)
@@ -241,7 +241,7 @@ class AllreduceCallback:
dtype_str = str(np.dtype(param.dtype))
dtype_size = np.dtype(param.dtype).itemsize
self._packing_list[dtype_str].append(param)
self._packing_size[dtype_str] += int(np.prod(param.shape)) * dtype_size
self._packing_size[dtype_str] += int(np.prod(param._tuple_shape)) * dtype_size
if self._packing_size[dtype_str] > self._param_pack_thd:
self._pack(dtype_str)
return self._futures_dict[param]


+ 1
- 1
imperative/python/test/integration/test_dp_correctness.py View File

@@ -194,7 +194,7 @@ def run_test(
worker(max_err)


@pytest.mark.skipif(get_device_count_by_fork("gpu") < 4, reason="need more gpu device")
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.isolated_distributed
@pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"


Loading…
Cancel
Save