Browse Source

docs(mge/distributed): add distributed.server docs

GitOrigin-RevId: 929d6adfcc
release-1.1
Megvii Engine Team 4 years ago
parent
commit
f5f86a05c4
1 changed files with 94 additions and 38 deletions
  1. +94
    -38
      imperative/python/megengine/distributed/server.py

+ 94
- 38
imperative/python/megengine/distributed/server.py View File

@@ -21,6 +21,12 @@ from .util import get_free_ports




class Methods: class Methods:
"""Distributed Server Method.
Used for exchange information between distributed nodes.

:param mm_server_port: multiple machine rpc server port.
"""

def __init__(self, mm_server_port): def __init__(self, mm_server_port):
self.lock = threading.Lock() self.lock = threading.Lock()
self.mm_server_port = mm_server_port self.mm_server_port = mm_server_port
@@ -31,51 +37,65 @@ class Methods:
self.dict_barrier_event = defaultdict(threading.Event) self.dict_barrier_event = defaultdict(threading.Event)


def connect(self): def connect(self):
"""Method for checking connection success."""
return True return True


def get_mm_server_port(self): def get_mm_server_port(self):
"""Get multiple machine rpc server port."""
return self.mm_server_port return self.mm_server_port


def set_is_grad(self, rank_peer, is_grad):
def set_is_grad(self, key, is_grad):
"""Mark send/recv need gradiants by key.
:param key: key to match send/recv op.
:param is_grad: whether this op need grad.
"""
with self.lock: with self.lock:
future = self.dict_is_grad[rank_peer]
future = self.dict_is_grad[key]
future.set(is_grad) future.set(is_grad)
return True return True


def check_is_grad(self, rank_peer):
def check_is_grad(self, key):
"""Check whether send/recv need gradiants.
:param key: key to match send/recv op.
"""
with self.lock: with self.lock:
future = self.dict_is_grad[rank_peer]
future = self.dict_is_grad[key]
ret = future.get() ret = future.get()
with self.lock: with self.lock:
del self.dict_is_grad[rank_peer]
del self.dict_is_grad[key]
return ret return ret


def set_remote_tracer(self, rank_peer, tracer_set):
def set_remote_tracer(self, key, tracer_set):
"""Set tracer dict for tracing send/recv op.

:param key: key to match send/recv op.
:param tracer_set: valid tracer set.
"""
with self.lock: with self.lock:
future = self.dict_remote_tracer[rank_peer]
future = self.dict_remote_tracer[key]
future.set(tracer_set) future.set(tracer_set)
return True return True


def check_remote_tracer(self, rank_peer):
def check_remote_tracer(self, key):
"""Get tracer dict for send/recv op.
:param key: key to match send/recv op.
"""
with self.lock: with self.lock:
future = self.dict_remote_tracer[rank_peer]
future = self.dict_remote_tracer[key]
ret = future.get() ret = future.get()
with self.lock: with self.lock:
del self.dict_remote_tracer[rank_peer]
del self.dict_remote_tracer[key]
return ret return ret


def set_pack_list(self, key, pack_list):
with self.lock:
future = self.dict_pack_list[key]
future.set(pack_list)
return True

def get_pack_list(self, key):
with self.lock:
future = self.dict_pack_list[key]
return future.get()

def group_barrier(self, key, size): def group_barrier(self, key, size):
"""A barrier wait for all group member.
:param key: group key to match each other.
:param size: group size.
"""
with self.lock: with self.lock:
self.dict_barrier_counter[key] += 1 self.dict_barrier_counter[key] += 1
counter = self.dict_barrier_counter[key] counter = self.dict_barrier_counter[key]
@@ -94,12 +114,23 @@ class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer):




def start_server(py_server_port, mm_server_port): def start_server(py_server_port, mm_server_port):
"""Start python distributed server and multiple machine server.
:param py_server_port: python server port.
:param mm_server_port: multiple machine server port.
"""
server = ThreadXMLRPCServer(("0.0.0.0", py_server_port), logRequests=False) server = ThreadXMLRPCServer(("0.0.0.0", py_server_port), logRequests=False)
server.register_instance(Methods(mm_server_port)) server.register_instance(Methods(mm_server_port))
server.serve_forever() server.serve_forever()




class Server: class Server:
"""Distributed Server for distributed training.
Should be running at master node.

:param port: python server port.
"""

def __init__(self, port): def __init__(self, port):
self.py_server_port = get_free_ports(1)[0] if port == 0 else port self.py_server_port = get_free_ports(1)[0] if port == 0 else port
self.mm_server_port = create_mm_server("0.0.0.0", 0) self.mm_server_port = create_mm_server("0.0.0.0", 0)
@@ -112,12 +143,19 @@ class Server:




class Client: class Client:
"""Distributed Client for distributed training.

:param master_ip: ip address of master node.
:param port: port of server at master node.
"""

def __init__(self, master_ip, port): def __init__(self, master_ip, port):
self.master_ip = master_ip self.master_ip = master_ip
self.port = port self.port = port
self.connect() self.connect()


def connect(self): def connect(self):
"""Check connection success."""
while True: while True:
try: try:
self.proxy = ServerProxy( self.proxy = ServerProxy(
@@ -129,25 +167,43 @@ class Client:
time.sleep(1) time.sleep(1)


def get_mm_server_port(self): def get_mm_server_port(self):
"""Get multiple machine server port."""
return self.proxy.get_mm_server_port() return self.proxy.get_mm_server_port()


def set_is_grad(self, rank_peer, is_grad):
self.proxy.set_is_grad(rank_peer, is_grad)

def check_is_grad(self, rank_peer):
return self.proxy.check_is_grad(rank_peer)

def set_remote_tracer(self, rank_peer, tracer_set):
self.proxy.set_remote_tracer(rank_peer, tracer_set)

def check_remote_tracer(self, rank_peer):
return self.proxy.check_remote_tracer(rank_peer)

def set_pack_list(self, key, pack_list):
self.proxy.set_pack_list(key, pack_list)

def get_pack_list(self, key):
return self.proxy.get_pack_list(key)
def set_is_grad(self, key, is_grad):
"""Mark send/recv need gradiants by key.
:param key: key to match send/recv op.
:param is_grad: whether this op need grad.
"""
self.proxy.set_is_grad(key, is_grad)

def check_is_grad(self, key):
"""Check whether send/recv need gradiants.
:param key: key to match send/recv op.
"""
return self.proxy.check_is_grad(key)

def set_remote_tracer(self, key, tracer_set):
"""Set tracer dict for tracing send/recv op.

:param key: key to match send/recv op.
:param tracer_set: valid tracer set.
"""
self.proxy.set_remote_tracer(key, tracer_set)

def check_remote_tracer(self, key):
"""Get tracer dict for send/recv op.
:param key: key to match send/recv op.
"""
return self.proxy.check_remote_tracer(key)


def group_barrier(self, key, size): def group_barrier(self, key, size):
"""A barrier wait for all group member.
:param key: group key to match each other.
:param size: group size.
"""
self.proxy.group_barrier(key, size) self.proxy.group_barrier(key, size)

Loading…
Cancel
Save