Browse Source

perf(dist): speed up bcast_val

GitOrigin-RevId: 21c4123b09
release-1.4
Megvii Engine Team 4 years ago
parent
commit
6964576e89
1 changed files with 28 additions and 12 deletions
  1. +28
    -12
      imperative/python/megengine/distributed/server.py

+ 28
- 12
imperative/python/megengine/distributed/server.py View File

@@ -36,6 +36,7 @@ class Methods:
self.dict_barrier_counter = defaultdict(int)
self.dict_barrier_event = defaultdict(threading.Event)
self.user_dict = defaultdict(partial(Future, False))
self.bcast_dict = {}

def connect(self):
"""Method for checking connection success."""
@@ -127,6 +128,23 @@ class Methods:
future = self.user_dict[key]
return future.get()

def bcast_val(self, val, key, size):
with self.lock:
if key not in self.bcast_dict:
self.bcast_dict[key] = [Future(False), size]
arr = self.bcast_dict[key]
if val is not None:
arr[0].set(val)
val = None
else:
val = arr[0].get()
with self.lock:
cnt = arr[1] - 1
arr[1] = cnt
if cnt == 0:
del self.bcast_dict[key]
return val


class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer):
pass
@@ -142,7 +160,9 @@ def _start_server(py_server_port, queue):
"""
try:
mm_server_port = create_mm_server("0.0.0.0", 0)
server = ThreadXMLRPCServer(("0.0.0.0", py_server_port), logRequests=False)
server = ThreadXMLRPCServer(
("0.0.0.0", py_server_port), logRequests=False, allow_none=True
)
server.register_instance(Methods(mm_server_port))
_, py_server_port = server.server_address
queue.put((py_server_port, mm_server_port))
@@ -185,13 +205,14 @@ class Client:
self.master_ip = master_ip
self.port = port
self.connect()
self.bcast_dict = defaultdict(lambda: 0)

def connect(self):
"""Check connection success."""
while True:
try:
self.proxy = ServerProxy(
"http://{}:{}".format(self.master_ip, self.port)
"http://{}:{}".format(self.master_ip, self.port), allow_none=True
)
if self.proxy.connect():
break
@@ -247,22 +268,17 @@ class Client:

def user_set(self, key, val):
"""Set user defined key-value pairs across processes."""
self.proxy.user_set(key, val)
return self.proxy.user_set(key, val)

def user_get(self, key):
"""Get user defined key-value pairs across processes."""
return self.proxy.user_get(key)

def bcast_val(self, val, key, size):
if val is not None:
self.user_set(key + "_sync", val)
self.group_barrier(key, size)
self.group_barrier(key, size)
else:
self.group_barrier(key, size)
val = self.user_get(key + "_sync")
self.group_barrier(key, size)
return val
idx = self.bcast_dict[key] + 1
self.bcast_dict[key] = idx
key = key + "_bcast_" + str(idx)
return self.proxy.bcast_val(val, key, size)


def main(port=0, verbose=True):


Loading…
Cancel
Save