Browse Source

Merge pull request #435 from MegEngine/try-import

tags/v1.9.0
XindaH GitHub 3 years ago
parent
commit
ea91babbce
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 235 additions and 39 deletions
  1. +3
    -0
      .gitattributes
  2. +0
    -2
      .github/workflows/ci.yml
  3. +1
    -1
      README.md
  4. +3
    -3
      README_CN.md
  5. +1
    -2
      ci/cmake.sh
  6. +54
    -0
      dnn/src/aarch64/relayout/opr_impl.cpp
  7. +3
    -0
      dnn/test/aarch64/relayout.cpp
  8. +3
    -1
      imperative/python/megengine/device.py
  9. +22
    -18
      imperative/python/megengine/functional/math.py
  10. +19
    -10
      imperative/python/megengine/module/init.py
  11. +28
    -1
      imperative/python/test/unit/module/test_init.py
  12. +15
    -0
      lite/include/lite/global.h
  13. +16
    -1
      lite/lite-c/include/lite-c/global_c.h
  14. +15
    -0
      lite/lite-c/src/global.cpp
  15. +20
    -0
      lite/pylite/megenginelite/global_setting.py
  16. +0
    -0
      lite/pylite/test/test_network_device.py
  17. +31
    -0
      lite/src/global.cpp
  18. +1
    -0
      lite/test/test_network.cpp

+ 3
- 0
.gitattributes View File

@@ -21,3 +21,6 @@ ci/resource/prof/model_with_err_assert.mdl filter=lfs diff=lfs merge=lfs -text
ci/resource/prof/test_mge.mge filter=lfs diff=lfs merge=lfs -text
lite/test/resource/lite/ax_models/64-58063ce2.axe filter=lfs diff=lfs merge=lfs -text
imperative/python/test/unit/module/MagicMindRuntimeOprTest.GraphShapeMutable.mlu filter=lfs diff=lfs merge=lfs -text
lite/test/resource/lite/ax_data_input.npy filter=lfs diff=lfs merge=lfs -text
lite/test/resource/lite/ax_data_output.npy filter=lfs diff=lfs merge=lfs -text
lite/test/resource/lite/ax_model.mge filter=lfs diff=lfs merge=lfs -text

+ 0
- 2
.github/workflows/ci.yml View File

@@ -29,7 +29,6 @@ jobs:
uses: actions/checkout@v2
- name: Checkout submodules
run: |
apt update&&apt install ninja-build
./third_party/prepare.sh
./third_party/install-mkl.sh
- name: Build MegEngine
@@ -58,7 +57,6 @@ jobs:
uses: actions/checkout@v2
- name: Checkout submodules
run: |
apt update&&apt install ninja-build
./third_party/prepare.sh
./third_party/install-mkl.sh
- name: Build MegEngine


+ 1
- 1
README.md View File

@@ -12,7 +12,7 @@ MegEngine is a fast, scalable and easy-to-use deep learning framework, with auto

## Installation

