Browse Source

feat(mge/distributed): add user_pop function to save device memory

BREAKING CHANGE:

GitOrigin-RevId: 0a8e406da5
release-1.5
Megvii Engine Team huangxinda 3 years ago
parent
commit
df79334cae
2 changed files with 28 additions and 0 deletions
  1. +14
    -0
      imperative/python/megengine/distributed/server.py
  2. +14
    -0
      imperative/python/test/unit/distributed/test_distributed.py

+ 14
- 0
imperative/python/megengine/distributed/server.py View File

@@ -145,6 +145,16 @@ class Methods:
del self.bcast_dict[key] del self.bcast_dict[key]
return val return val


def _del(self, key):
with self.lock:
del self.user_dict[key]

# thread safe function
def user_pop(self, key):
ret = self.user_get(key)
self._del(key)
return ret



class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer): class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer):
pass pass
@@ -274,6 +284,10 @@ class Client:
"""Get user defined key-value pairs across processes.""" """Get user defined key-value pairs across processes."""
return self.proxy.user_get(key) return self.proxy.user_get(key)


def user_pop(self, key):
"""Get user defined key-value pairs and delete the resources when the get is done"""
return self.proxy.user_pop(key)

def bcast_val(self, val, key, size): def bcast_val(self, val, key, size):
idx = self.bcast_dict[key] + 1 idx = self.bcast_dict[key] + 1
self.bcast_dict[key] = idx self.bcast_dict[key] = idx


+ 14
- 0
imperative/python/test/unit/distributed/test_distributed.py View File

@@ -219,3 +219,17 @@ def test_collect_results(early_return, output_size):
else [[dev] * output_size for dev in range(world_size)] else [[dev] * output_size for dev in range(world_size)]
) )
assert results == expects assert results == expects


@pytest.mark.require_ngpu(2)
@pytest.mark.isolated_distributed
def test_user_set_pop():
@dist.launcher
def worker():
# set in race condition
dist.get_client().user_set("foo", 1)
if dist.get_rank() == 1:
ret = dist.get_client().user_pop("foo")
assert ret == 1

worker()

Loading…
Cancel
Save