|
@@ -49,7 +49,7 @@ class Methods: |
|
|
def set_is_grad(self, key, is_grad): |
|
|
def set_is_grad(self, key, is_grad): |
|
|
""" |
|
|
""" |
|
|
Mark send/recv need gradiants by key. |
|
|
Mark send/recv need gradiants by key. |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
:param key: key to match send/recv op. |
|
|
:param key: key to match send/recv op. |
|
|
:param is_grad: whether this op need grad. |
|
|
:param is_grad: whether this op need grad. |
|
|
""" |
|
|
""" |
|
@@ -61,7 +61,7 @@ class Methods: |
|
|
def check_is_grad(self, key): |
|
|
def check_is_grad(self, key): |
|
|
""" |
|
|
""" |
|
|
Check whether send/recv need gradiants. |
|
|
Check whether send/recv need gradiants. |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
:param key: key to match send/recv op. |
|
|
:param key: key to match send/recv op. |
|
|
""" |
|
|
""" |
|
|
with self.lock: |
|
|
with self.lock: |
|
@@ -86,7 +86,7 @@ class Methods: |
|
|
def check_remote_tracer(self, key): |
|
|
def check_remote_tracer(self, key): |
|
|
""" |
|
|
""" |
|
|
Get tracer dict for send/recv op. |
|
|
Get tracer dict for send/recv op. |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
:param key: key to match send/recv op. |
|
|
:param key: key to match send/recv op. |
|
|
""" |
|
|
""" |
|
|
with self.lock: |
|
|
with self.lock: |
|
@@ -99,7 +99,7 @@ class Methods: |
|
|
def group_barrier(self, key, size): |
|
|
def group_barrier(self, key, size): |
|
|
""" |
|
|
""" |
|
|
A barrier wait for all group member. |
|
|
A barrier wait for all group member. |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
:param key: group key to match each other. |
|
|
:param key: group key to match each other. |
|
|
:param size: group size. |
|
|
:param size: group size. |
|
|
""" |
|
|
""" |
|
@@ -136,7 +136,7 @@ class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer): |
|
|
def _start_server(py_server_port, mm_server_port, queue): |
|
|
def _start_server(py_server_port, mm_server_port, queue): |
|
|
""" |
|
|
""" |
|
|
Start python distributed server and multiple machine server. |
|
|
Start python distributed server and multiple machine server. |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
:param py_server_port: python server port. |
|
|
:param py_server_port: python server port. |
|
|
:param mm_server_port: multiple machine server port. |
|
|
:param mm_server_port: multiple machine server port. |
|
|
:param queue: server port will put in this queue, puts exception when process fails. |
|
|
:param queue: server port will put in this queue, puts exception when process fails. |
|
@@ -205,7 +205,7 @@ class Client: |
|
|
def set_is_grad(self, key, is_grad): |
|
|
def set_is_grad(self, key, is_grad): |
|
|
""" |
|
|
""" |
|
|
Mark send/recv need gradiants by key. |
|
|
Mark send/recv need gradiants by key. |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
:param key: key to match send/recv op. |
|
|
:param key: key to match send/recv op. |
|
|
:param is_grad: whether this op need grad. |
|
|
:param is_grad: whether this op need grad. |
|
|
""" |
|
|
""" |
|
@@ -214,7 +214,7 @@ class Client: |
|
|
def check_is_grad(self, key): |
|
|
def check_is_grad(self, key): |
|
|
""" |
|
|
""" |
|
|
Check whether send/recv need gradiants. |
|
|
Check whether send/recv need gradiants. |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
:param key: key to match send/recv op. |
|
|
:param key: key to match send/recv op. |
|
|
""" |
|
|
""" |
|
|
return self.proxy.check_is_grad(key) |
|
|
return self.proxy.check_is_grad(key) |
|
@@ -231,7 +231,7 @@ class Client: |
|
|
def check_remote_tracer(self, key): |
|
|
def check_remote_tracer(self, key): |
|
|
""" |
|
|
""" |
|
|
Get tracer dict for send/recv op. |
|
|
Get tracer dict for send/recv op. |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
:param key: key to match send/recv op. |
|
|
:param key: key to match send/recv op. |
|
|
""" |
|
|
""" |
|
|
return self.proxy.check_remote_tracer(key) |
|
|
return self.proxy.check_remote_tracer(key) |
|
@@ -239,7 +239,7 @@ class Client: |
|
|
def group_barrier(self, key, size): |
|
|
def group_barrier(self, key, size): |
|
|
""" |
|
|
""" |
|
|
A barrier wait for all group member. |
|
|
A barrier wait for all group member. |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
:param key: group key to match each other. |
|
|
:param key: group key to match each other. |
|
|
:param size: group size. |
|
|
:param size: group size. |
|
|
""" |
|
|
""" |
|
@@ -252,3 +252,22 @@ class Client: |
|
|
def user_get(self, key): |
|
|
def user_get(self, key): |
|
|
"""Get user defined key-value pairs across processes.""" |
|
|
"""Get user defined key-value pairs across processes.""" |
|
|
return self.proxy.user_get(key) |
|
|
return self.proxy.user_get(key) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(port=0, verbose=True): |
|
|
|
|
|
mm_server_port = create_mm_server("0.0.0.0", 0) |
|
|
|
|
|
server = ThreadXMLRPCServer(("0.0.0.0", port), logRequests=verbose) |
|
|
|
|
|
server.register_instance(Methods(mm_server_port)) |
|
|
|
|
|
_, port = server.server_address |
|
|
|
|
|
print("serving on port", port) |
|
|
|
|
|
server.serve_forever() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
import argparse |
|
|
|
|
|
|
|
|
|
|
|
ap = argparse.ArgumentParser() |
|
|
|
|
|
ap.add_argument("-p", "--port", type=int, default=0) |
|
|
|
|
|
ap.add_argument("-v", "--verbose", type=bool, default=True) |
|
|
|
|
|
args = ap.parse_args() |
|
|
|
|
|
main(port=args.port, verbose=args.verbose) |