@@ -1150,12 +1150,18 @@ def copy(inp, device=None): | |||
.. testcode:: | |||
import numpy as np | |||
import platform | |||
from megengine import tensor | |||
from megengine.distributed.helper import get_device_count_by_fork | |||
import megengine.functional as F | |||
x = tensor([1, 2, 3], np.int32) | |||
y = F.copy(x, "xpu1") | |||
print(y.numpy()) | |||
if 1 == get_device_count_by_fork("gpu"): | |||
y = F.copy(x, "cpu1") | |||
print(y.numpy()) | |||
else: | |||
y = F.copy(x, "xpu1") | |||
print(y.numpy()) | |||
Outputs: | |||
@@ -7,3 +7,4 @@ tqdm | |||
redispy | |||
deprecated | |||
mprop | |||
wheel |
@@ -1,24 +1,42 @@ | |||
#!/bin/bash -e | |||
test_dirs="megengine test" | |||
TEST_PLAT=$1 | |||
export MEGENGINE_LOGGING_LEVEL="ERROR" | |||
if [[ "$TEST_PLAT" == cpu ]]; then | |||
echo "only test cpu pytest" | |||
echo "test cpu after Ninja develop" | |||
elif [[ "$TEST_PLAT" == cuda ]]; then | |||
echo "test both cpu and gpu pytest" | |||
echo "test cuda after Ninja develop" | |||
elif [[ "$TEST_PLAT" == cpu_local ]]; then | |||
echo "test cpu after python3 -m pip install xxx" | |||
elif [[ "$TEST_PLAT" == cuda_local ]]; then | |||
echo "test cuda after python3 -m pip install xxx" | |||
else | |||
echo "Argument must cpu or cuda" | |||
echo "ERR args, support list:" | |||
echo "$0 cpu (test cpu after Ninja develop)" | |||
echo "$0 cuda (test cuda after Ninja develop)" | |||
echo "$0 cpu_local (test cpu after python3 -m pip install xxx)" | |||
echo "$0 cuda_local (test cuda after python3 -m pip install xxx)" | |||
exit 1 | |||
fi | |||
export MEGENGINE_LOGGING_LEVEL="ERROR" | |||
pushd $(dirname "${BASH_SOURCE[0]}")/.. >/dev/null | |||
if [[ "$TEST_PLAT" =~ "local" ]]; then | |||
cd $(dirname "${BASH_SOURCE[0]}") | |||
megengine_dir=`python3 -c 'from pathlib import Path;import megengine;print(Path(megengine.__file__).resolve().parent)'` | |||
test_dirs="${megengine_dir} ." | |||
echo "test local env at: ${test_dirs}" | |||
PY_IGNORE_IMPORTMISMATCH=1 python3 -m pytest -v $test_dirs -m 'not isolated_distributed' | |||
if [[ "$TEST_PLAT" =~ "cuda" ]]; then | |||
echo "test GPU pytest now" | |||
PY_IGNORE_IMPORTMISMATCH=1 python3 -m pytest -v $test_dirs -m 'isolated_distributed' | |||
fi | |||
else | |||
cd $(dirname "${BASH_SOURCE[0]}")/.. | |||
test_dirs="megengine test" | |||
echo "test develop env" | |||
PYTHONPATH="." PY_IGNORE_IMPORTMISMATCH=1 python3 -m pytest -v $test_dirs -m 'not isolated_distributed' | |||
if [[ "$TEST_PLAT" == cuda ]]; then | |||
if [[ "$TEST_PLAT" =~ "cuda" ]]; then | |||
echo "test GPU pytest now" | |||
PYTHONPATH="." PY_IGNORE_IMPORTMISMATCH=1 python3 -m pytest -v $test_dirs -m 'isolated_distributed' | |||
fi | |||
popd >/dev/null | |||
fi |
@@ -7,6 +7,7 @@ | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import itertools | |||
import platform | |||
from functools import partial | |||
import numpy as np | |||
@@ -7,9 +7,10 @@ | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import numpy as np | |||
import pytest | |||
import megengine | |||
from megengine import tensor | |||
from megengine import is_cuda_available, tensor | |||
from megengine.core._imperative_rt import CompNode | |||
from megengine.core._imperative_rt.core2 import apply | |||
from megengine.core._imperative_rt.ops import ( | |||
@@ -18,10 +19,14 @@ from megengine.core._imperative_rt.ops import ( | |||
new_rng_handle, | |||
) | |||
from megengine.core.ops.builtin import GaussianRNG, UniformRNG | |||
from megengine.distributed.helper import get_device_count_by_fork | |||
from megengine.random import RNG | |||
from megengine.random.rng import _normal, _uniform | |||
@pytest.mark.skipif( | |||
get_device_count_by_fork("xpu") <= 2, reason="xpu counts need > 2", | |||
) | |||
def test_gaussian_op(): | |||
shape = ( | |||
8, | |||
@@ -47,6 +52,9 @@ def test_gaussian_op(): | |||
assert str(output.device) == str(cn) | |||
@pytest.mark.skipif( | |||
get_device_count_by_fork("xpu") <= 2, reason="xpu counts need > 2", | |||
) | |||
def test_uniform_op(): | |||
shape = ( | |||
8, | |||
@@ -70,6 +78,9 @@ def test_uniform_op(): | |||
assert str(output.device) == str(cn) | |||
@pytest.mark.skipif( | |||
get_device_count_by_fork("xpu") <= 1, reason="xpu counts need > 1", | |||
) | |||
def test_UniformRNG(): | |||
m1 = RNG(seed=111, device="xpu0") | |||
m2 = RNG(seed=111, device="xpu1") | |||
@@ -95,6 +106,9 @@ def test_UniformRNG(): | |||
assert np.abs(out.mean().numpy() - ((low + high) / 2)) / (high - low) < 0.1 | |||
@pytest.mark.skipif( | |||
get_device_count_by_fork("xpu") <= 1, reason="xpu counts need > 1", | |||
) | |||
def test_NormalRNG(): | |||
m1 = RNG(seed=111, device="xpu0") | |||
m2 = RNG(seed=111, device="xpu1") | |||
@@ -77,11 +77,12 @@ TEST(TestImperative, BatchNorm) { | |||
} | |||
TEST(TestImperative, Concat) { | |||
OprAttr::Param param; | |||
param.write_pod(megdnn::param::Axis(0)); | |||
OperatorNodeConfig config{CompNode::load("xpu1")}; | |||
OprChecker(OprAttr::make("Concat", param, config)) | |||
.run({TensorShape{200, 300}, TensorShape{300, 300}}); | |||
REQUIRE_XPU(2); | |||
OprAttr::Param param; | |||
param.write_pod(megdnn::param::Axis(0)); | |||
OperatorNodeConfig config{CompNode::load("xpu1")}; | |||
OprChecker(OprAttr::make("Concat", param, config)) | |||
.run({TensorShape{200, 300}, TensorShape{300, 300}}); | |||
} | |||
TEST(TestImperative, Split) { | |||
@@ -147,36 +148,36 @@ void run_graph(size_t mem_reserved, bool enable_defrag) { | |||
} | |||
TEST(TestImperative, Defragment) { | |||
REQUIRE_GPU(1); | |||
CompNode::load("gpux").activate(); | |||
size_t reserve; | |||
{ | |||
size_t free, tot; | |||
MGB_CUDA_CHECK(cudaMemGetInfo(&free, &tot)); | |||
reserve = free * 0.92; | |||
} | |||
auto reserve_setting = ssprintf("b:%zu", reserve); | |||
auto do_run = [reserve]() { | |||
ASSERT_THROW(run_graph(reserve, false), MemAllocError); | |||
run_graph(reserve, true); | |||
}; | |||
// reserve memory explicitly to avoid uncontrollable factors | |||
constexpr const char* KEY = "MGB_CUDA_RESERVE_MEMORY"; | |||
auto old_value = getenv(KEY); | |||
setenv(KEY, reserve_setting.c_str(), 1); | |||
MGB_TRY { | |||
do_run(); | |||
} MGB_FINALLY( | |||
if (old_value) { | |||
setenv(KEY, old_value, 1); | |||
} else { | |||
unsetenv(KEY); | |||
} | |||
CompNode::try_coalesce_all_free_memory(); | |||
CompNode::finalize(); | |||
); | |||
#if WIN32 | |||
//! FIXME, finalize on CUDA windows will be strip as windows CUDA101 DLL | |||
//! issue | |||
return; | |||
#endif | |||
REQUIRE_GPU(1); | |||
CompNode::load("gpux").activate(); | |||
size_t reserve; | |||
{ | |||
size_t free, tot; | |||
MGB_CUDA_CHECK(cudaMemGetInfo(&free, &tot)); | |||
reserve = free * 0.92; | |||
} | |||
auto reserve_setting = ssprintf("b:%zu", reserve); | |||
auto do_run = [reserve]() { | |||
ASSERT_THROW(run_graph(reserve, false), MemAllocError); | |||
run_graph(reserve, true); | |||
}; | |||
// reserve memory explicitly to avoid uncontrollable factors | |||
constexpr const char* KEY = "MGB_CUDA_RESERVE_MEMORY"; | |||
auto old_value = getenv(KEY); | |||
setenv(KEY, reserve_setting.c_str(), 1); | |||
MGB_TRY { do_run(); } | |||
MGB_FINALLY( | |||
if (old_value) { setenv(KEY, old_value, 1); } else { | |||
unsetenv(KEY); | |||
} CompNode::try_coalesce_all_free_memory(); | |||
CompNode::finalize();); | |||
} | |||
#endif // MGB_CUDA && MGB_ENABLE_EXCEPTION | |||
@@ -89,6 +89,7 @@ TEST(TestOprUtility, NopCallback) { | |||
} | |||
TEST(TestOprUtility, NopCallbackMixedInput) { | |||
REQUIRE_XPU(2); | |||
auto graph = ComputingGraph::make(); | |||
auto x0 = opr::Host2DeviceCopy::make(*graph, HostTensorGenerator<dtype::Int32>()({2, 3}), OperatorNodeConfig(CompNode::load("xpu0"))); | |||
auto x1 = opr::Host2DeviceCopy::make(*graph, HostTensorGenerator<dtype::Float32>()({2, 3}), OperatorNodeConfig(CompNode::load("xpu1"))); | |||
@@ -43,10 +43,12 @@ void check_rng_basic(Args&& ...args) { | |||
} | |||
TEST(TestImperative, UniformRNGBasic) { | |||
REQUIRE_XPU(2); | |||
check_rng_basic<UniformRNG>(123); | |||
} | |||
TEST(TestImperative, GaussianRNGBasic) { | |||
REQUIRE_XPU(2); | |||
check_rng_basic<GaussianRNG>(123, 2.f, 3.f); | |||
} | |||
@@ -19,6 +19,9 @@ endif() | |||
add_executable(megbrain_test ${SOURCES}) | |||
target_link_libraries(megbrain_test gtest gmock) | |||
target_link_libraries(megbrain_test megbrain megdnn ${MGE_CUDA_LIBS}) | |||
if (MGE_WITH_CUDA) | |||
target_include_directories(megbrain_test PRIVATE ${CUDNN_INCLUDE_DIR}) | |||
endif() | |||
if(CXX_SUPPORT_WCLASS_MEMACCESS) | |||
if(MGE_WITH_CUDA) | |||
target_compile_options(megbrain_test PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler=-Wno-class-memaccess>" | |||
@@ -337,6 +337,14 @@ std::vector<CompNode> mgb::load_multiple_xpus(size_t num) { | |||
return ret; | |||
} | |||
bool mgb::check_xpu_available(size_t num) { | |||
if (CompNode::get_device_count(CompNode::DeviceType::UNSPEC) < num) { | |||
mgb_log_warn("skip test case that requires %zu XPU(s)", num); | |||
return false; | |||
} | |||
return true; | |||
} | |||
bool mgb::check_gpu_available(size_t num) { | |||
if (CompNode::get_device_count(CompNode::DeviceType::CUDA) < num) { | |||
mgb_log_warn("skip test case that requires %zu GPU(s)", num); | |||
@@ -492,6 +492,9 @@ std::vector<CompNode> load_multiple_xpus(size_t num); | |||
//! check whether given number of GPUs is available | |||
bool check_gpu_available(size_t num); | |||
//! check whether given number of XPUs is available | |||
bool check_xpu_available(size_t num); | |||
//! check whether given number of AMD GPUs is available | |||
bool check_amd_gpu_available(size_t num); | |||
@@ -518,6 +521,12 @@ public: | |||
PersistentCacheHook(GetHook on_get); | |||
~PersistentCacheHook(); | |||
}; | |||
//! skip a testcase if xpu not available | |||
#define REQUIRE_XPU(n) do { \ | |||
if (!check_xpu_available(n)) \ | |||
return; \ | |||
} while(0) | |||
//! skip a testcase if gpu not available | |||
#define REQUIRE_GPU(n) do { \ | |||