**NOTE:** MegEngine now supports Python installation on Linux-64bit/Windows-64bit/MacOS(CPU-Only)-10.14+/Android 7+(CPU-Only) platforms with Python from 3.5 to 3.8. On Windows 10 you can either install the Linux distribution through [Windows Subsystem for Linux (WSL)](https://docs.microsoft.com/en-us/windows/wsl) or install the Windows distribution directly. Many other platforms are supported for inference.
**NOTE:** MegEngine now supports Python installation on Linux-64bit/Windows-64bit/MacOS(CPU-Only)-10.14+ platforms with Python from 3.5 to 3.8. On Windows 10 you can either install the Linux distribution through [Windows Subsystem for Linux (WSL)](https://docs.microsoft.com/en-us/windows/wsl) or install the Windows distribution directly. Many other platforms are supported for inference.

### Binaries



+ 3
- 3
README_CN.md View File

@@ -13,7 +13,7 @@ MegEngine 是一个快速、可拓展、易于使用且支持自动求导的深

## 安装说明

**注意:** MegEngine 现在支持在 Linux-64bit/Windows-64bit/macos-10.14/Android 7+ 及其以上 (MacOS/Android只支持cpu) 等平台上安装 Python 包,支持Python3.5 到 Python3.8。对于 Windows 10 用户,可以通过安装 [Windows Subsystem for Linux (WSL)](https://docs.microsoft.com/en-us/windows/wsl) 进行体验,同时我们也原生支持Windows。MegEngine 也支持在很多其它平台上进行推理运算。
**注意:** MegEngine 现在支持在 Linux-64bit/Windows-64bit/macos-10.14及其以上 (MacOS只支持cpu) 等平台上安装 Python 包,支持Python3.5 到 Python3.8。对于 Windows 10 用户,可以通过安装 [Windows Subsystem for Linux (WSL)](https://docs.microsoft.com/en-us/windows/wsl) 进行体验,同时我们也原生支持Windows。MegEngine 也支持在很多其它平台上进行推理运算。

### 通过包管理器安装

@@ -26,8 +26,8 @@ python3 -m pip install megengine -f https://megengine.org.cn/whl/mge.html

## 通过源码编译安装

* CMake 编译细节请参考 [BUILD_README.md](scripts/cmake-build/BUILD_README.md)
* Python 绑定编译细节请参考 [BUILD_PYTHON_WHL_README.md](scripts/whl/BUILD_PYTHON_WHL_README.md)
* CMake编译细节请参考 [BUILD_README.md](scripts/cmake-build/BUILD_README.md)
* Python绑定编译细节请参考 [BUILD_PYTHON_WHL_README.md](scripts/whl/BUILD_PYTHON_WHL_README.md)

## 如何参与贡献



+ 1
- 2
ci/cmake.sh View File

@@ -27,8 +27,7 @@ function build() {
-DMGE_WITH_DISTRIBUTED=${DMGE_WITH_DISTRIBUTED} \
-DMGE_WITH_CUDA=${DMGE_WITH_CUDA} \
-DMGE_WITH_TEST=ON \
-DCMAKE_BUILD_TYPE=RelWithDebInfo \
-DMGE_WITH_CUSTOM_OP=ON
-DCMAKE_BUILD_TYPE=RelWithDebInfo
make -j$(($(nproc) * 2)) -I ${build_dir}
make develop
popd >/dev/null


+ 54
- 0
dnn/src/aarch64/relayout/opr_impl.cpp View File

@@ -363,6 +363,58 @@ static inline void trans_8x4_u16(
vst1q_u16(dst_ptr + 3 * dst_step, row_3);
}

static inline void trans_8x3_u16(
const void* src, void* dst, const size_t src_step, const size_t dst_step) {
uint16_t* src_ptr = (uint16_t*)src;
uint16_t* dst_ptr = (uint16_t*)dst;
uint16x4_t src0 = vld1_u16(src_ptr + 0 * src_step); // A0A1A2A3
uint16x4_t src1 = vld1_u16(src_ptr + 1 * src_step); // B0B1B2B3
uint16x4_t src2 = vld1_u16(src_ptr + 2 * src_step); // C0C1C2C3
uint16x4_t src3 = vld1_u16(src_ptr + 3 * src_step); // D0D1D2D3
uint16x4_t src4 = vld1_u16(src_ptr + 4 * src_step); // E0E1E2E3
uint16x4_t src5 = vld1_u16(src_ptr + 5 * src_step); // F0F1F2F3
uint16x4_t src6 = vld1_u16(src_ptr + 6 * src_step); // G0G1G2G3
// H0H1H2
uint16x4_t src7 =
vreinterpret_u16_u32(vld1_dup_u32((uint32_t*)(src_ptr + 7 * src_step)));
src7 = vld1_lane_u16(src_ptr + 7 * src_step + 2, src7, 2);

uint16x4_t ab_low = vzip1_u16(src0, src1); // A0B0A1B1
uint16x4_t ab_high = vzip2_u16(src0, src1); // A2B2A3B3
uint16x4_t cd_low = vzip1_u16(src2, src3); // C0D0C1D1
uint16x4_t cd_high = vzip2_u16(src2, src3); // C2D2C3D3
uint16x4_t ef_low = vzip1_u16(src4, src5); // E0F0E1F1
uint16x4_t ef_high = vzip2_u16(src4, src5); // E2F2E3F3
uint16x4_t gh_low = vzip1_u16(src6, src7); // G0H0G1H1
uint16x4_t gh_high = vzip2_u16(src6, src7); // G2H2G3

uint16x4_t abcd_0 = vreinterpret_u16_u32(vzip1_u32(
vreinterpret_u32_u16(ab_low),
vreinterpret_u32_u16(cd_low))); // A0B0C0D0
uint16x4_t abcd_1 = vreinterpret_u16_u32(vzip2_u32(
vreinterpret_u32_u16(ab_low),
vreinterpret_u32_u16(cd_low))); // A1B1C1D1
uint16x4_t abcd_2 = vreinterpret_u16_u32(vzip1_u32(
vreinterpret_u32_u16(ab_high),
vreinterpret_u32_u16(cd_high))); // A2B2C2D2
uint16x4_t efgh_0 = vreinterpret_u16_u32(vzip1_u32(
vreinterpret_u32_u16(ef_low),
vreinterpret_u32_u16(gh_low))); // E0F0G0H0
uint16x4_t efgh_1 = vreinterpret_u16_u32(vzip2_u32(
vreinterpret_u32_u16(ef_low),
vreinterpret_u32_u16(gh_low))); // E1F1G1H1
uint16x4_t efgh_2 = vreinterpret_u16_u32(vzip1_u32(
vreinterpret_u32_u16(ef_high),
vreinterpret_u32_u16(gh_high))); // E2F2G2H2

uint16x8_t row_0 = vcombine_u16(abcd_0, efgh_0);
uint16x8_t row_1 = vcombine_u16(abcd_1, efgh_1);
uint16x8_t row_2 = vcombine_u16(abcd_2, efgh_2);

vst1q_u16(dst_ptr + 0 * dst_step, row_0);
vst1q_u16(dst_ptr + 1 * dst_step, row_1);
vst1q_u16(dst_ptr + 2 * dst_step, row_2);
}
} // anonymous namespace

namespace megdnn {
@@ -410,6 +462,8 @@ void transpose_block<Transpose2Byte>(
const size_t dst_stride, size_t block_h, size_t block_w) {
if (block_h == 8 && block_w == 4) {
trans_8x4_u16(src, dst, src_stride, dst_stride);
} else if (block_h == 8 && block_w == 3) {
trans_8x3_u16(src, dst, src_stride, dst_stride);
} else {
transpose_block_fallback(src, dst, src_stride, dst_stride, block_h, block_w);
}


+ 3
- 0
dnn/test/aarch64/relayout.cpp View File

@@ -40,6 +40,9 @@ TEST_F(AARCH64, Relayout) {
TensorLayout dst({1, 54, 112, 256}, {1548288, 28672, 256, 1}, dtype);
checker.execl({src, dst});
}
TensorLayout src_4_3({1, 3, 112, 256}, {3, 1, 1024, 4}, dtype::Uint16());
TensorLayout dst_4_3({1, 3, 112, 256}, {86016, 28672, 256, 1}, dtype::Uint16());
checker.execl({src_4_3, dst_4_3});
}

TEST_F(AARCH64, RelayoutNonContig) {


+ 3
- 1
imperative/python/megengine/device.py View File

@@ -50,7 +50,9 @@ _sh = _stream_helper()


def _valid_device(inp):
if isinstance(inp, str) and re.match("^([cxg]pu|rocm)(\d+|\d+:\d+|x)$", inp):
if isinstance(inp, str) and re.match(
"^([cxg]pu|rocm|multithread)(\d+|\d+:\d+|x)$", inp
):
return True
return False



+ 22
- 18
imperative/python/megengine/functional/math.py View File

@@ -1153,35 +1153,39 @@ def dot(inp1: Tensor, inp2: Tensor) -> Tensor:


def svd(inp: Tensor, full_matrices=False, compute_uv=True) -> Tensor:
r"""Computes the singular value decompositions of input matrix.
r"""Returns a singular value decomposition ``A = USVh`` of a matrix (or a stack of matrices) ``x`` , where ``U`` is a matrix (or a stack of matrices) with orthonormal columns, ``S`` is a vector of non-negative numbers (or stack of vectors), and ``Vh`` is a matrix (or a stack of matrices) with orthonormal rows.

Args:
inp: input matrix, must has shape `[..., M, N]`.
x (Tensor): A input real tensor having the shape ``(..., M, N)`` with ``x.ndim >= 2`` .
full_matrices (bool, optional): If ``False`` , ``U`` and ``Vh`` have the shapes ``(..., M, K)`` and ``(..., K, N)`` , respectively, where ``K = min(M, N)`` . If ``True`` , the shapes are ``(..., M, M)`` and ``(..., N, N)`` , respectively. Default: ``False`` .
compute_uv (bool, optional): Whether or not to compute ``U`` and ``Vh`` in addition to ``S`` . Default: ``True`` .

Note:
* naive does not support ``full_matrices`` and ``compute_uv`` as ``True`` .

Returns:
output matrices, `(U, sigma, V)`.
Returns a tuple ( ``U`` , ``S`` , ``Vh`` ), which are SVD factors ``U`` , ``S``, ``Vh`` of input matrix ``x``. ( ``U`` , ``Vh`` only returned when ``compute_uv`` is True).
``U`` contains matrices orthonormal columns (i.e., the columns are left singular vectors). If ``full_matrices`` is ``True`` , the array must have shape ``(..., M, M)`` . If ``full_matrices`` is ``False`` , the array must have shape ``(..., M, K)`` , where ``K = min(M, N)`` .

Examples:

.. testcode::

import numpy as np
from megengine import tensor
import megengine.functional as F

x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2,3))
_, y, _ = F.svd(x)
print(y.numpy().round(decimals=3))
>>> import numpy as np
>>> x = Tensor(np.random.randn(9, 6))
>>> y = Tensor(np.random.randn(2, 7, 8, 3))

Outputs:

.. testoutput::
Reconstruction based on reduced SVD, 2D case:
>>> U, S, Vh = F.svd(x, full_matrices=False)
>>> print(U._tuple_shape, S._tuple_shape, Vh._tuple_shape)
(9, 6) (6,) (6, 6)

[7.348 1. ]
Reconsturction based on reduced SVD, 4D case:
>>> u, s, vh = F.svd(y, full_matrices=False)
>>> print(u._tuple_shape, s._tuple_shape, vh._tuple_shape)
(2, 7, 8, 3) (2, 7, 3) (2, 7, 3, 3)
"""
op = builtin.SVD(full_matrices=full_matrices, compute_uv=compute_uv)
U, sigma, V = apply(op, inp)
return U, sigma, V
U, S, Vh = apply(op, inp)
return U, S, Vh


def _check_non_finite(inps: Iterable[Tensor], scale=1.0) -> Tensor:


+ 19
- 10
imperative/python/megengine/module/init.py View File

@@ -74,7 +74,7 @@ def calculate_gain(
) -> float:
r"""Returns a recommended gain value (see the table below) for the given nonlinearity
function.
================= ====================================================
nonlinearity gain
================= ====================================================
@@ -126,6 +126,11 @@ def calculate_fan_in_and_fan_out(tensor: Tensor) -> Tuple[float, float]:
r"""Calculates fan_in / fan_out value for given weight tensor. This function assumes
input tensor is stored in ``NCHW`` format.

Note:
The group conv2d kernel shape in MegEngine is ``(G, O/G, I/G, K, K)``. This
function calculates ``fan_out = O/G * K * K`` as default, but PyTorch uses
``fan_out = O * K * K``.

Args:
tensor: weight tensor in ``NCHW`` format.
"""
@@ -141,6 +146,10 @@ def calculate_fan_in_and_fan_out(tensor: Tensor) -> Tuple[float, float]:
fan_in = shape[1]
fan_out = shape[0]
else:
if ndim >= 5:
# ignore the groups dimension of group conv2d and group conv3d
# FIXME: will be wrong for conv3d
shape = shape[1:]
num_input_fmaps = shape[1]
num_output_fmaps = shape[0]
receptive_field_size = 1
@@ -154,7 +163,7 @@ def calculate_fan_in_and_fan_out(tensor: Tensor) -> Tuple[float, float]:
def calculate_correct_fan(tensor: Tensor, mode: str) -> float:
r"""Calculates fan_in / fan_out value for given weight tensor, depending on given
``mode``.
See :func:`calculate_fan_in_and_fan_out` for details.

Args:
@@ -175,11 +184,11 @@ def calculate_correct_fan(tensor: Tensor, mode: str) -> float:
def xavier_uniform_(tensor: Tensor, gain: float = 1.0) -> None:
r"""Fills tensor with random values sampled from :math:`\mathcal{U}(-a, a)`
where
.. math::

a = \text{gain} \times \sqrt{\frac{6}{\text{fan_in} + \text{fan_out}}}
Also known as Glorot initialization. Detailed information can be retrieved from
`Understanding the difficulty of training deep feedforward neural networks` -
Glorot, X. & Bengio, Y. (2010).
@@ -197,11 +206,11 @@ def xavier_uniform_(tensor: Tensor, gain: float = 1.0) -> None:
def xavier_normal_(tensor: Tensor, gain: float = 1.0) -> None:
r"""Fills tensor with random values sampled from
:math:`\mathcal{N}(0, \text{std}^2)` where
.. math::

\text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan_in} + \text{fan_out}}}
Also known as Glorot initialization. Detailed information can be retrieved from
`Understanding the difficulty of training deep feedforward neural networks` -
Glorot, X. & Bengio, Y. (2010).
@@ -220,11 +229,11 @@ def msra_uniform_(
) -> None:
r"""Fills tensor wilth random values sampled from
:math:`\mathcal{U}(-\text{bound}, \text{bound})` where
.. math::

\text{bound} = \sqrt{\frac{6}{(1 + a^2) \times \text{fan_in}}}
Detailed information can be retrieved from
`Delving deep into rectifiers: Surpassing human-level performance on ImageNet
classification`
@@ -251,11 +260,11 @@ def msra_normal_(
) -> None:
r"""Fills tensor wilth random values sampled from
:math:`\mathcal{N}(0, \text{std}^2)` where
.. math::

\text{std} = \sqrt{\frac{2}{(1 + a^2) \times \text{fan_in}}}
Detailed information can be retrieved from
`Delving deep into rectifiers: Surpassing human-level performance on ImageNet
classification`


+ 28
- 1
imperative/python/test/unit/module/test_init.py View File

@@ -10,7 +10,7 @@ import numpy as np
import pytest

from megengine import tensor
from megengine.module import Conv2d, Linear
from megengine.module import Conv1d, Conv2d, Conv3d, Linear
from megengine.module.init import calculate_fan_in_and_fan_out, fill_


@@ -32,7 +32,34 @@ def test_calculate_fan_in_and_fan_out():
with pytest.raises(ValueError):
calculate_fan_in_and_fan_out(l.bias)

l = Conv1d(in_channels=2, out_channels=3, kernel_size=5)
fanin, fanout = calculate_fan_in_and_fan_out(l.weight)
assert fanin == 2 * 5
assert fanout == 3 * 5

# FIXME: will be wrong for group conv1d
# l = Conv1d(in_channels=2, out_channels=4, kernel_size=5, groups=2)
# fanin, fanout = calculate_fan_in_and_fan_out(l.weight)
# assert fanin == 2 // 2 * 5
# assert fanout == 4 // 2 * 5

l = Conv2d(in_channels=2, out_channels=3, kernel_size=(5, 7))
fanin, fanout = calculate_fan_in_and_fan_out(l.weight)
assert fanin == 2 * 5 * 7
assert fanout == 3 * 5 * 7

l = Conv2d(in_channels=2, out_channels=4, kernel_size=(5, 7), groups=2)
fanin, fanout = calculate_fan_in_and_fan_out(l.weight)
assert fanin == 2 // 2 * 5 * 7
assert fanout == 4 // 2 * 5 * 7

# FIXME: will be wrong for conv3d
# l = Conv3d(in_channels=2, out_channels=3, kernel_size=(5, 7, 9))
# fanin, fanout = calculate_fan_in_and_fan_out(l.weight)
# assert fanin == 2 * 5 * 7 * 9
# assert fanout == 3 * 5 * 7 * 9

l = Conv3d(in_channels=2, out_channels=4, kernel_size=(5, 7, 9), groups=2)
fanin, fanout = calculate_fan_in_and_fan_out(l.weight)
assert fanin == 2 // 2 * 5 * 7 * 9
assert fanout == 4 // 2 * 5 * 7 * 9

+ 15
- 0
lite/include/lite/global.h View File

@@ -154,6 +154,21 @@ LITE_API void set_tensor_rt_cache(std::string tensorrt_cache_path);
*/
LITE_API void dump_tensor_rt_cache();

/**
* register the physical and virtual address pair to the mge, some device
* need the map from physical to virtual.
*/
LITE_API bool register_memory_pair(
void* vir_ptr, void* phy_ptr, size_t length, LiteDeviceType device,
LiteBackend backend = LiteBackend::LITE_DEFAULT);

/**
* clear the physical and virtual address pair in mge.
*/
LITE_API bool clear_memory_pair(
void* vir_ptr, void* phy_ptr, LiteDeviceType device,
LiteBackend backend = LiteBackend::LITE_DEFAULT);

} // namespace lite

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 16
- 1
lite/lite-c/include/lite-c/global_c.h View File

@@ -160,9 +160,24 @@ LITE_API int LITE_dump_persistent_cache(const char* cache_path);
* \brief dump the tensorrt policy cache to file
*/
LITE_API int LITE_dump_tensor_rt_cache();
#endif

/**
* register the physical and virtual address pair to the mge, some device
* need the map from physical to virtual.
*/
LITE_API int LITE_register_memory_pair(
void* vir_ptr, void* phy_ptr, size_t length, LiteDeviceType device,
LiteBackend backend);

/**
* clear the physical and virtual address pair in mge.
*/
LITE_API int LITE_clear_memory_pair(
void* phy_ptr, void* vir_ptr, LiteDeviceType device, LiteBackend backend);

#ifdef __cplusplus
}
#endif
#endif

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 15
- 0
lite/lite-c/src/global.cpp View File

@@ -189,4 +189,19 @@ int LITE_dump_tensor_rt_cache() {
LITE_CAPI_END();
}

int LITE_register_memory_pair(
void* vir_ptr, void* phy_ptr, size_t length, LiteDeviceType device,
LiteBackend backend) {
LITE_CAPI_BEGIN();
lite::register_memory_pair(vir_ptr, phy_ptr, length, device, backend);
LITE_CAPI_END();
}

int LITE_clear_memory_pair(
void* phy_ptr, void* vir_ptr, LiteDeviceType device, LiteBackend backend) {
LITE_CAPI_BEGIN();
lite::clear_memory_pair(vir_ptr, phy_ptr, device, backend);
LITE_CAPI_END();
}

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 20
- 0
lite/pylite/megenginelite/global_setting.py View File

@@ -42,6 +42,8 @@ class _GlobalAPI(_LiteCObjBase):
# ('LITE_set_tensor_rt_cache', [c_char_p]),
("LITE_dump_persistent_cache", [c_char_p]),
("LITE_dump_tensor_rt_cache", [c_char_p]),
("LITE_register_memory_pair", [c_void_p, c_void_p, c_size_t, c_int, c_int]),
("LITE_clear_memory_pair", [c_void_p, c_void_p, c_int, c_int]),
]


@@ -121,3 +123,21 @@ class LiteGlobal(object):
@staticmethod
def try_coalesce_all_free_memory():
LiteGlobal._api.LITE_try_coalesce_all_free_memory()

@staticmethod
def register_memory_pair(
vir_ptr, phy_ptr, length, device, backend=LiteBackend.LITE_DEFAULT
):
assert isinstance(vir_ptr, c_void_p) and isinstance(
phy_ptr, c_void_p
), "clear memory pair only accept c_void_p type."
LiteGlobal._api.LITE_register_memory_pair(
vir_ptr, phy_ptr, length, device, backend
)

@staticmethod
def clear_memory_pair(vir_ptr, phy_ptr, device, backend=LiteBackend.LITE_DEFAULT):
assert isinstance(vir_ptr, c_void_p) and isinstance(
phy_ptr, c_void_p
), "clear memory pair only accept c_void_p type."
LiteGlobal._api.LITE_clear_memory_pair(vir_ptr, phy_ptr, device, backend)

lite/pylite/test/test_network_cuda.py → lite/pylite/test/test_network_device.py View File


+ 31
- 0
lite/src/global.cpp View File

@@ -212,6 +212,26 @@ void lite::dump_tensor_rt_cache() {
#endif
}

bool lite::register_memory_pair(
void* vir_ptr, void* phy_ptr, size_t length, LiteDeviceType device,
LiteBackend backend) {
LITE_MARK_USED_VAR(vir_ptr);
LITE_MARK_USED_VAR(phy_ptr);
LITE_MARK_USED_VAR(length);
LITE_MARK_USED_VAR(device);
LITE_MARK_USED_VAR(backend);
LITE_THROW("register_memory_pair is not implement yet!");
}

bool lite::clear_memory_pair(
void* vir_ptr, void* phy_ptr, LiteDeviceType device, LiteBackend backend) {
LITE_MARK_USED_VAR(vir_ptr);
LITE_MARK_USED_VAR(phy_ptr);
LITE_MARK_USED_VAR(device);
LITE_MARK_USED_VAR(backend);
LITE_THROW("clear_memory_pair is not implement yet!");
}

#else // LITE_BUILD_WITH_MGE
void lite::try_coalesce_all_free_memory() {}

@@ -235,6 +255,17 @@ void lite::set_tensor_rt_cache(std::string) {
void lite::dump_tensor_rt_cache() {
LITE_THROW("mge is disbale at build time, please build with mge");
}

bool lite::register_memory_pair(
void* vir_ptr, void* phy_ptr, size_t length, LiteDeviceType device,
LiteBackend beckend) {
LITE_THROW("register_memory_pair is not implement yet!");
}

bool lite::clear_memory_pair(
void* vir_ptr, void* phy_ptr, LiteDeviceType device, LiteBackend beckend) {
LITE_THROW("clear_memory_pair is not implement yet!");
}
#endif
namespace lite {
REGIST_DECRYPTION_METHOD(


+ 1
- 0
lite/test/test_network.cpp View File

@@ -1357,5 +1357,6 @@ TEST(TestNetWork, CambriconDeviceID) {
load_device_id(LiteDeviceType::LITE_CAMBRICON, 0, "./model_magicmind.mgb");
}
#endif

#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

Loading…
Cancel
Save