Browse Source

fix(ci/windows): add windows cuda test

GitOrigin-RevId: 706be83032
release-1.5
Megvii Engine Team 4 years ago
parent
commit
d2673c5abf
11 changed files with 112 additions and 48 deletions
  1. +8
    -2
      imperative/python/megengine/functional/tensor.py
  2. +1
    -0
      imperative/python/requires.txt
  3. +28
    -10
      imperative/python/test/run.sh
  4. +1
    -0
      imperative/python/test/unit/functional/test_functional.py
  5. +15
    -1
      imperative/python/test/unit/random/test_rng.py
  6. +36
    -35
      imperative/src/test/imperative.cpp
  7. +1
    -0
      imperative/src/test/opr_utility.cpp
  8. +2
    -0
      imperative/src/test/rng.cpp
  9. +3
    -0
      test/CMakeLists.txt
  10. +8
    -0
      test/src/helper.cpp
  11. +9
    -0
      test/src/include/megbrain/test/helper.h

+ 8
- 2
imperative/python/megengine/functional/tensor.py View File

@@ -1150,12 +1150,18 @@ def copy(inp, device=None):
.. testcode:: .. testcode::


import numpy as np import numpy as np
import platform
from megengine import tensor from megengine import tensor
from megengine.distributed.helper import get_device_count_by_fork
import megengine.functional as F import megengine.functional as F


x = tensor([1, 2, 3], np.int32) x = tensor([1, 2, 3], np.int32)
y = F.copy(x, "xpu1")
print(y.numpy())
if 1 == get_device_count_by_fork("gpu"):
y = F.copy(x, "cpu1")
print(y.numpy())
else:
y = F.copy(x, "xpu1")
print(y.numpy())


Outputs: Outputs:




+ 1
- 0
imperative/python/requires.txt View File

@@ -7,3 +7,4 @@ tqdm
redispy redispy
deprecated deprecated
mprop mprop
wheel

+ 28
- 10
imperative/python/test/run.sh View File

@@ -1,24 +1,42 @@
#!/bin/bash -e #!/bin/bash -e


test_dirs="megengine test"

TEST_PLAT=$1 TEST_PLAT=$1
export MEGENGINE_LOGGING_LEVEL="ERROR"


if [[ "$TEST_PLAT" == cpu ]]; then if [[ "$TEST_PLAT" == cpu ]]; then
echo "only test cpu pytest"
echo "test cpu after Ninja develop"
elif [[ "$TEST_PLAT" == cuda ]]; then elif [[ "$TEST_PLAT" == cuda ]]; then
echo "test both cpu and gpu pytest"
echo "test cuda after Ninja develop"
elif [[ "$TEST_PLAT" == cpu_local ]]; then
echo "test cpu after python3 -m pip install xxx"
elif [[ "$TEST_PLAT" == cuda_local ]]; then
echo "test cuda after python3 -m pip install xxx"
else else
echo "Argument must cpu or cuda"
echo "ERR args, support list:"
echo "$0 cpu (test cpu after Ninja develop)"
echo "$0 cuda (test cuda after Ninja develop)"
echo "$0 cpu_local (test cpu after python3 -m pip install xxx)"
echo "$0 cuda_local (test cuda after python3 -m pip install xxx)"
exit 1 exit 1
fi fi


export MEGENGINE_LOGGING_LEVEL="ERROR"

