Browse Source

feat(mge): support python -m megengine.distributed.server

GitOrigin-RevId: f1e5c8e3cf
release-1.2
Megvii Engine Team 4 years ago
parent
commit
9c92701f63
1 changed files with 28 additions and 9 deletions
  1. +28
    -9
      imperative/python/megengine/distributed/server.py

+ 28
- 9
imperative/python/megengine/distributed/server.py View File

@@ -49,7 +49,7 @@ class Methods:
def set_is_grad(self, key, is_grad):
"""
Mark send/recv need gradiants by key.
:param key: key to match send/recv op.
:param is_grad: whether this op need grad.
"""
@@ -61,7 +61,7 @@ class Methods:
def check_is_grad(self, key):
"""
Check whether send/recv need gradiants.
:param key: key to match send/recv op.
"""
with self.lock:
@@ -86,7 +86,7 @@ class Methods:
def check_remote_tracer(self, key):
"""
Get tracer dict for send/recv op.
:param key: key to match send/recv op.
"""
with self.lock:
@@ -99,7 +99,7 @@ class Methods:
def group_barrier(self, key, size):
"""
A barrier wait for all group member.
:param key: group key to match each other.
:param size: group size.
"""
@@ -136,7 +136,7 @@ class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer):
def _start_server(py_server_port, mm_server_port, queue):
"""
Start python distributed server and multiple machine server.
:param py_server_port: python server port.
:param mm_server_port: multiple machine server port.
:param queue: server port will put in this queue, puts exception when process fails.
@@ -205,7 +205,7 @@ class Client:
def set_is_grad(self, key, is_grad):
"""
Mark send/recv need gradiants by key.
:param key: key to match send/recv op.
:param is_grad: whether this op need grad.
"""
@@ -214,7 +214,7 @@ class Client:
def check_is_grad(self, key):
"""
Check whether send/recv need gradiants.
:param key: key to match send/recv op.
"""
return self.proxy.check_is_grad(key)
@@ -231,7 +231,7 @@ class Client:
def check_remote_tracer(self, key):
"""
Get tracer dict for send/recv op.
:param key: key to match send/recv op.
"""
return self.proxy.check_remote_tracer(key)
@@ -239,7 +239,7 @@ class Client:
def group_barrier(self, key, size):
"""
A barrier wait for all group member.
:param key: group key to match each other.
:param size: group size.
"""
@@ -252,3 +252,22 @@ class Client:
def user_get(self, key):
"""Get user defined key-value pairs across processes."""
return self.proxy.user_get(key)


def main(port=0, verbose=True):
mm_server_port = create_mm_server("0.0.0.0", 0)
server = ThreadXMLRPCServer(("0.0.0.0", port), logRequests=verbose)
server.register_instance(Methods(mm_server_port))
_, port = server.server_address
print("serving on port", port)
server.serve_forever()


if __name__ == "__main__":
import argparse

ap = argparse.ArgumentParser()
ap.add_argument("-p", "--port", type=int, default=0)
ap.add_argument("-v", "--verbose", type=bool, default=True)
args = ap.parse_args()
main(port=args.port, verbose=args.verbose)

Loading…
Cancel
Save