Browse Source

fix(trace): link io-op to avoid deadlock

GitOrigin-RevId: 872cb6b715
release-1.1
Megvii Engine Team 4 years ago
parent
commit
495472954d
3 changed files with 77 additions and 44 deletions
  1. +1
    -1
      imperative/python/megengine/distributed/functional.py
  2. +54
    -22
      imperative/python/megengine/jit/tracing.py
  3. +22
    -21
      imperative/python/test/unit/autodiff/test_grad_manger.py

+ 1
- 1
imperative/python/megengine/distributed/functional.py View File

@@ -300,7 +300,7 @@ def remote_recv(
device = get_default_device() device = get_default_device()
# dummy input # dummy input
if inp == None: if inp == None:
inp = tensor([0])
inp = tensor([0], device=device)
tracer_set = get_client().check_remote_tracer(key) tracer_set = get_client().check_remote_tracer(key)
for grad_manager in get_grad_managers(): for grad_manager in get_grad_managers():
if grad_manager.name in tracer_set: if grad_manager.name in tracer_set:


+ 54
- 22
imperative/python/megengine/jit/tracing.py View File

@@ -18,7 +18,13 @@ import weakref
import numpy as np import numpy as np


from ..core._imperative_rt import GraphProfiler from ..core._imperative_rt import GraphProfiler
from ..core._imperative_rt.ops import OprAttr
from ..core._imperative_rt.ops import (
CollectiveComm,
OprAttr,
RemoteRecv,
RemoteSend,
VirtualDep,
)
from ..core._trace_option import set_symbolic_shape from ..core._trace_option import set_symbolic_shape
from ..core._wrap import device as as_device from ..core._wrap import device as as_device
from ..core.ops.special import Const from ..core.ops.special import Const
@@ -92,6 +98,9 @@ class TensorInfo:
self.data_reader = None self.data_reader = None




_io_op_types = {CollectiveComm, RemoteSend, RemoteRecv}


class trace: class trace:
""" """
Wraps a callable and provide: Wraps a callable and provide:
@@ -143,8 +152,8 @@ class trace:
self._graph = None self._graph = None
self._need_reset_nodes = None self._need_reset_nodes = None
self._lazy_eval_graph = None self._lazy_eval_graph = None
self._lazy_eval_tensors = []
self._lazy_eval_tensor_count = 0
self._lazy_eval_tensors = weakref.WeakSet()
self._lazy_eval_links = None
self._active_tensors = weakref.WeakSet() self._active_tensors = weakref.WeakSet()
self._tensor_remaps = None self._tensor_remaps = None
self._inputs_to_restore = None self._inputs_to_restore = None
@@ -286,27 +295,22 @@ class trace:
apply.enable(apply_const_symbolic_mode) apply.enable(apply_const_symbolic_mode)
self._lazy_eval_graph = G.Graph() self._lazy_eval_graph = G.Graph()
self._apply_graph_options(self._lazy_eval_graph) self._apply_graph_options(self._lazy_eval_graph)
self._lazy_eval_links = ()


def _take_escaped_tensors(self): def _take_escaped_tensors(self):
escaped_tensors = tuple(self._active_tensors) escaped_tensors = tuple(self._active_tensors)
self._active_tensors.clear() self._active_tensors.clear()
return escaped_tensors return escaped_tensors


def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors):
active_lazy_eval_tensors = []
visited = set()
readers = []
for x in lazy_eval_tensors:
x = x()
if x is None or x in visited:
continue
reader = G.OutputNode(x._LazyEvalTensor__varnode).outputs[0]
readers.append(reader)
active_lazy_eval_tensors.append(x)
visited.add(x)
lazy_eval_graph.compile(*readers)
def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors, lazy_eval_links):
readers = [
G.OutputNode(x._LazyEvalTensor__varnode).outputs[0]
for x in lazy_eval_tensors
]
self._apply_graph_options(lazy_eval_graph)
lazy_eval_graph.compile(*lazy_eval_links, *readers)
lazy_eval_graph() lazy_eval_graph()
for r, x in zip(readers, active_lazy_eval_tensors):
for r, x in zip(readers, lazy_eval_tensors):
assign_raw_tensor(x, as_raw_tensor(r.op.get_value())) assign_raw_tensor(x, as_raw_tensor(r.op.get_value()))


@contextlib.contextmanager @contextlib.contextmanager
@@ -333,11 +337,18 @@ class trace:
if self._inputs_to_restore: if self._inputs_to_restore:
for x in self._inputs_to_restore: for x in self._inputs_to_restore:
x._TraceMixin__restore() x._TraceMixin__restore()
if self._symbolic and self._lazy_eval_tensors:
if self._symbolic and (
self._lazy_eval_tensors or self._lazy_eval_links
):
# eval lazy eval tensors # eval lazy eval tensors
self._lazy_eval(self._lazy_eval_graph, self._lazy_eval_tensors)
self._lazy_eval(
self._lazy_eval_graph,
tuple(self._lazy_eval_tensors),
self._lazy_eval_links,
)
self._lazy_eval_graph = None self._lazy_eval_graph = None
self._lazy_eval_tensors = None self._lazy_eval_tensors = None
self._lazy_eval_links = None
self._untraced = False self._untraced = False
else: else:
# compiled_tensor leaks # compiled_tensor leaks
@@ -438,8 +449,10 @@ class trace:
links += opnode.outputs[1:] links += opnode.outputs[1:]