pushd $(dirname "${BASH_SOURCE[0]}")/.. >/dev/null
if [[ "$TEST_PLAT" =~ "local" ]]; then
cd $(dirname "${BASH_SOURCE[0]}")
megengine_dir=`python3 -c 'from pathlib import Path;import megengine;print(Path(megengine.__file__).resolve().parent)'`
test_dirs="${megengine_dir} ."
echo "test local env at: ${test_dirs}"
PY_IGNORE_IMPORTMISMATCH=1 python3 -m pytest -v $test_dirs -m 'not isolated_distributed'
if [[ "$TEST_PLAT" =~ "cuda" ]]; then
echo "test GPU pytest now"
PY_IGNORE_IMPORTMISMATCH=1 python3 -m pytest -v $test_dirs -m 'isolated_distributed'
fi
else
cd $(dirname "${BASH_SOURCE[0]}")/..
test_dirs="megengine test"
echo "test develop env"
PYTHONPATH="." PY_IGNORE_IMPORTMISMATCH=1 python3 -m pytest -v $test_dirs -m 'not isolated_distributed' PYTHONPATH="." PY_IGNORE_IMPORTMISMATCH=1 python3 -m pytest -v $test_dirs -m 'not isolated_distributed'
if [[ "$TEST_PLAT" == cuda ]]; then
if [[ "$TEST_PLAT" =~ "cuda" ]]; then
echo "test GPU pytest now" echo "test GPU pytest now"
PYTHONPATH="." PY_IGNORE_IMPORTMISMATCH=1 python3 -m pytest -v $test_dirs -m 'isolated_distributed' PYTHONPATH="." PY_IGNORE_IMPORTMISMATCH=1 python3 -m pytest -v $test_dirs -m 'isolated_distributed'
fi fi
popd >/dev/null
fi

+ 1
- 0
imperative/python/test/unit/functional/test_functional.py View File

@@ -7,6 +7,7 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import itertools import itertools
import platform
from functools import partial from functools import partial


import numpy as np import numpy as np


+ 15
- 1
imperative/python/test/unit/random/test_rng.py View File

@@ -7,9 +7,10 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np import numpy as np
import pytest


