|
|
@@ -21,6 +21,12 @@ from .util import get_free_ports |
|
|
|
|
|
|
|
|
|
|
|
class Methods: |
|
|
|
"""Distributed Server Method. |
|
|
|
Used for exchange information between distributed nodes. |
|
|
|
|
|
|
|
:param mm_server_port: multiple machine rpc server port. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, mm_server_port): |
|
|
|
self.lock = threading.Lock() |
|
|
|
self.mm_server_port = mm_server_port |
|
|
@@ -31,51 +37,65 @@ class Methods: |
|
|
|
self.dict_barrier_event = defaultdict(threading.Event) |
|
|
|
|
|
|
|
def connect(self): |
|
|
|
"""Method for checking connection success.""" |
|
|
|
return True |
|
|
|
|
|
|
|
def get_mm_server_port(self): |
|
|
|
"""Get multiple machine rpc server port.""" |
|
|
|
return self.mm_server_port |
|
|
|
|
|
|
|
def set_is_grad(self, rank_peer, is_grad): |
|
|
|
def set_is_grad(self, key, is_grad): |
|
|
|
"""Mark send/recv need gradiants by key. |
|
|
|
|
|
|
|
:param key: key to match send/recv op. |
|
|
|
:param is_grad: whether this op need grad. |
|
|
|
""" |
|
|
|
with self.lock: |
|
|
|
future = self.dict_is_grad[rank_peer] |
|
|
|
future = self.dict_is_grad[key] |
|
|
|
future.set(is_grad) |
|
|
|
return True |
|
|
|
|
|
|
|
def check_is_grad(self, rank_peer): |
|
|
|
def check_is_grad(self, key): |
|
|
|
"""Check whether send/recv need gradiants. |
|
|
|
|
|
|
|
:param key: key to match send/recv op. |
|
|
|
""" |
|
|
|
with self.lock: |
|
|
|
future = self.dict_is_grad[rank_peer] |
|
|
|
future = self.dict_is_grad[key] |
|
|
|
ret = future.get() |
|
|
|
with self.lock: |
|
|
|
del self.dict_is_grad[rank_peer] |
|
|
|
del self.dict_is_grad[key] |
|
|
|
return ret |
|
|
|
|
|
|
|
def set_remote_tracer(self, rank_peer, tracer_set): |
|
|
|
def set_remote_tracer(self, key, tracer_set): |
|
|
|
"""Set tracer dict for tracing send/recv op. |
|
|
|
|
|
|
|
:param key: key to match send/recv op. |
|
|
|
:param tracer_set: valid tracer set. |
|
|
|
""" |
|
|
|
with self.lock: |
|
|
|
future = self.dict_remote_tracer[rank_peer] |
|
|
|
future = self.dict_remote_tracer[key] |
|
|
|
future.set(tracer_set) |
|
|
|
return True |
|
|
|
|
|
|
|
def check_remote_tracer(self, rank_peer): |
|
|
|
def check_remote_tracer(self, key): |
|
|
|
"""Get tracer dict for send/recv op. |
|
|
|
|
|
|
|
:param key: key to match send/recv op. |
|
|
|
""" |
|
|
|
with self.lock: |
|
|
|
future = self.dict_remote_tracer[rank_peer] |
|
|
|
future = self.dict_remote_tracer[key] |
|
|
|
ret = future.get() |
|
|
|
with self.lock: |
|
|
|
del self.dict_remote_tracer[rank_peer] |
|
|
|
del self.dict_remote_tracer[key] |
|
|
|
return ret |
|
|
|
|
|
|
|
def set_pack_list(self, key, pack_list): |
|
|
|
with self.lock: |
|
|
|
future = self.dict_pack_list[key] |
|
|
|
future.set(pack_list) |
|
|
|
return True |
|
|
|
|
|
|
|
def get_pack_list(self, key): |
|
|
|
with self.lock: |
|
|
|
future = self.dict_pack_list[key] |
|
|
|
return future.get() |
|
|
|
|
|
|
|
def group_barrier(self, key, size): |
|
|
|
"""A barrier wait for all group member. |
|
|
|
|
|
|
|
:param key: group key to match each other. |
|
|
|
:param size: group size. |
|
|
|
""" |
|
|
|
with self.lock: |
|
|
|
self.dict_barrier_counter[key] += 1 |
|
|
|
counter = self.dict_barrier_counter[key] |
|
|
@@ -94,12 +114,23 @@ class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer): |
|
|
|
|
|
|
|
|
|
|
|
def start_server(py_server_port, mm_server_port): |
|
|
|
"""Start python distributed server and multiple machine server. |
|
|
|
|
|
|
|
:param py_server_port: python server port. |
|
|
|
:param mm_server_port: multiple machine server port. |
|
|
|
""" |
|
|
|
server = ThreadXMLRPCServer(("0.0.0.0", py_server_port), logRequests=False) |
|
|
|
server.register_instance(Methods(mm_server_port)) |
|
|
|
server.serve_forever() |
|
|
|
|
|
|
|
|
|
|
|
class Server: |
|
|
|
"""Distributed Server for distributed training. |
|
|
|
Should be running at master node. |
|
|
|
|
|
|
|
:param port: python server port. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, port): |
|
|
|
self.py_server_port = get_free_ports(1)[0] if port == 0 else port |
|
|
|
self.mm_server_port = create_mm_server("0.0.0.0", 0) |
|
|
@@ -112,12 +143,19 @@ class Server: |
|
|
|
|
|
|
|
|
|
|
|
class Client: |
|
|
|
"""Distributed Client for distributed training. |
|
|
|
|
|
|
|
:param master_ip: ip address of master node. |
|
|
|
:param port: port of server at master node. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, master_ip, port): |
|
|
|
self.master_ip = master_ip |
|
|
|
self.port = port |
|
|
|
self.connect() |
|
|
|
|
|
|
|
def connect(self): |
|
|
|
"""Check connection success.""" |
|
|
|
while True: |
|
|
|
try: |
|
|
|
self.proxy = ServerProxy( |
|
|
@@ -129,25 +167,43 @@ class Client: |
|
|
|
time.sleep(1) |
|
|
|
|
|
|
|
def get_mm_server_port(self): |
|
|
|
"""Get multiple machine server port.""" |
|
|
|
return self.proxy.get_mm_server_port() |
|
|
|
|
|
|
|
def set_is_grad(self, rank_peer, is_grad): |
|
|
|
self.proxy.set_is_grad(rank_peer, is_grad) |
|
|
|
|
|
|
|
def check_is_grad(self, rank_peer): |
|
|
|
return self.proxy.check_is_grad(rank_peer) |
|
|
|
|
|
|
|
def set_remote_tracer(self, rank_peer, tracer_set): |
|
|
|
self.proxy.set_remote_tracer(rank_peer, tracer_set) |
|
|
|
|
|
|
|
def check_remote_tracer(self, rank_peer): |
|
|
|
return self.proxy.check_remote_tracer(rank_peer) |
|
|
|
|
|
|
|
def set_pack_list(self, key, pack_list): |
|
|
|
self.proxy.set_pack_list(key, pack_list) |
|
|
|
|
|
|
|
def get_pack_list(self, key): |
|
|
|
return self.proxy.get_pack_list(key) |
|
|
|
def set_is_grad(self, key, is_grad): |
|
|
|
"""Mark send/recv need gradiants by key. |
|
|
|
|
|
|
|
:param key: key to match send/recv op. |
|
|
|
:param is_grad: whether this op need grad. |
|
|
|
""" |
|
|
|
self.proxy.set_is_grad(key, is_grad) |
|
|
|
|
|
|
|
def check_is_grad(self, key): |
|
|
|
"""Check whether send/recv need gradiants. |
|
|
|
|
|
|
|
:param key: key to match send/recv op. |
|
|
|
""" |
|
|
|
return self.proxy.check_is_grad(key) |
|
|
|
|
|
|
|
def set_remote_tracer(self, key, tracer_set): |
|
|
|
"""Set tracer dict for tracing send/recv op. |
|
|
|
|
|
|
|
:param key: key to match send/recv op. |
|
|
|
:param tracer_set: valid tracer set. |
|
|
|
""" |
|
|
|
self.proxy.set_remote_tracer(key, tracer_set) |
|
|
|
|
|
|
|
def check_remote_tracer(self, key): |
|
|
|
"""Get tracer dict for send/recv op. |
|
|
|
|
|
|
|
:param key: key to match send/recv op. |
|
|
|
""" |
|
|
|
return self.proxy.check_remote_tracer(key) |
|
|
|
|
|
|
|
def group_barrier(self, key, size): |
|
|
|
"""A barrier wait for all group member. |
|
|
|
|
|
|
|
:param key: group key to match each other. |
|
|
|
:param size: group size. |
|
|
|
""" |
|
|
|
self.proxy.group_barrier(key, size) |