@@ -1150,12 +1150,18 @@ def copy(inp, device=None): | |||||
.. testcode:: | .. testcode:: | ||||
import numpy as np | import numpy as np | ||||
import platform | |||||
from megengine import tensor | from megengine import tensor | ||||
from megengine.distributed.helper import get_device_count_by_fork | |||||
import megengine.functional as F | import megengine.functional as F | ||||
x = tensor([1, 2, 3], np.int32) | x = tensor([1, 2, 3], np.int32) | ||||
y = F.copy(x, "xpu1") | |||||
print(y.numpy()) | |||||
if 1 == get_device_count_by_fork("gpu"): | |||||
y = F.copy(x, "cpu1") | |||||
print(y.numpy()) | |||||
else: | |||||
y = F.copy(x, "xpu1") | |||||
print(y.numpy()) | |||||
Outputs: | Outputs: | ||||
@@ -7,3 +7,4 @@ tqdm | |||||
redispy | redispy | ||||
deprecated | deprecated | ||||
mprop | mprop | ||||
wheel |
@@ -1,24 +1,42 @@ | |||||
#!/bin/bash -e | #!/bin/bash -e | ||||
test_dirs="megengine test" | |||||
TEST_PLAT=$1 | TEST_PLAT=$1 | ||||
export MEGENGINE_LOGGING_LEVEL="ERROR" | |||||
if [[ "$TEST_PLAT" == cpu ]]; then | if [[ "$TEST_PLAT" == cpu ]]; then | ||||
echo "only test cpu pytest" | |||||
echo "test cpu after Ninja develop" | |||||
elif [[ "$TEST_PLAT" == cuda ]]; then | elif [[ "$TEST_PLAT" == cuda ]]; then | ||||
echo "test both cpu and gpu pytest" | |||||
echo "test cuda after Ninja develop" | |||||
elif [[ "$TEST_PLAT" == cpu_local ]]; then | |||||
echo "test cpu after python3 -m pip install xxx" | |||||
elif [[ "$TEST_PLAT" == cuda_local ]]; then | |||||
echo "test cuda after python3 -m pip install xxx" | |||||
else | else | ||||
echo "Argument must cpu or cuda" | |||||
echo "ERR args, support list:" | |||||
echo "$0 cpu (test cpu after Ninja develop)" | |||||
echo "$0 cuda (test cuda after Ninja develop)" | |||||
echo "$0 cpu_local (test cpu after python3 -m pip install xxx)" | |||||
echo "$0 cuda_local (test cuda after python3 -m pip install xxx)" | |||||
exit 1 | exit 1 | ||||
fi | fi | ||||
export MEGENGINE_LOGGING_LEVEL="ERROR" | |||||
pushd $(dirname "${BASH_SOURCE[0]}")/.. >/dev/null | |||||
if [[ "$TEST_PLAT" =~ "local" ]]; then | |||||
cd $(dirname "${BASH_SOURCE[0]}") | |||||
megengine_dir=`python3 -c 'from pathlib import Path;import megengine;print(Path(megengine.__file__).resolve().parent)'` | |||||
test_dirs="${megengine_dir} ." | |||||
echo "test local env at: ${test_dirs}" | |||||
PY_IGNORE_IMPORTMISMATCH=1 python3 -m pytest -v $test_dirs -m 'not isolated_distributed' | |||||
if [[ "$TEST_PLAT" =~ "cuda" ]]; then | |||||
echo "test GPU pytest now" | |||||
PY_IGNORE_IMPORTMISMATCH=1 python3 -m pytest -v $test_dirs -m 'isolated_distributed' | |||||
fi | |||||
else | |||||
cd $(dirname "${BASH_SOURCE[0]}")/.. | |||||
test_dirs="megengine test" | |||||
echo "test develop env" | |||||
PYTHONPATH="." PY_IGNORE_IMPORTMISMATCH=1 python3 -m pytest -v $test_dirs -m 'not isolated_distributed' | PYTHONPATH="." PY_IGNORE_IMPORTMISMATCH=1 python3 -m pytest -v $test_dirs -m 'not isolated_distributed' | ||||
if [[ "$TEST_PLAT" == cuda ]]; then | |||||
if [[ "$TEST_PLAT" =~ "cuda" ]]; then | |||||
echo "test GPU pytest now" | echo "test GPU pytest now" | ||||
PYTHONPATH="." PY_IGNORE_IMPORTMISMATCH=1 python3 -m pytest -v $test_dirs -m 'isolated_distributed' | PYTHONPATH="." PY_IGNORE_IMPORTMISMATCH=1 python3 -m pytest -v $test_dirs -m 'isolated_distributed' | ||||
fi | fi | ||||
popd >/dev/null | |||||
fi |
@@ -7,6 +7,7 @@ | |||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import itertools | import itertools | ||||
import platform | |||||
from functools import partial | from functools import partial | ||||
import numpy as np | import numpy as np | ||||
@@ -7,9 +7,10 @@ | |||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import numpy as np | import numpy as np | ||||
import pytest | |||||
import megengine | import megengine | ||||
from megengine import tensor | |||||
from megengine import is_cuda_available, tensor | |||||
from megengine.core._imperative_rt import CompNode | from megengine.core._imperative_rt import CompNode | ||||
from megengine.core._imperative_rt.core2 import apply | from megengine.core._imperative_rt.core2 import apply | ||||
from megengine.core._imperative_rt.ops import ( | from megengine.core._imperative_rt.ops import ( | ||||
@@ -18,10 +19,14 @@ from megengine.core._imperative_rt.ops import ( | |||||
new_rng_handle, | new_rng_handle, | ||||
) | ) | ||||
from megengine.core.ops.builtin import GaussianRNG, UniformRNG | from megengine.core.ops.builtin import GaussianRNG, UniformRNG | ||||
from megengine.distributed.helper import get_device_count_by_fork | |||||
from megengine.random import RNG | from megengine.random import RNG | ||||
from megengine.random.rng import _normal, _uniform | from megengine.random.rng import _normal, _uniform | ||||
@pytest.mark.skipif( | |||||
get_device_count_by_fork("xpu") <= 2, reason="xpu counts need > 2", | |||||
) | |||||
def test_gaussian_op(): | def test_gaussian_op(): | ||||
shape = ( | shape = ( | ||||
8, | 8, | ||||
@@ -47,6 +52,9 @@ def test_gaussian_op(): | |||||
assert str(output.device) == str(cn) | assert str(output.device) == str(cn) | ||||
@pytest.mark.skipif( | |||||
get_device_count_by_fork("xpu") <= 2, reason="xpu counts need > 2", | |||||
) | |||||
def test_uniform_op(): | def test_uniform_op(): | ||||
shape = ( | shape = ( | ||||
8, | 8, | ||||
@@ -70,6 +78,9 @@ def test_uniform_op(): | |||||
assert str(output.device) == str(cn) | assert str(output.device) == str(cn) | ||||
@pytest.mark.skipif( | |||||
get_device_count_by_fork("xpu") <= 1, reason="xpu counts need > 1", | |||||
) | |||||
def test_UniformRNG(): | def test_UniformRNG(): | ||||
m1 = RNG(seed=111, device="xpu0") | m1 = RNG(seed=111, device="xpu0") | ||||
m2 = RNG(seed=111, device="xpu1") | m2 = RNG(seed=111, device="xpu1") | ||||
@@ -95,6 +106,9 @@ def test_UniformRNG(): | |||||
assert np.abs(out.mean().numpy() - ((low + high) / 2)) / (high - low) < 0.1 | assert np.abs(out.mean().numpy() - ((low + high) / 2)) / (high - low) < 0.1 | ||||
@pytest.mark.skipif( | |||||
get_device_count_by_fork("xpu") <= 1, reason="xpu counts need > 1", | |||||
) | |||||
def test_NormalRNG(): | def test_NormalRNG(): | ||||
m1 = RNG(seed=111, device="xpu0") | m1 = RNG(seed=111, device="xpu0") | ||||
m2 = RNG(seed=111, device="xpu1") | m2 = RNG(seed=111, device="xpu1") | ||||
@@ -77,11 +77,12 @@ TEST(TestImperative, BatchNorm) { | |||||
} | } | ||||
TEST(TestImperative, Concat) { | TEST(TestImperative, Concat) { | ||||
OprAttr::Param param; | |||||
param.write_pod(megdnn::param::Axis(0)); | |||||
OperatorNodeConfig config{CompNode::load("xpu1")}; | |||||
OprChecker(OprAttr::make("Concat", param, config)) | |||||
.run({TensorShape{200, 300}, TensorShape{300, 300}}); | |||||
REQUIRE_XPU(2); | |||||
OprAttr::Param param; | |||||
param.write_pod(megdnn::param::Axis(0)); | |||||
OperatorNodeConfig config{CompNode::load("xpu1")}; | |||||
OprChecker(OprAttr::make("Concat", param, config)) | |||||
.run({TensorShape{200, 300}, TensorShape{300, 300}}); | |||||
} | } | ||||
TEST(TestImperative, Split) { | TEST(TestImperative, Split) { | ||||
@@ -147,36 +148,36 @@ void run_graph(size_t mem_reserved, bool enable_defrag) { | |||||
} | } | ||||
TEST(TestImperative, Defragment) { | TEST(TestImperative, Defragment) { | ||||
REQUIRE_GPU(1); | |||||
CompNode::load("gpux").activate(); | |||||
size_t reserve; | |||||
{ | |||||
size_t free, tot; | |||||
MGB_CUDA_CHECK(cudaMemGetInfo(&free, &tot)); | |||||
reserve = free * 0.92; | |||||
} | |||||
auto reserve_setting = ssprintf("b:%zu", reserve); | |||||
auto do_run = [reserve]() { | |||||
ASSERT_THROW(run_graph(reserve, false), MemAllocError); | |||||
run_graph(reserve, true); | |||||
}; | |||||
// reserve memory explicitly to avoid uncontrollable factors | |||||
constexpr const char* KEY = "MGB_CUDA_RESERVE_MEMORY"; | |||||
auto old_value = getenv(KEY); | |||||
setenv(KEY, reserve_setting.c_str(), 1); | |||||
MGB_TRY { | |||||
do_run(); | |||||
} MGB_FINALLY( | |||||
if (old_value) { | |||||
setenv(KEY, old_value, 1); | |||||
} else { | |||||
unsetenv(KEY); | |||||
} | |||||
CompNode::try_coalesce_all_free_memory(); | |||||
CompNode::finalize(); | |||||
); | |||||
#if WIN32 | |||||
//! FIXME, finalize on CUDA windows will be strip as windows CUDA101 DLL | |||||
//! issue | |||||
return; | |||||
#endif | |||||
REQUIRE_GPU(1); | |||||
CompNode::load("gpux").activate(); | |||||
size_t reserve; | |||||
{ | |||||
size_t free, tot; | |||||
MGB_CUDA_CHECK(cudaMemGetInfo(&free, &tot)); | |||||
reserve = free * 0.92; | |||||
} | |||||
auto reserve_setting = ssprintf("b:%zu", reserve); | |||||
auto do_run = [reserve]() { | |||||
ASSERT_THROW(run_graph(reserve, false), MemAllocError); | |||||
run_graph(reserve, true); | |||||
}; | |||||
// reserve memory explicitly to avoid uncontrollable factors | |||||
constexpr const char* KEY = "MGB_CUDA_RESERVE_MEMORY"; | |||||
auto old_value = getenv(KEY); | |||||
setenv(KEY, reserve_setting.c_str(), 1); | |||||
MGB_TRY { do_run(); } | |||||
MGB_FINALLY( | |||||
if (old_value) { setenv(KEY, old_value, 1); } else { | |||||
unsetenv(KEY); | |||||
} CompNode::try_coalesce_all_free_memory(); | |||||
CompNode::finalize();); | |||||
} | } | ||||
#endif // MGB_CUDA && MGB_ENABLE_EXCEPTION | #endif // MGB_CUDA && MGB_ENABLE_EXCEPTION | ||||
@@ -89,6 +89,7 @@ TEST(TestOprUtility, NopCallback) { | |||||
} | } | ||||
TEST(TestOprUtility, NopCallbackMixedInput) { | TEST(TestOprUtility, NopCallbackMixedInput) { | ||||
REQUIRE_XPU(2); | |||||
auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
auto x0 = opr::Host2DeviceCopy::make(*graph, HostTensorGenerator<dtype::Int32>()({2, 3}), OperatorNodeConfig(CompNode::load("xpu0"))); | auto x0 = opr::Host2DeviceCopy::make(*graph, HostTensorGenerator<dtype::Int32>()({2, 3}), OperatorNodeConfig(CompNode::load("xpu0"))); | ||||
auto x1 = opr::Host2DeviceCopy::make(*graph, HostTensorGenerator<dtype::Float32>()({2, 3}), OperatorNodeConfig(CompNode::load("xpu1"))); | auto x1 = opr::Host2DeviceCopy::make(*graph, HostTensorGenerator<dtype::Float32>()({2, 3}), OperatorNodeConfig(CompNode::load("xpu1"))); | ||||
@@ -43,10 +43,12 @@ void check_rng_basic(Args&& ...args) { | |||||
} | } | ||||
TEST(TestImperative, UniformRNGBasic) { | TEST(TestImperative, UniformRNGBasic) { | ||||
REQUIRE_XPU(2); | |||||
check_rng_basic<UniformRNG>(123); | check_rng_basic<UniformRNG>(123); | ||||
} | } | ||||
TEST(TestImperative, GaussianRNGBasic) { | TEST(TestImperative, GaussianRNGBasic) { | ||||
REQUIRE_XPU(2); | |||||
check_rng_basic<GaussianRNG>(123, 2.f, 3.f); | check_rng_basic<GaussianRNG>(123, 2.f, 3.f); | ||||
} | } | ||||
@@ -19,6 +19,9 @@ endif() | |||||
add_executable(megbrain_test ${SOURCES}) | add_executable(megbrain_test ${SOURCES}) | ||||
target_link_libraries(megbrain_test gtest gmock) | target_link_libraries(megbrain_test gtest gmock) | ||||
target_link_libraries(megbrain_test megbrain megdnn ${MGE_CUDA_LIBS}) | target_link_libraries(megbrain_test megbrain megdnn ${MGE_CUDA_LIBS}) | ||||
if (MGE_WITH_CUDA) | |||||
target_include_directories(megbrain_test PRIVATE ${CUDNN_INCLUDE_DIR}) | |||||
endif() | |||||
if(CXX_SUPPORT_WCLASS_MEMACCESS) | if(CXX_SUPPORT_WCLASS_MEMACCESS) | ||||
if(MGE_WITH_CUDA) | if(MGE_WITH_CUDA) | ||||
target_compile_options(megbrain_test PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler=-Wno-class-memaccess>" | target_compile_options(megbrain_test PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler=-Wno-class-memaccess>" | ||||
@@ -337,6 +337,14 @@ std::vector<CompNode> mgb::load_multiple_xpus(size_t num) { | |||||
return ret; | return ret; | ||||
} | } | ||||
bool mgb::check_xpu_available(size_t num) { | |||||
if (CompNode::get_device_count(CompNode::DeviceType::UNSPEC) < num) { | |||||
mgb_log_warn("skip test case that requires %zu XPU(s)", num); | |||||
return false; | |||||
} | |||||
return true; | |||||
} | |||||
bool mgb::check_gpu_available(size_t num) { | bool mgb::check_gpu_available(size_t num) { | ||||
if (CompNode::get_device_count(CompNode::DeviceType::CUDA) < num) { | if (CompNode::get_device_count(CompNode::DeviceType::CUDA) < num) { | ||||
mgb_log_warn("skip test case that requires %zu GPU(s)", num); | mgb_log_warn("skip test case that requires %zu GPU(s)", num); | ||||
@@ -492,6 +492,9 @@ std::vector<CompNode> load_multiple_xpus(size_t num); | |||||
//! check whether given number of GPUs is available | //! check whether given number of GPUs is available | ||||
bool check_gpu_available(size_t num); | bool check_gpu_available(size_t num); | ||||
//! check whether given number of XPUs is available | |||||
bool check_xpu_available(size_t num); | |||||
//! check whether given number of AMD GPUs is available | //! check whether given number of AMD GPUs is available | ||||
bool check_amd_gpu_available(size_t num); | bool check_amd_gpu_available(size_t num); | ||||
@@ -518,6 +521,12 @@ public: | |||||
PersistentCacheHook(GetHook on_get); | PersistentCacheHook(GetHook on_get); | ||||
~PersistentCacheHook(); | ~PersistentCacheHook(); | ||||
}; | }; | ||||
//! skip a testcase if xpu not available | |||||
#define REQUIRE_XPU(n) do { \ | |||||
if (!check_xpu_available(n)) \ | |||||
return; \ | |||||
} while(0) | |||||
//! skip a testcase if gpu not available | //! skip a testcase if gpu not available | ||||
#define REQUIRE_GPU(n) do { \ | #define REQUIRE_GPU(n) do { \ | ||||