You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pytorch.py 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. import copy
  11. import functools
  12. import os
  13. from typing import Any, Callable, List, Optional, Tuple
  14. import torch
  15. from torch.utils.cpp_extension import load as load_torch_extension
  16. import megengine._internal as mgb
  17. from megengine._internal import CompGraph
  18. from megengine._internal.mgb import CompGraphCallbackValueProxy
  19. from ...core import Parameter, Tensor, get_default_device
  20. from ..module import Module
  21. from .utils import device_to_torch_device, torch_dtype_to_numpy_dtype
  22. # A global dict to map opr during graph copy
  23. _copy_dict = {}
  24. @functools.lru_cache(None)
  25. def _get_torch_mem_fwd_lib():
  26. source_file = os.path.join(os.path.dirname(__file__), "torch_mem_fwd.cpp")
  27. return load_torch_extension(
  28. "torch_mem_fwd",
  29. [source_file],
  30. extra_include_paths=[mgb.config.get_include_path()],
  31. )
  32. def inp_mem_fwd(pubapi_dev_tensor_ptr: int) -> torch.Tensor:
  33. """Forward a MegBrain tensor to torch tensor
  34. :param pubapi_dev_tensor_ptr: pointer to MegBrain tensor
  35. """
  36. return _get_torch_mem_fwd_lib().inp_mem_fwd(pubapi_dev_tensor_ptr)
  37. def oup_mem_fwd(
  38. pubapi_dev_tensor_ptr: int, tensor: torch.Tensor, keep_data_ptr: bool = True
  39. ) -> None:
  40. """Forward a torch tensor to a contiguous MegBrain tensor
  41. :param pubapi_dev_tensor_ptr: Pointer to the MegBrain tensor
  42. :param tensor: The input torch tensor
  43. :param keep_data_ptr: if True, memory copy is not allowed here,
  44. thus the input torch tensor must be contiguous also.
  45. defaults to True
  46. """
  47. _get_torch_mem_fwd_lib().oup_mem_fwd(pubapi_dev_tensor_ptr, tensor, keep_data_ptr)
  48. def torch_param_to_mge(
  49. name: str, param: torch.nn.Parameter, device, comp_graph: CompGraph
  50. ) -> Parameter:
  51. """Convert a torch parameter to a megengine parameter
  52. :param name: parametr name
  53. :param param: torch parameter
  54. :param device: the device on which the megengine parameter is,
  55. should be physically the same as the one on torch parameter
  56. :param comp_graph: the owner graph of megengine parameter
  57. :return: megengine parameter
  58. """
  59. assert isinstance(param, torch.nn.Parameter)
  60. dtype = torch_dtype_to_numpy_dtype(param.dtype)
  61. mge_param = Parameter(None, dtype=dtype)
  62. shared_nd = mge_param._Tensor__val
  63. oup_mem_fwd(shared_nd.pubapi_dev_tensor_ptr, param.data, True)
  64. return mge_param
  65. class _PyTorchSubgraphGradOpr(mgb.craniotome.CraniotomeBase):
  66. __nr_inputs__ = None
  67. __nr_outputs__ = None
  68. __allow_duplicate__ = False
  69. __disable_sys_mem_alloc__ = True
  70. __is_dynamic_output_shape__ = True
  71. _forward_opr = None # type: PyTorchSubgraphImplOpr
  72. _shape_infer_func = None
  73. _condensed_out_grad_idx = None # type: List[Optional[int]]
  74. _forward_input_cnt = None
  75. _forward_output_cnt = None
  76. _output_grad_cnt = None
  77. _param_cnt = None
  78. def setup(
  79. self, forward_opr, condensed_out_grad_idx: List[Optional[int]], infer_shape=None
  80. ):
  81. self._forward_opr = forward_opr
  82. self._forward_input_cnt = forward_opr.input_cnt
  83. self._forward_output_cnt = forward_opr.output_cnt
  84. self._param_cnt = forward_opr.param_cnt
  85. self._output_grad_cnt = sum([idx is not None for idx in condensed_out_grad_idx])
  86. self.__nr_inputs__ = (
  87. self._forward_input_cnt
  88. + self._param_cnt
  89. + self._forward_output_cnt
  90. + self._output_grad_cnt
  91. )
  92. self.__nr_outputs__ = self._forward_input_cnt + self._param_cnt
  93. self._forward_opr = forward_opr
  94. self._condensed_out_grad_idx = condensed_out_grad_idx
  95. self._shape_infer_func = infer_shape
  96. if infer_shape is not None:
  97. type(self).__is_dynamic_output_shape__ = False
  98. def execute(
  99. self,
  100. inputs: Tuple[CompGraphCallbackValueProxy, ...],
  101. outputs: Tuple[mgb.SharedND, ...],
  102. ):
  103. assert self._forward_opr._last_forward_inputs is not None
  104. assert self._forward_opr._last_forward_outputs is not None
  105. if self._forward_opr._last_forward_outputs is None:
  106. self._forward_opr.execute(inputs[: self.__nr_outputs__], None)
  107. out_grads = [
  108. inp_mem_fwd(inputs[idx].pubapi_dev_tensor_ptr) if idx else None
  109. for idx in self._condensed_out_grad_idx
  110. ]
  111. grads = torch.autograd.grad(
  112. self._forward_opr._last_forward_outputs,
  113. self._forward_opr._last_forward_inputs
  114. + self._forward_opr._last_forward_params,
  115. out_grads, # type: ignore
  116. only_inputs=True,
  117. allow_unused=True,
  118. )
  119. for ovar, oten in zip(outputs, grads):
  120. oup_mem_fwd(ovar.pubapi_dev_tensor_ptr, oten)
  121. def grad(self, wrt_idx, inputs, outputs, out_grad):
  122. raise NotImplementedError("Apply grad to a grad opr is not supported")
  123. def infer_shape(self, inp_shapes):
  124. if callable(self._shape_infer_func):
  125. return self._shape_infer_func(inp_shapes)
  126. raise NotImplementedError(
  127. "No shape inference function specified on PyTorchSubgraphImplOpr"
  128. )
  129. def copy(self):
  130. ret = type(self)()
  131. d0 = self.__dict__.copy()
  132. d0.pop("this")
  133. d0.pop("_forward_opr")
  134. later_copy = self._forward_opr in _copy_dict
  135. if later_copy:
  136. assert len(_copy_dict) == 1
  137. forward_opr_copy = _copy_dict[self._forward_opr]
  138. else:
  139. forward_opr_copy = self._forward_opr
  140. ret.__dict__["_forward_opr"] = forward_opr_copy
  141. ret.__dict__.update(copy.deepcopy(d0))
  142. _copy_dict[self] = ret
  143. if later_copy:
  144. forward_opr_copy._grad_opr = ret
  145. _copy_dict.clear()
  146. return ret
  147. class PyTorchSubgraphImplOpr(mgb.craniotome.CraniotomeBase):
  148. # pylint: disable=abstract-method
  149. """This is a pytorch module wrapper to operator"""
  150. __nr_inputs__ = None # type: int
  151. __nr_outputs__ = None # type: int
  152. __allow_duplicate__ = False
  153. __disable_sys_mem_alloc__ = True
  154. __is_dynamic_output_shape__ = True
  155. _grad_opr = None
  156. _func = None # type: Callable[[Any], Any]
  157. input_cnt = None # type: int
  158. output_cnt = None # type: int
  159. param_cnt = None # type: int
  160. _shape_infer_func = None
  161. _last_forward_inputs = None
  162. _last_forward_outputs = None # type: List[torch.Tensor]
  163. _last_forward_params = None # type: List[torch.Tensor]
  164. def setup(self, *, input_cnt, output_cnt, func, params, infer_shape=None):
  165. """Setup the operator by accepted kwargs
  166. :param input_cnt: input count of torch module
  167. :param output_cnt: output count of torch module
  168. :param func: a callable object accept inputs and returns outputs
  169. usually a torch module itself
  170. :param params: parameters of the torch module
  171. :param infer_shape: a callable infers output shapes from input shapes,
  172. defaults to None
  173. """
  174. param_cnt = len(params)
  175. self.input_cnt = input_cnt
  176. self.output_cnt = output_cnt
  177. self.param_cnt = param_cnt
  178. self.__nr_inputs__ = input_cnt + param_cnt
  179. self.__nr_outputs__ = output_cnt
  180. self._func = func
  181. self._shape_infer_func = infer_shape
  182. if infer_shape is not None:
  183. type(self).__is_dynamic_output_shape__ = False
  184. self._last_forward_params = params
  185. def execute(
  186. self,
  187. inputs: Tuple[CompGraphCallbackValueProxy, ...],
  188. outputs: Optional[Tuple[mgb.SharedND, ...]],
  189. ):
  190. """execute the operator, read values from *inputs*,
  191. forward them to torch tensor and do execution by self.func
  192. and forward results to outputs
  193. :param inputs: values for each input var
  194. :param outputs: values for each output var
  195. """
  196. input_value_proxys = inputs[: self.input_cnt]
  197. input_torch_tensors = [
  198. inp_mem_fwd(ivar.pubapi_dev_tensor_ptr).requires_grad_()
  199. for ivar in input_value_proxys
  200. ]
  201. output_torch_tensors = self._func(*input_torch_tensors)
  202. if isinstance(output_torch_tensors, torch.Tensor):
  203. output_torch_tensors = [output_torch_tensors]
  204. # `execute` may be called in _PyTorchSubgraphGradOp with None as outputs
  205. if outputs:
  206. for ovar, oten in zip(outputs, output_torch_tensors):
  207. oup_mem_fwd(ovar.pubapi_dev_tensor_ptr, oten)
  208. # Retain input / output tensors for backward
  209. self._last_forward_inputs = input_torch_tensors
  210. self._last_forward_outputs = output_torch_tensors
  211. def grad(
  212. self,
  213. wrt_idx,
  214. inputs: Tuple[mgb.SymbolVar, ...],
  215. outputs: Tuple[mgb.SymbolVar, ...],
  216. out_grads: Tuple[mgb.SymbolVar, ...],
  217. ):
  218. """generate a grad opr which calculates grad by torch.autograd.grad and cache it
  219. :param wrt_idx: the input var with respect to which the gradient should
  220. be computed
  221. :param inputs: operator inputs
  222. :param outputs: operator outputs
  223. :param out_grads: gradients of each output var
  224. :return: an initialized grad opr
  225. """
  226. if self._grad_opr is None:
  227. condensed_out_grad = []
  228. condensed_out_grad_idx = [] # type: List[Optional[int]]
  229. idx = self.__nr_inputs__ + len(outputs)
  230. for out_grad in out_grads:
  231. if out_grad is None:
  232. condensed_out_grad_idx.append(None)
  233. else:
  234. condensed_out_grad.append(out_grad)
  235. condensed_out_grad_idx.append(idx)
  236. idx += 1
  237. self._grad_opr = _PyTorchSubgraphGradOpr.make(
  238. *(inputs + outputs + tuple(condensed_out_grad)),
  239. forward_opr=self,
  240. condensed_out_grad_idx=condensed_out_grad_idx,
  241. )
  242. return self._grad_opr
  243. def infer_shape(self, inp_shapes):
  244. """infer output shape from input shapes
  245. :param inp_shapes: input shapes as tuple
  246. :return: output shapes
  247. """
  248. if callable(self._shape_infer_func):
  249. return self._shape_infer_func(inp_shapes)
  250. raise NotImplementedError(
  251. "No shape inference function specified on PyTorchSubgraphImplOpr"
  252. )
  253. def copy(self):
  254. ret = type(self)()
  255. d0 = self.__dict__.copy()
  256. d0.pop("this")
  257. ret.__dict__["_last_forward_inputs"] = d0.pop("_last_forward_inputs")
  258. ret.__dict__["_last_forward_outputs"] = d0.pop("_last_forward_outputs")
  259. ret.__dict__["_last_forward_params"] = d0.pop("_last_forward_params")
  260. ret.__dict__["_func"] = d0.pop("_func")
  261. d0.pop("_grad_opr")
  262. later_copy = self._grad_opr in _copy_dict
  263. if later_copy:
  264. assert len(_copy_dict) == 1
  265. grad_opr_copy = _copy_dict[self._grad_opr]
  266. else:
  267. grad_opr_copy = self._grad_opr
  268. ret.__dict__["_grad_opr"] = grad_opr_copy
  269. ret.__dict__.update(copy.deepcopy(d0))
  270. _copy_dict[self] = ret
  271. if later_copy:
  272. grad_opr_copy._forward_opr = ret
  273. _copy_dict.clear()
  274. return ret
  275. class PyTorchModule(Module):
  276. """Wrap a pytorch module as megengine module
  277. :param torch_module: torch module to be wrapped
  278. :param device: target device this module would be in
  279. :param output_cnt: output count of this module
  280. :param input_shape: input shape inferrer
  281. :param comp_graph: target comp_graph on which this module would be in
  282. """
  283. __torch_module = None # type: torch.nn.Module
  284. __output_cnt = None
  285. __infer_shape = None
  286. __comp_graph = None
  287. __device = None
  288. _torch_params = None
  289. _param_inputs = None
  290. _name_param_list = None # type: List[Tuple[str, Parameter]]
  291. def __init__(
  292. self,
  293. torch_module,
  294. device=None,
  295. output_cnt=1,
  296. *,
  297. infer_shape=None,
  298. comp_graph=None
  299. ):
  300. super().__init__()
  301. if not isinstance(torch_module, torch.nn.Module):
  302. raise TypeError(
  303. "torch_module should either be an instance of torch.nn.Module "
  304. "or its subclass"
  305. )
  306. self.__torch_module = torch_module
  307. if not isinstance(output_cnt, int):
  308. raise TypeError("output_cnt must be int")
  309. if output_cnt <= 0:
  310. raise ValueError("output_cnt must be greater than zero")
  311. self.__output_cnt = output_cnt
  312. if infer_shape and not callable(infer_shape):
  313. raise TypeError("infer_shape should either be None or a callable object")
  314. self.__infer_shape = infer_shape
  315. if comp_graph and not isinstance(comp_graph, mgb.CompGraph):
  316. raise TypeError("comp_graph shoud eighter be None or a mgb.CompGraph")
  317. self.__comp_graph = comp_graph
  318. self._torch_params = []
  319. self._param_inputs = []
  320. self._name_param_list = []
  321. if device is None:
  322. device = get_default_device()
  323. if isinstance(device, str):
  324. device = mgb.comp_node(device)
  325. self.device = device
  326. def init_params(self):
  327. """forward torch parameters to megengine parameters and store,
  328. would be called in constructor and setter of device
  329. """
  330. self._torch_params = []
  331. self._param_inputs = []
  332. self._name_param_list = []
  333. for name, torch_param in self.__torch_module.named_parameters(recurse=True):
  334. formated_name = "_torch_{}_{}".format(id(self.__torch_module), name)
  335. mge_param = torch_param_to_mge(
  336. formated_name, torch_param, self.device, self.__comp_graph
  337. )
  338. self._param_inputs.append(mge_param)
  339. self._torch_params.append(torch_param)
  340. self._name_param_list.append((name, mge_param))
  341. def get_param_by_name(self, param_name: str) -> Parameter:
  342. """find parameter by its name
  343. :param param_name: name of parameter
  344. :return: the parameter
  345. """
  346. for name, param in self._name_param_list:
  347. if param_name == name:
  348. return param
  349. raise KeyError("Cannot find param: {}".format(param_name))
  350. def forward(self, *inputs):
  351. """apply the module on given inputs
  352. :return: output vars
  353. """
  354. param_inputs = [param._symvar for param in self._param_inputs]
  355. inputs = [tensor._symvar for tensor in list(inputs)] + param_inputs
  356. out = PyTorchSubgraphImplOpr.make(
  357. *inputs,
  358. input_cnt=len(inputs) - len(param_inputs),
  359. output_cnt=self.__output_cnt,
  360. func=self.__torch_module.forward,
  361. params=self._torch_params,
  362. infer_shape=self.__infer_shape,
  363. )
  364. if isinstance(out, mgb.SymbolVar):
  365. return Tensor(out)
  366. assert isinstance(out, collections.Iterable)
  367. return [Tensor(sym) for sym in out]
  368. def get_device(self):
  369. """get the device this module belongs to"""
  370. return self.__device
  371. def set_device(self, device: mgb.CompNode):
  372. """set the device and move torch module to corresponding device"""
  373. touch_device = device_to_torch_device(device)
  374. self.__torch_module.to(device=touch_device)
  375. self.__device = device
  376. self.init_params()
  377. device = property(get_device, set_device)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台