From d2673c5abf5cfce742fba5063228f7354a679394 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 17 May 2021 20:57:57 +0800 Subject: [PATCH] fix(ci/windows): add windows cuda test GitOrigin-RevId: 706be83032bc08c0d2a9cc00eb8dce23e58925bf --- imperative/python/megengine/functional/tensor.py | 10 ++- imperative/python/requires.txt | 1 + imperative/python/test/run.sh | 38 +++++++++--- .../python/test/unit/functional/test_functional.py | 1 + imperative/python/test/unit/random/test_rng.py | 16 ++++- imperative/src/test/imperative.cpp | 71 +++++++++++----------- imperative/src/test/opr_utility.cpp | 1 + imperative/src/test/rng.cpp | 2 + test/CMakeLists.txt | 3 + test/src/helper.cpp | 8 +++ test/src/include/megbrain/test/helper.h | 9 +++ 11 files changed, 112 insertions(+), 48 deletions(-) diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index 7cf0472b..3c71e82c 100644 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -1150,12 +1150,18 @@ def copy(inp, device=None): .. testcode:: import numpy as np + import platform from megengine import tensor + from megengine.distributed.helper import get_device_count_by_fork import megengine.functional as F x = tensor([1, 2, 3], np.int32) - y = F.copy(x, "xpu1") - print(y.numpy()) + if 1 == get_device_count_by_fork("gpu"): + y = F.copy(x, "cpu1") + print(y.numpy()) + else: + y = F.copy(x, "xpu1") + print(y.numpy()) Outputs: diff --git a/imperative/python/requires.txt b/imperative/python/requires.txt index 4921869a..670193dc 100644 --- a/imperative/python/requires.txt +++ b/imperative/python/requires.txt @@ -7,3 +7,4 @@ tqdm redispy deprecated mprop +wheel diff --git a/imperative/python/test/run.sh b/imperative/python/test/run.sh index f70c121f..ed71de53 100755 --- a/imperative/python/test/run.sh +++ b/imperative/python/test/run.sh @@ -1,24 +1,42 @@ #!/bin/bash -e -test_dirs="megengine test" - TEST_PLAT=$1 +export MEGENGINE_LOGGING_LEVEL="ERROR" if [[ "$TEST_PLAT" == cpu ]]; then - echo "only test cpu pytest" + echo "test cpu after Ninja develop" elif [[ "$TEST_PLAT" == cuda ]]; then - echo "test both cpu and gpu pytest" + echo "test cuda after Ninja develop" +elif [[ "$TEST_PLAT" == cpu_local ]]; then + echo "test cpu after python3 -m pip install xxx" +elif [[ "$TEST_PLAT" == cuda_local ]]; then + echo "test cuda after python3 -m pip install xxx" else - echo "Argument must cpu or cuda" + echo "ERR args, support list:" + echo "$0 cpu (test cpu after Ninja develop)" + echo "$0 cuda (test cuda after Ninja develop)" + echo "$0 cpu_local (test cpu after python3 -m pip install xxx)" + echo "$0 cuda_local (test cuda after python3 -m pip install xxx)" exit 1 fi -export MEGENGINE_LOGGING_LEVEL="ERROR" - -pushd $(dirname "${BASH_SOURCE[0]}")/.. >/dev/null +if [[ "$TEST_PLAT" =~ "local" ]]; then + cd $(dirname "${BASH_SOURCE[0]}") + megengine_dir=`python3 -c 'from pathlib import Path;import megengine;print(Path(megengine.__file__).resolve().parent)'` + test_dirs="${megengine_dir} ." + echo "test local env at: ${test_dirs}" + PY_IGNORE_IMPORTMISMATCH=1 python3 -m pytest -v $test_dirs -m 'not isolated_distributed' + if [[ "$TEST_PLAT" =~ "cuda" ]]; then + echo "test GPU pytest now" + PY_IGNORE_IMPORTMISMATCH=1 python3 -m pytest -v $test_dirs -m 'isolated_distributed' + fi +else + cd $(dirname "${BASH_SOURCE[0]}")/.. + test_dirs="megengine test" + echo "test develop env" PYTHONPATH="." PY_IGNORE_IMPORTMISMATCH=1 python3 -m pytest -v $test_dirs -m 'not isolated_distributed' - if [[ "$TEST_PLAT" == cuda ]]; then + if [[ "$TEST_PLAT" =~ "cuda" ]]; then echo "test GPU pytest now" PYTHONPATH="." PY_IGNORE_IMPORTMISMATCH=1 python3 -m pytest -v $test_dirs -m 'isolated_distributed' fi -popd >/dev/null +fi diff --git a/imperative/python/test/unit/functional/test_functional.py b/imperative/python/test/unit/functional/test_functional.py index 46486bfd..28f2254a 100644 --- a/imperative/python/test/unit/functional/test_functional.py +++ b/imperative/python/test/unit/functional/test_functional.py @@ -7,6 +7,7 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import itertools +import platform from functools import partial import numpy as np diff --git a/imperative/python/test/unit/random/test_rng.py b/imperative/python/test/unit/random/test_rng.py index 1979eebf..b3bf7c20 100644 --- a/imperative/python/test/unit/random/test_rng.py +++ b/imperative/python/test/unit/random/test_rng.py @@ -7,9 +7,10 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import numpy as np +import pytest import megengine -from megengine import tensor +from megengine import is_cuda_available, tensor from megengine.core._imperative_rt import CompNode from megengine.core._imperative_rt.core2 import apply from megengine.core._imperative_rt.ops import ( @@ -18,10 +19,14 @@ from megengine.core._imperative_rt.ops import ( new_rng_handle, ) from megengine.core.ops.builtin import GaussianRNG, UniformRNG +from megengine.distributed.helper import get_device_count_by_fork from megengine.random import RNG from megengine.random.rng import _normal, _uniform +@pytest.mark.skipif( + get_device_count_by_fork("xpu") <= 2, reason="xpu counts need > 2", +) def test_gaussian_op(): shape = ( 8, @@ -47,6 +52,9 @@ def test_gaussian_op(): assert str(output.device) == str(cn) +@pytest.mark.skipif( + get_device_count_by_fork("xpu") <= 2, reason="xpu counts need > 2", +) def test_uniform_op(): shape = ( 8, @@ -70,6 +78,9 @@ def test_uniform_op(): assert str(output.device) == str(cn) +@pytest.mark.skipif( + get_device_count_by_fork("xpu") <= 1, reason="xpu counts need > 1", +) def test_UniformRNG(): m1 = RNG(seed=111, device="xpu0") m2 = RNG(seed=111, device="xpu1") @@ -95,6 +106,9 @@ def test_UniformRNG(): assert np.abs(out.mean().numpy() - ((low + high) / 2)) / (high - low) < 0.1 +@pytest.mark.skipif( + get_device_count_by_fork("xpu") <= 1, reason="xpu counts need > 1", +) def test_NormalRNG(): m1 = RNG(seed=111, device="xpu0") m2 = RNG(seed=111, device="xpu1") diff --git a/imperative/src/test/imperative.cpp b/imperative/src/test/imperative.cpp index d29ef19e..bcf081d4 100644 --- a/imperative/src/test/imperative.cpp +++ b/imperative/src/test/imperative.cpp @@ -77,11 +77,12 @@ TEST(TestImperative, BatchNorm) { } TEST(TestImperative, Concat) { - OprAttr::Param param; - param.write_pod(megdnn::param::Axis(0)); - OperatorNodeConfig config{CompNode::load("xpu1")}; - OprChecker(OprAttr::make("Concat", param, config)) - .run({TensorShape{200, 300}, TensorShape{300, 300}}); + REQUIRE_XPU(2); + OprAttr::Param param; + param.write_pod(megdnn::param::Axis(0)); + OperatorNodeConfig config{CompNode::load("xpu1")}; + OprChecker(OprAttr::make("Concat", param, config)) + .run({TensorShape{200, 300}, TensorShape{300, 300}}); } TEST(TestImperative, Split) { @@ -147,36 +148,36 @@ void run_graph(size_t mem_reserved, bool enable_defrag) { } TEST(TestImperative, Defragment) { - REQUIRE_GPU(1); - CompNode::load("gpux").activate(); - size_t reserve; - { - size_t free, tot; - MGB_CUDA_CHECK(cudaMemGetInfo(&free, &tot)); - reserve = free * 0.92; - } - auto reserve_setting = ssprintf("b:%zu", reserve); - - auto do_run = [reserve]() { - ASSERT_THROW(run_graph(reserve, false), MemAllocError); - run_graph(reserve, true); - }; - - // reserve memory explicitly to avoid uncontrollable factors - constexpr const char* KEY = "MGB_CUDA_RESERVE_MEMORY"; - auto old_value = getenv(KEY); - setenv(KEY, reserve_setting.c_str(), 1); - MGB_TRY { - do_run(); - } MGB_FINALLY( - if (old_value) { - setenv(KEY, old_value, 1); - } else { - unsetenv(KEY); - } - CompNode::try_coalesce_all_free_memory(); - CompNode::finalize(); - ); +#if WIN32 + //! FIXME, finalize on CUDA windows will be strip as windows CUDA101 DLL + //! issue + return; +#endif + REQUIRE_GPU(1); + CompNode::load("gpux").activate(); + size_t reserve; + { + size_t free, tot; + MGB_CUDA_CHECK(cudaMemGetInfo(&free, &tot)); + reserve = free * 0.92; + } + auto reserve_setting = ssprintf("b:%zu", reserve); + + auto do_run = [reserve]() { + ASSERT_THROW(run_graph(reserve, false), MemAllocError); + run_graph(reserve, true); + }; + + // reserve memory explicitly to avoid uncontrollable factors + constexpr const char* KEY = "MGB_CUDA_RESERVE_MEMORY"; + auto old_value = getenv(KEY); + setenv(KEY, reserve_setting.c_str(), 1); + MGB_TRY { do_run(); } + MGB_FINALLY( + if (old_value) { setenv(KEY, old_value, 1); } else { + unsetenv(KEY); + } CompNode::try_coalesce_all_free_memory(); + CompNode::finalize();); } #endif // MGB_CUDA && MGB_ENABLE_EXCEPTION diff --git a/imperative/src/test/opr_utility.cpp b/imperative/src/test/opr_utility.cpp index 10c1f84f..8f3c9405 100644 --- a/imperative/src/test/opr_utility.cpp +++ b/imperative/src/test/opr_utility.cpp @@ -89,6 +89,7 @@ TEST(TestOprUtility, NopCallback) { } TEST(TestOprUtility, NopCallbackMixedInput) { + REQUIRE_XPU(2); auto graph = ComputingGraph::make(); auto x0 = opr::Host2DeviceCopy::make(*graph, HostTensorGenerator()({2, 3}), OperatorNodeConfig(CompNode::load("xpu0"))); auto x1 = opr::Host2DeviceCopy::make(*graph, HostTensorGenerator()({2, 3}), OperatorNodeConfig(CompNode::load("xpu1"))); diff --git a/imperative/src/test/rng.cpp b/imperative/src/test/rng.cpp index 9604df98..d03b53b8 100644 --- a/imperative/src/test/rng.cpp +++ b/imperative/src/test/rng.cpp @@ -43,10 +43,12 @@ void check_rng_basic(Args&& ...args) { } TEST(TestImperative, UniformRNGBasic) { + REQUIRE_XPU(2); check_rng_basic(123); } TEST(TestImperative, GaussianRNGBasic) { + REQUIRE_XPU(2); check_rng_basic(123, 2.f, 3.f); } diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 2387a109..fd4dde50 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -19,6 +19,9 @@ endif() add_executable(megbrain_test ${SOURCES}) target_link_libraries(megbrain_test gtest gmock) target_link_libraries(megbrain_test megbrain megdnn ${MGE_CUDA_LIBS}) +if (MGE_WITH_CUDA) + target_include_directories(megbrain_test PRIVATE ${CUDNN_INCLUDE_DIR}) +endif() if(CXX_SUPPORT_WCLASS_MEMACCESS) if(MGE_WITH_CUDA) target_compile_options(megbrain_test PRIVATE "$<$:-Xcompiler=-Wno-class-memaccess>" diff --git a/test/src/helper.cpp b/test/src/helper.cpp index 82f63681..ccdd2bbe 100644 --- a/test/src/helper.cpp +++ b/test/src/helper.cpp @@ -337,6 +337,14 @@ std::vector mgb::load_multiple_xpus(size_t num) { return ret; } +bool mgb::check_xpu_available(size_t num) { + if (CompNode::get_device_count(CompNode::DeviceType::UNSPEC) < num) { + mgb_log_warn("skip test case that requires %zu XPU(s)", num); + return false; + } + return true; +} + bool mgb::check_gpu_available(size_t num) { if (CompNode::get_device_count(CompNode::DeviceType::CUDA) < num) { mgb_log_warn("skip test case that requires %zu GPU(s)", num); diff --git a/test/src/include/megbrain/test/helper.h b/test/src/include/megbrain/test/helper.h index 92cc29be..5bfee7e5 100644 --- a/test/src/include/megbrain/test/helper.h +++ b/test/src/include/megbrain/test/helper.h @@ -492,6 +492,9 @@ std::vector load_multiple_xpus(size_t num); //! check whether given number of GPUs is available bool check_gpu_available(size_t num); +//! check whether given number of XPUs is available +bool check_xpu_available(size_t num); + //! check whether given number of AMD GPUs is available bool check_amd_gpu_available(size_t num); @@ -518,6 +521,12 @@ public: PersistentCacheHook(GetHook on_get); ~PersistentCacheHook(); }; +//! skip a testcase if xpu not available +#define REQUIRE_XPU(n) do { \ + if (!check_xpu_available(n)) \ + return; \ +} while(0) + //! skip a testcase if gpu not available #define REQUIRE_GPU(n) do { \