|
|
@@ -36,6 +36,7 @@ class Methods: |
|
|
|
self.dict_barrier_counter = defaultdict(int) |
|
|
|
self.dict_barrier_event = defaultdict(threading.Event) |
|
|
|
self.user_dict = defaultdict(partial(Future, False)) |
|
|
|
self.bcast_dict = {} |
|
|
|
|
|
|
|
def connect(self): |
|
|
|
"""Method for checking connection success.""" |
|
|
@@ -127,6 +128,23 @@ class Methods: |
|
|
|
future = self.user_dict[key] |
|
|
|
return future.get() |
|
|
|
|
|
|
|
def bcast_val(self, val, key, size): |
|
|
|
with self.lock: |
|
|
|
if key not in self.bcast_dict: |
|
|
|
self.bcast_dict[key] = [Future(False), size] |
|
|
|
arr = self.bcast_dict[key] |
|
|
|
if val is not None: |
|
|
|
arr[0].set(val) |
|
|
|
val = None |
|
|
|
else: |
|
|
|
val = arr[0].get() |
|
|
|
with self.lock: |
|
|
|
cnt = arr[1] - 1 |
|
|
|
arr[1] = cnt |
|
|
|
if cnt == 0: |
|
|
|
del self.bcast_dict[key] |
|
|
|
return val |
|
|
|
|
|
|
|
|
|
|
|
class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer): |
|
|
|
pass |
|
|
@@ -142,7 +160,9 @@ def _start_server(py_server_port, queue): |
|
|
|
""" |
|
|
|
try: |
|
|
|
mm_server_port = create_mm_server("0.0.0.0", 0) |
|
|
|
server = ThreadXMLRPCServer(("0.0.0.0", py_server_port), logRequests=False) |
|
|
|
server = ThreadXMLRPCServer( |
|
|
|
("0.0.0.0", py_server_port), logRequests=False, allow_none=True |
|
|
|
) |
|
|
|
server.register_instance(Methods(mm_server_port)) |
|
|
|
_, py_server_port = server.server_address |
|
|
|
queue.put((py_server_port, mm_server_port)) |
|
|
@@ -185,13 +205,14 @@ class Client: |
|
|
|
self.master_ip = master_ip |
|
|
|
self.port = port |
|
|
|
self.connect() |
|
|
|
self.bcast_dict = defaultdict(lambda: 0) |
|
|
|
|
|
|
|
def connect(self): |
|
|
|
"""Check connection success.""" |
|
|
|
while True: |
|
|
|
try: |
|
|
|
self.proxy = ServerProxy( |
|
|
|
"http://{}:{}".format(self.master_ip, self.port) |
|
|
|
"http://{}:{}".format(self.master_ip, self.port), allow_none=True |
|
|
|
) |
|
|
|
if self.proxy.connect(): |
|
|
|
break |
|
|
@@ -247,22 +268,17 @@ class Client: |
|
|
|
|
|
|
|
def user_set(self, key, val): |
|
|
|
"""Set user defined key-value pairs across processes.""" |
|
|
|
self.proxy.user_set(key, val) |
|
|
|
return self.proxy.user_set(key, val) |
|
|
|
|
|
|
|
def user_get(self, key): |
|
|
|
"""Get user defined key-value pairs across processes.""" |
|
|
|
return self.proxy.user_get(key) |
|
|
|
|
|
|
|
def bcast_val(self, val, key, size): |
|
|
|
if val is not None: |
|
|
|
self.user_set(key + "_sync", val) |
|
|
|
self.group_barrier(key, size) |
|
|
|
self.group_barrier(key, size) |
|
|
|
else: |
|
|
|
self.group_barrier(key, size) |
|
|
|
val = self.user_get(key + "_sync") |
|
|
|
self.group_barrier(key, size) |
|
|
|
return val |
|
|
|
idx = self.bcast_dict[key] + 1 |
|
|
|
self.bcast_dict[key] = idx |
|
|
|
key = key + "_bcast_" + str(idx) |
|
|
|
return self.proxy.bcast_val(val, key, size) |
|
|
|
|
|
|
|
|
|
|
|
def main(port=0, verbose=True): |
|
|
|