for op, ihandles, ohandles in self._seq: for op, ihandles, ohandles in self._seq:
require_links = type(op) in _io_op_types

ivars = [] ivars = []
for h in ihandles:
for i, h in enumerate(ihandles):
info = self._tinfo[h] info = self._tinfo[h]
if not hasattr(info, "varnode"): if not hasattr(info, "varnode"):
assert info.external assert info.external
@@ -455,9 +468,14 @@ class trace:
) )
need_reset_nodes.append(opnode) need_reset_nodes.append(opnode)
info.varnode, *links = opnode.outputs info.varnode, *links = opnode.outputs
if require_links and i == 0 and len(links) > 0:
info.varnode = apply(VirtualDep(), info.varnode, *links)[0]
links = (info.varnode,)


ivars.append(info.varnode) ivars.append(info.varnode)
ovars = apply(op, *ivars) ovars = apply(op, *ivars)
if require_links and len(ovars) > 0:
links = (ovars[0],)
assert len(ovars) == len(ohandles) assert len(ovars) == len(ohandles)
for h, v in zip(ohandles, ovars): for h, v in zip(ohandles, ovars):
info = self._tinfo[h] info = self._tinfo[h]
@@ -502,6 +520,8 @@ class trace:
info.data_read = True info.data_read = True


def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
if is_tracing():
return self.__wrapped__(*args, **kwargs)
with self._setup(): with self._setup():
if self._capture_as_const: if self._capture_as_const:
self._process_inputs(*args, **kwargs) self._process_inputs(*args, **kwargs)
@@ -938,9 +958,21 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor):
or graph.make_const(x._dev_tensor()) or graph.make_const(x._dev_tensor())
for x in args for x in args
] ]

require_links = type(op) in _io_op_types

if require_links and active_trace._lazy_eval_links:
assert len(ivars) > 0, "op should has at least one input"
ivars[0] = apply(VirtualDep(), ivars[0], *active_trace._lazy_eval_links)[0]
active_trace._lazy_eval_links = (ivars[0],)

ovars = apply(op, *ivars) ovars = apply(op, *ivars)

if require_links:
active_trace._lazy_eval_links = (ovars[0],)

outputs = [LazyEvalTensor(v) for v in ovars] outputs = [LazyEvalTensor(v) for v in ovars]
active_trace._lazy_eval_tensors.extend(weakref.ref(oup) for oup in outputs)
active_trace._lazy_eval_tensors.update(outputs)
return outputs return outputs




@@ -951,7 +983,7 @@ apply.disable(apply_symbolic_mode)
def apply_const_symbolic_mode(op: Const, *args: RawTensor): def apply_const_symbolic_mode(op: Const, *args: RawTensor):
graph = active_trace._lazy_eval_graph graph = active_trace._lazy_eval_graph
ret = LazyEvalTensor(graph.make_const(op.value, dtype=op.dtype, device=op.device)) ret = LazyEvalTensor(graph.make_const(op.value, dtype=op.dtype, device=op.device))
active_trace._lazy_eval_tensors.append(weakref.ref(ret))
active_trace._lazy_eval_tensors.add(ret)
return (ret,) return (ret,)






+ 22
- 21
imperative/python/test/unit/autodiff/test_grad_manger.py View File

@@ -18,6 +18,7 @@ import megengine.optimizer as optim
from megengine.autodiff import GradManager from megengine.autodiff import GradManager
from megengine.core._imperative_rt.imperative import sync from megengine.core._imperative_rt.imperative import sync
from megengine.distributed.helper import get_device_count_by_fork from megengine.distributed.helper import get_device_count_by_fork
from megengine.jit import trace




def test_basic(): def test_basic():
@@ -75,27 +76,27 @@ def test_remote_grad():
gm = GradManager().attach(m.parameters()) gm = GradManager().attach(m.parameters())
opt = optim.SGD(m.parameters(), 1e-3, momentum=0.9) opt = optim.SGD(m.parameters(), 1e-3, momentum=0.9)


@trace(symbolic=True)
def train_func(x): def train_func(x):
if rank != 0:
x = dist.functional.remote_recv(
rank - 1, shape=(1, rank * 2 + 2), dtype=np.float32
)
print(rank, "x", x)
y = m(x)
print(rank, "y", y)
if rank != size - 1:
y = dist.functional.remote_send(y, dest_rank=rank + 1)
return y

with gm:
y = train_func(x)
if rank == size - 1:
y = y.mean()
gm.backward(y)
else:
gm.backward()
opt.step().clear_grad()
# sync because send is the last job
sync()
with gm:
if rank != 0:
x = dist.functional.remote_recv(
rank - 1, shape=(1, rank * 2 + 2), dtype=np.float32
)
y = m(x)
if rank != size - 1:
y = dist.functional.remote_send(y, dest_rank=rank + 1)
if rank == size - 1:
y = y.mean()
gm.backward(y)
else:
gm.backward()
opt.step().clear_grad()

for i in range(3):
train_func(x)

for param in m.parameters():
param.numpy()


worker() worker()

Loading…
Cancel
Save