|
|
@@ -14,6 +14,7 @@ from megengine.core._imperative_rt.core2 import apply |
|
|
|
from megengine.core._wrap import Device |
|
|
|
from megengine.core.ops import builtin |
|
|
|
from megengine.device import is_cuda_available |
|
|
|
from megengine.distributed.helper import get_device_count_by_fork |
|
|
|
from megengine.functional.external import tensorrt_runtime_opr |
|
|
|
from megengine.jit.tracing import trace |
|
|
|
from megengine.tensor import Tensor |
|
|
@@ -265,6 +266,10 @@ def test_deformable_ps_roi_pooling(): |
|
|
|
check_pygraph_dump(fwd, [inp, rois, trans], [result]) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skipif( |
|
|
|
get_device_count_by_fork("gpu") > 0, |
|
|
|
reason="does not support int8 when gpu compute capability less than 6.1", |
|
|
|
) |
|
|
|
def test_convbias(): |
|
|
|
@trace(symbolic=True, capture_as_const=True) |
|
|
|
def fwd(inp, weight, bias): |
|
|
|