From f5f86a05c43f7f266d2f629f0fde54569a2b5145 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 28 Sep 2020 15:45:19 +0800 Subject: [PATCH] docs(mge/distributed): add distributed.server docs GitOrigin-RevId: 929d6adfcc2e5301c8bedf592871acc3e06ea126 --- imperative/python/megengine/distributed/server.py | 132 +++++++++++++++------- 1 file changed, 94 insertions(+), 38 deletions(-) diff --git a/imperative/python/megengine/distributed/server.py b/imperative/python/megengine/distributed/server.py index d8f199a6..6c51ae7f 100644 --- a/imperative/python/megengine/distributed/server.py +++ b/imperative/python/megengine/distributed/server.py @@ -21,6 +21,12 @@ from .util import get_free_ports class Methods: + """Distributed Server Method. + Used for exchange information between distributed nodes. + + :param mm_server_port: multiple machine rpc server port. + """ + def __init__(self, mm_server_port): self.lock = threading.Lock() self.mm_server_port = mm_server_port @@ -31,51 +37,65 @@ class Methods: self.dict_barrier_event = defaultdict(threading.Event) def connect(self): + """Method for checking connection success.""" return True def get_mm_server_port(self): + """Get multiple machine rpc server port.""" return self.mm_server_port - def set_is_grad(self, rank_peer, is_grad): + def set_is_grad(self, key, is_grad): + """Mark send/recv need gradiants by key. + + :param key: key to match send/recv op. + :param is_grad: whether this op need grad. + """ with self.lock: - future = self.dict_is_grad[rank_peer] + future = self.dict_is_grad[key] future.set(is_grad) return True - def check_is_grad(self, rank_peer): + def check_is_grad(self, key): + """Check whether send/recv need gradiants. + + :param key: key to match send/recv op. + """ with self.lock: - future = self.dict_is_grad[rank_peer] + future = self.dict_is_grad[key] ret = future.get() with self.lock: - del self.dict_is_grad[rank_peer] + del self.dict_is_grad[key] return ret - def set_remote_tracer(self, rank_peer, tracer_set): + def set_remote_tracer(self, key, tracer_set): + """Set tracer dict for tracing send/recv op. + + :param key: key to match send/recv op. + :param tracer_set: valid tracer set. + """ with self.lock: - future = self.dict_remote_tracer[rank_peer] + future = self.dict_remote_tracer[key] future.set(tracer_set) return True - def check_remote_tracer(self, rank_peer): + def check_remote_tracer(self, key): + """Get tracer dict for send/recv op. + + :param key: key to match send/recv op. + """ with self.lock: - future = self.dict_remote_tracer[rank_peer] + future = self.dict_remote_tracer[key] ret = future.get() with self.lock: - del self.dict_remote_tracer[rank_peer] + del self.dict_remote_tracer[key] return ret - def set_pack_list(self, key, pack_list): - with self.lock: - future = self.dict_pack_list[key] - future.set(pack_list) - return True - - def get_pack_list(self, key): - with self.lock: - future = self.dict_pack_list[key] - return future.get() - def group_barrier(self, key, size): + """A barrier wait for all group member. + + :param key: group key to match each other. + :param size: group size. + """ with self.lock: self.dict_barrier_counter[key] += 1 counter = self.dict_barrier_counter[key] @@ -94,12 +114,23 @@ class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer): def start_server(py_server_port, mm_server_port): + """Start python distributed server and multiple machine server. + + :param py_server_port: python server port. + :param mm_server_port: multiple machine server port. + """ server = ThreadXMLRPCServer(("0.0.0.0", py_server_port), logRequests=False) server.register_instance(Methods(mm_server_port)) server.serve_forever() class Server: + """Distributed Server for distributed training. + Should be running at master node. + + :param port: python server port. + """ + def __init__(self, port): self.py_server_port = get_free_ports(1)[0] if port == 0 else port self.mm_server_port = create_mm_server("0.0.0.0", 0) @@ -112,12 +143,19 @@ class Server: class Client: + """Distributed Client for distributed training. + + :param master_ip: ip address of master node. + :param port: port of server at master node. + """ + def __init__(self, master_ip, port): self.master_ip = master_ip self.port = port self.connect() def connect(self): + """Check connection success.""" while True: try: self.proxy = ServerProxy( @@ -129,25 +167,43 @@ class Client: time.sleep(1) def get_mm_server_port(self): + """Get multiple machine server port.""" return self.proxy.get_mm_server_port() - def set_is_grad(self, rank_peer, is_grad): - self.proxy.set_is_grad(rank_peer, is_grad) - - def check_is_grad(self, rank_peer): - return self.proxy.check_is_grad(rank_peer) - - def set_remote_tracer(self, rank_peer, tracer_set): - self.proxy.set_remote_tracer(rank_peer, tracer_set) - - def check_remote_tracer(self, rank_peer): - return self.proxy.check_remote_tracer(rank_peer) - - def set_pack_list(self, key, pack_list): - self.proxy.set_pack_list(key, pack_list) - - def get_pack_list(self, key): - return self.proxy.get_pack_list(key) + def set_is_grad(self, key, is_grad): + """Mark send/recv need gradiants by key. + + :param key: key to match send/recv op. + :param is_grad: whether this op need grad. + """ + self.proxy.set_is_grad(key, is_grad) + + def check_is_grad(self, key): + """Check whether send/recv need gradiants. + + :param key: key to match send/recv op. + """ + return self.proxy.check_is_grad(key) + + def set_remote_tracer(self, key, tracer_set): + """Set tracer dict for tracing send/recv op. + + :param key: key to match send/recv op. + :param tracer_set: valid tracer set. + """ + self.proxy.set_remote_tracer(key, tracer_set) + + def check_remote_tracer(self, key): + """Get tracer dict for send/recv op. + + :param key: key to match send/recv op. + """ + return self.proxy.check_remote_tracer(key) def group_barrier(self, key, size): + """A barrier wait for all group member. + + :param key: group key to match each other. + :param size: group size. + """ self.proxy.group_barrier(key, size)