|
|
@@ -14,6 +14,10 @@ from megengine.core._imperative_rt.core2 import apply |
|
|
|
from megengine.core._wrap import Device |
|
|
|
from megengine.core.ops import builtin |
|
|
|
from megengine.device import get_device_count, is_cuda_available |
|
|
|
from megengine.functional.debug_param import ( |
|
|
|
get_execution_strategy, |
|
|
|
set_execution_strategy, |
|
|
|
) |
|
|
|
from megengine.functional.external import tensorrt_runtime_opr |
|
|
|
from megengine.jit.tracing import trace |
|
|
|
from megengine.tensor import Tensor |
|
|
@@ -106,10 +110,13 @@ def test_matmul(): |
|
|
|
def fwd(data1, data2): |
|
|
|
return F.matmul(data1, data2) |
|
|
|
|
|
|
|
old = get_execution_strategy() |
|
|
|
set_execution_strategy("HEURISTIC_REPRODUCIBLE") |
|
|
|
data1 = Tensor(np.random.random((32, 64))) |
|
|
|
data2 = Tensor(np.random.random((64, 16))) |
|
|
|
result = fwd(data1, data2) |
|
|
|
check_pygraph_dump(fwd, [data1, data2], [result]) |
|
|
|
set_execution_strategy(old) |
|
|
|
|
|
|
|
|
|
|
|
def test_batchmatmul(): |
|
|
|