import megengine import megengine
from megengine import tensor
from megengine import is_cuda_available, tensor
from megengine.core._imperative_rt import CompNode from megengine.core._imperative_rt import CompNode
from megengine.core._imperative_rt.core2 import apply from megengine.core._imperative_rt.core2 import apply
from megengine.core._imperative_rt.ops import ( from megengine.core._imperative_rt.ops import (
@@ -18,10 +19,14 @@ from megengine.core._imperative_rt.ops import (
new_rng_handle, new_rng_handle,
) )
from megengine.core.ops.builtin import GaussianRNG, UniformRNG from megengine.core.ops.builtin import GaussianRNG, UniformRNG
from megengine.distributed.helper import get_device_count_by_fork
from megengine.random import RNG from megengine.random import RNG
from megengine.random.rng import _normal, _uniform from megengine.random.rng import _normal, _uniform




@pytest.mark.skipif(
get_device_count_by_fork("xpu") <= 2, reason="xpu counts need > 2",
)
def test_gaussian_op(): def test_gaussian_op():
shape = ( shape = (
8, 8,
@@ -47,6 +52,9 @@ def test_gaussian_op():
assert str(output.device) == str(cn) assert str(output.device) == str(cn)




@pytest.mark.skipif(
get_device_count_by_fork("xpu") <= 2, reason="xpu counts need > 2",
)
def test_uniform_op(): def test_uniform_op():
shape = ( shape = (
8, 8,
@@ -70,6 +78,9 @@ def test_uniform_op():
assert str(output.device) == str(cn) assert str(output.device) == str(cn)




@pytest.mark.skipif(
get_device_count_by_fork("xpu") <= 1, reason="xpu counts need > 1",
)
def test_UniformRNG(): def test_UniformRNG():
m1 = RNG(seed=111, device="xpu0") m1 = RNG(seed=111, device="xpu0")
m2 = RNG(seed=111, device="xpu1") m2 = RNG(seed=111, device="xpu1")
@@ -95,6 +106,9 @@ def test_UniformRNG():
assert np.abs(out.mean().numpy() - ((low + high) / 2)) / (high - low) < 0.1 assert np.abs(out.mean().numpy() - ((low + high) / 2)) / (high - low) < 0.1




@pytest.mark.skipif(
get_device_count_by_fork("xpu") <= 1, reason="xpu counts need > 1",
)
def test_NormalRNG(): def test_NormalRNG():
m1 = RNG(seed=111, device="xpu0") m1 = RNG(seed=111, device="xpu0")
m2 = RNG(seed=111, device="xpu1") m2 = RNG(seed=111, device="xpu1")


+ 36
- 35
imperative/src/test/imperative.cpp View File

@@ -77,11 +77,12 @@ TEST(TestImperative, BatchNorm) {
} }


TEST(TestImperative, Concat) { TEST(TestImperative, Concat) {
OprAttr::Param param;
param.write_pod(megdnn::param::Axis(0));
OperatorNodeConfig config{CompNode::load("xpu1")};
OprChecker(OprAttr::make("Concat", param, config))
.run({TensorShape{200, 300}, TensorShape{300, 300}});
REQUIRE_XPU(2);
OprAttr::Param param;
param.write_pod(megdnn::param::Axis(0));
OperatorNodeConfig config{CompNode::load("xpu1")};
OprChecker(OprAttr::make("Concat", param, config))
.run({TensorShape{200, 300}, TensorShape{300, 300}});
} }


TEST(TestImperative, Split) { TEST(TestImperative, Split) {
@@ -147,36 +148,36 @@ void run_graph(size_t mem_reserved, bool enable_defrag) {
} }


TEST(TestImperative, Defragment) { TEST(TestImperative, Defragment) {
REQUIRE_GPU(1);
CompNode::load("gpux").activate();
size_t reserve;
{
size_t free, tot;
MGB_CUDA_CHECK(cudaMemGetInfo(&free, &tot));
reserve = free * 0.92;
}
auto reserve_setting = ssprintf("b:%zu", reserve);
auto do_run = [reserve]() {
ASSERT_THROW(run_graph(reserve, false), MemAllocError);
run_graph(reserve, true);
};
// reserve memory explicitly to avoid uncontrollable factors
constexpr const char* KEY = "MGB_CUDA_RESERVE_MEMORY";
auto old_value = getenv(KEY);
setenv(KEY, reserve_setting.c_str(), 1);
MGB_TRY {
do_run();
} MGB_FINALLY(
if (old_value) {
setenv(KEY, old_value, 1);
} else {
unsetenv(KEY);
}
CompNode::try_coalesce_all_free_memory();
CompNode::finalize();
);
#if WIN32
//! FIXME, finalize on CUDA windows will be strip as windows CUDA101 DLL
//! issue
return;
#endif
REQUIRE_GPU(1);
CompNode::load("gpux").activate();
size_t reserve;
{
size_t free, tot;
MGB_CUDA_CHECK(cudaMemGetInfo(&free, &tot));
reserve = free * 0.92;
}
auto reserve_setting = ssprintf("b:%zu", reserve);
auto do_run = [reserve]() {
ASSERT_THROW(run_graph(reserve, false), MemAllocError);
run_graph(reserve, true);
};
// reserve memory explicitly to avoid uncontrollable factors
constexpr const char* KEY = "MGB_CUDA_RESERVE_MEMORY";
auto old_value = getenv(KEY);
setenv(KEY, reserve_setting.c_str(), 1);
MGB_TRY { do_run(); }
MGB_FINALLY(
if (old_value) { setenv(KEY, old_value, 1); } else {
unsetenv(KEY);
} CompNode::try_coalesce_all_free_memory();
CompNode::finalize(););
} }
#endif // MGB_CUDA && MGB_ENABLE_EXCEPTION #endif // MGB_CUDA && MGB_ENABLE_EXCEPTION




+ 1
- 0
imperative/src/test/opr_utility.cpp View File

@@ -89,6 +89,7 @@ TEST(TestOprUtility, NopCallback) {
} }


TEST(TestOprUtility, NopCallbackMixedInput) { TEST(TestOprUtility, NopCallbackMixedInput) {
REQUIRE_XPU(2);
auto graph = ComputingGraph::make(); auto graph = ComputingGraph::make();
auto x0 = opr::Host2DeviceCopy::make(*graph, HostTensorGenerator<dtype::Int32>()({2, 3}), OperatorNodeConfig(CompNode::load("xpu0"))); auto x0 = opr::Host2DeviceCopy::make(*graph, HostTensorGenerator<dtype::Int32>()({2, 3}), OperatorNodeConfig(CompNode::load("xpu0")));
auto x1 = opr::Host2DeviceCopy::make(*graph, HostTensorGenerator<dtype::Float32>()({2, 3}), OperatorNodeConfig(CompNode::load("xpu1"))); auto x1 = opr::Host2DeviceCopy::make(*graph, HostTensorGenerator<dtype::Float32>()({2, 3}), OperatorNodeConfig(CompNode::load("xpu1")));


+ 2
- 0
imperative/src/test/rng.cpp View File

@@ -43,10 +43,12 @@ void check_rng_basic(Args&& ...args) {
} }


TEST(TestImperative, UniformRNGBasic) { TEST(TestImperative, UniformRNGBasic) {
REQUIRE_XPU(2);
check_rng_basic<UniformRNG>(123); check_rng_basic<UniformRNG>(123);
} }


TEST(TestImperative, GaussianRNGBasic) { TEST(TestImperative, GaussianRNGBasic) {
REQUIRE_XPU(2);
check_rng_basic<GaussianRNG>(123, 2.f, 3.f); check_rng_basic<GaussianRNG>(123, 2.f, 3.f);
} }




+ 3
- 0
test/CMakeLists.txt View File

@@ -19,6 +19,9 @@ endif()
add_executable(megbrain_test ${SOURCES}) add_executable(megbrain_test ${SOURCES})
target_link_libraries(megbrain_test gtest gmock) target_link_libraries(megbrain_test gtest gmock)
target_link_libraries(megbrain_test megbrain megdnn ${MGE_CUDA_LIBS}) target_link_libraries(megbrain_test megbrain megdnn ${MGE_CUDA_LIBS})
if (MGE_WITH_CUDA)
target_include_directories(megbrain_test PRIVATE ${CUDNN_INCLUDE_DIR})
endif()
if(CXX_SUPPORT_WCLASS_MEMACCESS) if(CXX_SUPPORT_WCLASS_MEMACCESS)
if(MGE_WITH_CUDA) if(MGE_WITH_CUDA)
target_compile_options(megbrain_test PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler=-Wno-class-memaccess>" target_compile_options(megbrain_test PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler=-Wno-class-memaccess>"


+ 8
- 0
test/src/helper.cpp View File

@@ -337,6 +337,14 @@ std::vector<CompNode> mgb::load_multiple_xpus(size_t num) {
return ret; return ret;
} }


bool mgb::check_xpu_available(size_t num) {
if (CompNode::get_device_count(CompNode::DeviceType::UNSPEC) < num) {
mgb_log_warn("skip test case that requires %zu XPU(s)", num);
return false;
}
return true;
}

bool mgb::check_gpu_available(size_t num) { bool mgb::check_gpu_available(size_t num) {
if (CompNode::get_device_count(CompNode::DeviceType::CUDA) < num) { if (CompNode::get_device_count(CompNode::DeviceType::CUDA) < num) {
mgb_log_warn("skip test case that requires %zu GPU(s)", num); mgb_log_warn("skip test case that requires %zu GPU(s)", num);


+ 9
- 0
test/src/include/megbrain/test/helper.h View File

@@ -492,6 +492,9 @@ std::vector<CompNode> load_multiple_xpus(size_t num);
//! check whether given number of GPUs is available //! check whether given number of GPUs is available
bool check_gpu_available(size_t num); bool check_gpu_available(size_t num);


//! check whether given number of XPUs is available
bool check_xpu_available(size_t num);

//! check whether given number of AMD GPUs is available //! check whether given number of AMD GPUs is available
bool check_amd_gpu_available(size_t num); bool check_amd_gpu_available(size_t num);


@@ -518,6 +521,12 @@ public:
PersistentCacheHook(GetHook on_get); PersistentCacheHook(GetHook on_get);
~PersistentCacheHook(); ~PersistentCacheHook();
}; };
//! skip a testcase if xpu not available
#define REQUIRE_XPU(n) do { \
if (!check_xpu_available(n)) \
return; \
} while(0)



//! skip a testcase if gpu not available //! skip a testcase if gpu not available
#define REQUIRE_GPU(n) do { \ #define REQUIRE_GPU(n) do { \


Loading…
Cancel
Save