You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_distributed.py 2.6 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import multiprocessing as mp
  10. import subprocess
  11. import sys
  12. import numpy as np
  13. def worker(master_ip, master_port, world_size, rank, dev, trace):
  14. import megengine.distributed as dist
  15. import megengine.functional as F
  16. from megengine import is_cuda_available
  17. from megengine import jit
  18. from megengine.module import Linear, Module
  19. from megengine.optimizer import SGD
  20. if not is_cuda_available():
  21. return
  22. class MLP(Module):
  23. def __init__(self):
  24. super().__init__()
  25. self.fc0 = Linear(3 * 224 * 224, 500)
  26. self.fc1 = Linear(500, 10)
  27. def forward(self, x):
  28. x = self.fc0(x)
  29. x = F.relu(x)
  30. x = self.fc1(x)
  31. return x
  32. dist.init_process_group(
  33. master_ip=master_ip, master_port=3456, world_size=world_size, rank=rank, dev=dev
  34. )
  35. net = MLP()
  36. opt = SGD(net.parameters(requires_grad=True), lr=0.02)
  37. data = np.random.random((64, 3 * 224 * 224)).astype(np.float32)
  38. label = np.random.randint(0, 10, size=(64,)).astype(np.int32)
  39. jit.trace.enabled = trace
  40. @jit.trace()
  41. def train_func(data, label):
  42. pred = net(data)
  43. loss = F.cross_entropy_with_softmax(pred, label)
  44. opt.backward(loss)
  45. return loss
  46. for i in range(5):
  47. opt.zero_grad()
  48. loss = train_func(data, label)
  49. opt.step()
  50. def start_workers(worker, world_size, trace=False):
  51. def run_subproc(rank):
  52. cmd = "from test.integration.test_distributed import worker\n"
  53. cmd += "worker('localhost', 3456, {}, {}, {}, {})".format(
  54. world_size, rank, rank, "True" if trace else "False"
  55. )
  56. cmd = ["python3", "-c", cmd]
  57. ret = subprocess.run(
  58. cmd, stdout=sys.stdout, stderr=sys.stderr, universal_newlines=True
  59. )
  60. assert ret.returncode == 0, "subprocess failed"
  61. procs = []
  62. for rank in range(world_size):
  63. p = mp.Process(target=run_subproc, args=(rank,))
  64. p.start()
  65. procs.append(p)
  66. for p in procs:
  67. p.join()
  68. assert p.exitcode == 0
  69. def test_distributed():
  70. start_workers(worker, 2, trace=True)
  71. start_workers(worker, 2, trace=False)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台