Browse Source

perf(dist): speed up bcast_val

GitOrigin-RevId: 21c4123b09
release-1.4
Megvii Engine Team 4 years ago
parent
commit
6964576e89
1 changed files with 28 additions and 12 deletions
  1. +28
    -12
      imperative/python/megengine/distributed/server.py

+ 28
- 12
imperative/python/megengine/distributed/server.py View File

@@ -36,6 +36,7 @@ class Methods:
self.dict_barrier_counter = defaultdict(int) self.dict_barrier_counter = defaultdict(int)
self.dict_barrier_event = defaultdict(threading.Event) self.dict_barrier_event = defaultdict(threading.Event)
self.user_dict = defaultdict(partial(Future, False)) self.user_dict = defaultdict(partial(Future, False))
self.bcast_dict = {}


def connect(self): def connect(self):
"""Method for checking connection success.""" """Method for checking connection success."""
@@ -127,6 +128,23 @@ class Methods:
future = self.user_dict[key] future = self.user_dict[key]
return future.get() return future.get()


def bcast_val(self, val, key, size):
with self.lock:
if key not in self.bcast_dict:
self.bcast_dict[key] = [Future(False), size]
arr = self.bcast_dict[key]
if val is not None:
arr[0].set(val)
val = None
else:
val = arr[0].get()
with self.lock:
cnt = arr[1] - 1
arr[1] = cnt
if cnt == 0:
del self.bcast_dict[key]
return val



class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer): class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer):
pass pass
@@ -142,7 +160,9 @@ def _start_server(py_server_port, queue):
""" """
try: try:
mm_server_port = create_mm_server("0.0.0.0", 0) mm_server_port = create_mm_server("0.0.0.0", 0)
server = ThreadXMLRPCServer(("0.0.0.0", py_server_port), logRequests=False)
server = ThreadXMLRPCServer(
("0.0.0.0", py_server_port), logRequests=False, allow_none=True
)
server.register_instance(Methods(mm_server_port)) server.register_instance(Methods(mm_server_port))
_, py_server_port = server.server_address _, py_server_port = server.server_address
queue.put((py_server_port, mm_server_port)) queue.put((py_server_port, mm_server_port))
@@ -185,13 +205,14 @@ class Client:
self.master_ip = master_ip self.master_ip = master_ip
self.port = port self.port = port
self.connect() self.connect()
self.bcast_dict = defaultdict(lambda: 0)


def connect(self): def connect(self):
"""Check connection success.""" """Check connection success."""
while True: while True:
try: try:
self.proxy = ServerProxy( self.proxy = ServerProxy(
"http://{}:{}".format(self.master_ip, self.port)
"http://{}:{}".format(self.master_ip, self.port), allow_none=True
) )
if self.proxy.connect(): if self.proxy.connect():
break break
@@ -247,22 +268,17 @@ class Client:


def user_set(self, key, val): def user_set(self, key, val):
"""Set user defined key-value pairs across processes.""" """Set user defined key-value pairs across processes."""
self.proxy.user_set(key, val)
return self.proxy.user_set(key, val)


def user_get(self, key): def user_get(self, key):
"""Get user defined key-value pairs across processes.""" """Get user defined key-value pairs across processes."""
return self.proxy.user_get(key) return self.proxy.user_get(key)


def bcast_val(self, val, key, size): def bcast_val(self, val, key, size):
if val is not None:
self.user_set(key + "_sync", val)
self.group_barrier(key, size)
self.group_barrier(key, size)
else:
self.group_barrier(key, size)
val = self.user_get(key + "_sync")
self.group_barrier(key, size)
return val
idx = self.bcast_dict[key] + 1
self.bcast_dict[key] = idx
key = key + "_bcast_" + str(idx)
return self.proxy.bcast_val(val, key, size)




def main(port=0, verbose=True): def main(port=0, verbose=True):


Loading…
Cancel
Save