@@ -36,7 +36,7 @@ public: | |||
virtual void exec( | |||
_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, | |||
_megdnn_workspace workspace) = 0; | |||
void deduce_dtype(DType A, DType B, DType& C); | |||
MGE_WIN_DECLSPEC_FUC void deduce_dtype(DType A, DType B, DType& C); | |||
void deduce_layout(const TensorLayout& A, const TensorLayout& B, TensorLayout& C); | |||
virtual size_t get_workspace_in_bytes( | |||
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0; | |||
@@ -73,7 +73,7 @@ public: | |||
virtual void exec( | |||
_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, | |||
_megdnn_workspace workspace) = 0; | |||
void deduce_dtype(DType A, DType B, DType& C); | |||
MGE_WIN_DECLSPEC_FUC void deduce_dtype(DType A, DType B, DType& C); | |||
void deduce_layout(const TensorLayout& A, const TensorLayout& B, TensorLayout& C); | |||
virtual size_t get_workspace_in_bytes( | |||
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0; | |||
@@ -44,216 +44,6 @@ def _elwise(*args, mode): | |||
return _elwise_apply(args, mode) | |||
@lru_cache(maxsize=None) | |||
def _get_extentedMatrixMulOp( | |||
device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy, | |||
): | |||
@subgraph("extentedMatrixMulOp", dtype, device, 2, gopt_level=2) | |||
def extentedMatrixMulOp(inputs, f, c): | |||
assert len(inputs) == 2 | |||
inp1, inp2 = inputs | |||
_dim1, _dim2 = dim1, dim2 | |||
def build_shape_head(shape, idx=-1): | |||
# shape[:idx] | |||
return f( | |||
builtin.Subtensor(items=[[0, False, True, False, False]]), | |||
shape, | |||
c(idx, "int32"), | |||
) | |||
def build_shape_tail(shape, idx=-1): | |||
# shape[idx:] | |||
return f( | |||
builtin.Subtensor(items=[[0, True, False, False, False]]), | |||
shape, | |||
c(idx, "int32"), | |||
) | |||
remove_row, remove_col = False, False | |||
if _dim1 == 1: | |||
_dim1 = 2 | |||
remove_row = True | |||
if _dim2 == 1: | |||
_dim2 = 2 | |||
remove_col = True | |||
if remove_row: | |||
inp1 = f(builtin.AddAxis(axis=[0,]), inp1) | |||
if remove_col: | |||
inp2 = f(builtin.AddAxis(axis=[1,]), inp2) | |||
shape1 = f(builtin.GetVarShape(), inp1) | |||
shape2 = f(builtin.GetVarShape(), inp2) | |||
if _dim1 > 2: | |||
inp1 = f( | |||
builtin.Reshape(), | |||
inp1, | |||
f( | |||
builtin.Concat(axis=0, comp_node=device), | |||
f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape1)), | |||
build_shape_tail(shape1), | |||
), | |||
) | |||
if _dim2 > 2: | |||
inp2 = f( | |||
builtin.Reshape(), | |||
inp2, | |||
f( | |||
builtin.Concat(axis=0, comp_node=device), | |||
f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape2)), | |||
build_shape_tail(shape2), | |||
), | |||
) | |||
op = builtin.MatrixMul( | |||
transposeA=transpose_a, | |||
transposeB=transpose_b, | |||
compute_mode=compute_mode, | |||
format=format, | |||
strategy=strategy.value, | |||
) | |||
result = f(op, inp1, inp2) | |||
result_shape = f(builtin.GetVarShape(), result) | |||
if _dim1 > 2: | |||
result = f( | |||
builtin.Reshape(), | |||
result, | |||
f( | |||
builtin.Concat(axis=0, comp_node=device), | |||
build_shape_head(shape1), | |||
build_shape_tail(result_shape), | |||
), | |||
) | |||
if _dim2 > 2: | |||
result = f( | |||
builtin.Reshape(), | |||
result, | |||
f( | |||
builtin.Concat(axis=0, comp_node=device), | |||
build_shape_head(shape2), | |||
build_shape_tail(result_shape), | |||
), | |||
) | |||
maxdim = _dim1 if _dim1 > _dim2 else _dim2 | |||
if remove_row: | |||
result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result) | |||
if remove_col: | |||
result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result) | |||
return (result,), (True,) | |||
return extentedMatrixMulOp | |||
@lru_cache(maxsize=None) | |||
def _get_extentedBatchedMatrixMulOp( | |||
device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy, | |||
): | |||
@subgraph("extentedBatchedMatrixMulOp", dtype, device, 2, gopt_level=2) | |||
def extentedBatchedMatrixMulOp(inputs, f, c): | |||
assert len(inputs) == 2 | |||
inp1, inp2 = inputs | |||
_dim1, _dim2 = dim1, dim2 | |||
def build_shape_head(shape, idx=-2): | |||
# shape[:idx] | |||
return f( | |||
builtin.Subtensor(items=[[0, False, True, False, False]]), | |||
shape, | |||
c(idx, "int32"), | |||
) | |||
def build_shape_tail(shape, idx=-2): | |||
# shape[idx:] | |||
return f( | |||
builtin.Subtensor(items=[[0, True, False, False, False]]), | |||
shape, | |||
c(idx, "int32"), | |||
) | |||
remove_row, remove_col = False, False | |||
if _dim1 == 1: | |||
_dim1 = 2 | |||
remove_row = True | |||
if _dim2 == 1: | |||
_dim2 = 2 | |||
remove_col = True | |||
if remove_row: | |||
inp1 = f(builtin.AddAxis(axis=[0,]), inp1) | |||
if remove_col: | |||
inp2 = f(builtin.AddAxis(axis=[1,]), inp2) | |||
shape1 = f(builtin.GetVarShape(), inp1) | |||
shape2 = f(builtin.GetVarShape(), inp2) | |||
maxdim = _dim1 if _dim1 > _dim2 else _dim2 | |||
if _dim1 > _dim2: | |||
# broadcast | |||
shape2 = f( | |||
builtin.Concat(axis=0, comp_node=device), | |||
build_shape_head(shape1, idx=-_dim2), # shape1[:-_dim2] | |||
shape2, | |||
) | |||
inp2 = f(builtin.Broadcast(), inp2, shape2) | |||
batch_shape = build_shape_head(shape1) | |||
if _dim2 > _dim1: | |||
# broadcast | |||
shape1 = f( | |||
builtin.Concat(axis=0, comp_node=device), | |||
build_shape_head(shape2, idx=-_dim1), # shape2[:-_dim1] | |||
shape1, | |||
) | |||
inp1 = f(builtin.Broadcast(), inp1, shape1) | |||
batch_shape = build_shape_head(shape2) | |||
if _dim1 == _dim2: | |||
batch_shape = build_shape_head(shape1) | |||
# compress inputs to 3d | |||
if maxdim > 3: | |||
inp1 = f( | |||
builtin.Reshape(), | |||
inp1, | |||
f( | |||
builtin.Concat(axis=0, comp_node=device), | |||
f(builtin.Reduce(mode="product", axis=0), batch_shape), | |||
build_shape_tail(shape1), | |||
), | |||
) | |||
inp2 = f( | |||
builtin.Reshape(), | |||
inp2, | |||
f( | |||
builtin.Concat(axis=0, comp_node=device), | |||
f(builtin.Reduce(mode="product", axis=0), batch_shape), | |||
build_shape_tail(shape2), | |||
), | |||
) | |||
op = builtin.BatchedMatrixMul( | |||
transposeA=transpose_a, | |||
transposeB=transpose_b, | |||
compute_mode=compute_mode, | |||
format=format, | |||
strategy=strategy.value, | |||
) | |||
result = f(op, inp1, inp2) | |||
if maxdim > 3: | |||
result = f( | |||
builtin.Reshape(), | |||
result, | |||
f( | |||
builtin.Concat(axis=0, comp_node=device), | |||
batch_shape, | |||
build_shape_tail(f(builtin.GetVarShape(), result)), | |||
), | |||
) | |||
if remove_row: | |||
result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result) | |||
if remove_col: | |||
result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result) | |||
return (result,), (True,) | |||
return extentedBatchedMatrixMulOp | |||
class _Hashable: | |||
def __init__(self, value) -> None: | |||
self.value = value | |||
@@ -267,42 +57,6 @@ class _Hashable: | |||
return self.value == o.value | |||
def symbolicMatrixMul( | |||
inp1, inp2, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy | |||
): | |||
extentedMatrixMulOp = _get_extentedMatrixMulOp( | |||
inp1.device, | |||
inp1.dtype, | |||
dim1, | |||
dim2, | |||
transpose_a, | |||
transpose_b, | |||
compute_mode, | |||
format, | |||
strategy=_Hashable(strategy), | |||
) | |||
(result,) = apply(extentedMatrixMulOp(), inp1, inp2) | |||
return result | |||
def symbolicBatchedMatrixMul( | |||
inp1, inp2, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy | |||
): | |||
extentedBatchedMatrixMulOp = _get_extentedBatchedMatrixMulOp( | |||
inp1.device, | |||
inp1.dtype, | |||
dim1, | |||
dim2, | |||
transpose_a, | |||
transpose_b, | |||
compute_mode, | |||
format, | |||
strategy=_Hashable(strategy), | |||
) | |||
(result,) = apply(extentedBatchedMatrixMulOp(), inp1, inp2) | |||
return result | |||
def _matmul( | |||
inp1, | |||
inp2, | |||
@@ -342,11 +96,8 @@ def _matmul( | |||
transpose_a, | |||
transpose_b, | |||
compute_mode, | |||
format, | |||
_config._benchmark_kernel, | |||
_config._deterministic_kernel, | |||
strategy, | |||
symbolicMatrixMul, | |||
) | |||
else: # dispath to BatchedMatrixMul | |||
# nx1(transpose_a=True), n>=3 | |||
@@ -362,11 +113,8 @@ def _matmul( | |||
transpose_a, | |||
transpose_b, | |||
compute_mode, | |||
format, | |||
_config._benchmark_kernel, | |||
_config._deterministic_kernel, | |||
strategy, | |||
symbolicBatchedMatrixMul, | |||
) | |||
@@ -32,7 +32,7 @@ from ..core.ops.builtin import ( | |||
TypeCvt, | |||
) | |||
from ..core.tensor import amp, megbrain_graph | |||
from ..core.tensor.array_method import _elwise_apply | |||
from ..core.tensor.array_method import _matmul | |||
from ..core.tensor.utils import ( | |||
astensor1d, | |||
cast_tensors, | |||
@@ -49,7 +49,7 @@ from ..utils.deprecation import deprecated_func | |||
from .debug_param import get_execution_strategy | |||
from .distributed import all_reduce_sum | |||
from .elemwise import _elwise, exp, log, log1p, maximum, minimum | |||
from .math import matmul, max, sum | |||
from .math import max, sum | |||
from .tensor import broadcast_to, concat, expand_dims, ones, squeeze, zeros | |||
__all__ = [ | |||
@@ -127,7 +127,7 @@ def linear( | |||
bias: bias with shape `(out_features,)`. Default: None | |||
""" | |||
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | |||
ret = matmul(inp, weight, transpose_b=True, compute_mode=compute_mode) | |||
ret = _matmul(inp, weight, transpose_b=True, compute_mode=compute_mode) | |||
if bias is not None: | |||
if amp._enabled: | |||
bias = bias.astype("float16") | |||
@@ -1494,73 +1494,61 @@ py::object _transpose_cpp(py::handle inp_hdl, py::handle args) { | |||
py::object _matmul_cpp( | |||
py::handle inp1, py::handle inp2, py::handle dim1, py::handle dim2, | |||
py::handle transpose_a, py::handle transpose_b, py::handle compute_mode, | |||
py::handle format, py::handle profile, py::handle determistic, | |||
py::handle strategy, py::handle func) { | |||
if (enable_fastpath(inp1)) { | |||
::megdnn::param::MatrixMul::ComputeMode mode = | |||
::megdnn::param::MatrixMul::ComputeMode::DEFAULT; | |||
if (compute_mode.cast<std::string>().compare(std::string("float32")) == 0) { | |||
mode = ::megdnn::param::MatrixMul::ComputeMode::FLOAT32; | |||
} | |||
::megdnn::param::ExecutionPolicy::Strategy cstrategy; | |||
if (profile.cast<bool>()) { | |||
cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::PROFILE; | |||
} else { | |||
cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::HEURISTIC; | |||
} | |||
if (determistic.cast<bool>()) { | |||
cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::REPRODUCIBLE; | |||
} | |||
std::shared_ptr<OpDef> op = MatrixMul::make( | |||
transpose_a.cast<bool>(), transpose_b.cast<bool>(), mode, | |||
::megdnn::param::MatrixMul::Format::DEFAULT, cstrategy, UINT64_MAX); | |||
py::object Op = py::cast(op); | |||
PyObject* p[3] = {Op.ptr(), inp1.ptr(), inp2.ptr()}; | |||
py::tuple ret = py::reinterpret_steal<py::object>(py_apply(NULL, p, 3)); | |||
return ret[0]; | |||
py::handle profile, py::handle determistic) { | |||
::megdnn::param::MatrixMul::ComputeMode mode = | |||
::megdnn::param::MatrixMul::ComputeMode::DEFAULT; | |||
if (compute_mode.cast<std::string>().compare(std::string("float32")) == 0) { | |||
mode = ::megdnn::param::MatrixMul::ComputeMode::FLOAT32; | |||
} | |||
::megdnn::param::ExecutionPolicy::Strategy cstrategy = | |||
static_cast<::megdnn::param::ExecutionPolicy::Strategy>(0); | |||
if (profile.cast<bool>()) { | |||
cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::PROFILE; | |||
} else { | |||
// fallback to traceable implementation | |||
return func( | |||
inp1, inp2, dim1, dim2, transpose_a, transpose_b, compute_mode, format, | |||
strategy); | |||
cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::HEURISTIC; | |||
} | |||
if (determistic.cast<bool>()) { | |||
cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::REPRODUCIBLE; | |||
} | |||
std::shared_ptr<OpDef> op = MatrixMul::make( | |||
transpose_a.cast<bool>(), transpose_b.cast<bool>(), mode, | |||
::megdnn::param::MatrixMul::Format::DEFAULT, cstrategy, UINT64_MAX, | |||
dim1.cast<uint32_t>(), dim2.cast<uint32_t>()); | |||
py::object Op = py::cast(op); | |||
PyObject* p[3] = {Op.ptr(), inp1.ptr(), inp2.ptr()}; | |||
py::tuple ret = py::reinterpret_steal<py::object>(py_apply(NULL, p, 3)); | |||
return ret[0]; | |||
} | |||
py::object _batched_matmul_cpp( | |||
py::handle inp1, py::handle inp2, py::handle dim1, py::handle dim2, | |||
py::handle transpose_a, py::handle transpose_b, py::handle compute_mode, | |||
py::handle format, py::handle profile, py::handle determistic, | |||
py::handle strategy, py::handle func) { | |||
if (enable_fastpath(inp1)) { | |||
::megdnn::param::MatrixMul::ComputeMode mode = | |||
::megdnn::param::MatrixMul::ComputeMode::DEFAULT; | |||
if (compute_mode.cast<std::string>().compare(std::string("float32")) == 0) { | |||
mode = ::megdnn::param::MatrixMul::ComputeMode::FLOAT32; | |||
} | |||
::megdnn::param::ExecutionPolicy::Strategy cstrategy; | |||
if (profile.cast<bool>()) { | |||
cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::PROFILE; | |||
} else { | |||
cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::HEURISTIC; | |||
} | |||
if (determistic.cast<bool>()) { | |||
cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::REPRODUCIBLE; | |||
} | |||
std::shared_ptr<OpDef> op = BatchedMatrixMul::make( | |||
transpose_a.cast<bool>(), transpose_b.cast<bool>(), mode, | |||
::megdnn::param::MatrixMul::Format::DEFAULT, cstrategy, UINT64_MAX); | |||
py::object Op = py::cast(op); | |||
PyObject* p[3] = {Op.ptr(), inp1.ptr(), inp2.ptr()}; | |||
py::tuple ret = py::reinterpret_steal<py::object>(py_apply(NULL, p, 3)); | |||
return ret[0]; | |||
py::handle profile, py::handle determistic) { | |||
::megdnn::param::MatrixMul::ComputeMode mode = | |||
::megdnn::param::MatrixMul::ComputeMode::DEFAULT; | |||
if (compute_mode.cast<std::string>().compare(std::string("float32")) == 0) { | |||
mode = ::megdnn::param::MatrixMul::ComputeMode::FLOAT32; | |||
} | |||
::megdnn::param::ExecutionPolicy::Strategy cstrategy = | |||
static_cast<::megdnn::param::ExecutionPolicy::Strategy>(0); | |||
if (profile.cast<bool>()) { | |||
cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::PROFILE; | |||
} else { | |||
// fallback to traceable implementation | |||
return func( | |||
inp1, inp2, dim1, dim2, transpose_a, transpose_b, compute_mode, format, | |||
strategy); | |||
cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::HEURISTIC; | |||
} | |||
if (determistic.cast<bool>()) { | |||
cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::REPRODUCIBLE; | |||
} | |||
std::shared_ptr<OpDef> op = BatchedMatrixMul::make( | |||
transpose_a.cast<bool>(), transpose_b.cast<bool>(), mode, | |||
::megdnn::param::MatrixMul::Format::DEFAULT, cstrategy, UINT64_MAX, | |||
dim1.cast<uint32_t>(), dim2.cast<uint32_t>()); | |||
py::object Op = py::cast(op); | |||
PyObject* p[3] = {Op.ptr(), inp1.ptr(), inp2.ptr()}; | |||
py::tuple ret = py::reinterpret_steal<py::object>(py_apply(NULL, p, 3)); | |||
return ret[0]; | |||
} | |||
py::object _pixel_shuffle_cpp(py::handle inp, py::handle val, py::handle func) { | |||
@@ -1671,7 +1659,7 @@ PyObject* matmul_cpp(PyObject* self, PyObject* const* args, size_t nargs) { | |||
try { | |||
return _matmul_cpp( | |||
args[0], args[1], args[2], args[3], args[4], args[5], args[6], | |||
args[7], args[8], args[9], args[10], args[11]) | |||
args[7], args[8]) | |||
.release() | |||
.ptr(); | |||
} | |||
@@ -1682,7 +1670,7 @@ PyObject* batched_matmul_cpp(PyObject* self, PyObject* const* args, size_t nargs | |||
try { | |||
return _batched_matmul_cpp( | |||
args[0], args[1], args[2], args[3], args[4], args[5], args[6], | |||
args[7], args[8], args[9], args[10], args[11]) | |||
args[7], args[8]) | |||
.release() | |||
.ptr(); | |||
} | |||
@@ -20,7 +20,6 @@ import megengine.optimizer as optim | |||
from megengine import tensor | |||
from megengine.autodiff import GradManager | |||
from megengine.jit import trace | |||
from megengine.traced_module import trace_module | |||
@contextlib.contextmanager | |||
@@ -2,8 +2,12 @@ | |||
#include "../blob_manager_impl.h" | |||
#include "../dnn_op_helper.h" | |||
#include "../op_trait.h" | |||
#include "megbrain/graph/symbol_var.h" | |||
#include "megbrain/imperative/ops/autogen.h" | |||
#include "megbrain/opr/basic_arith.h" | |||
#include "megbrain/opr/blas.h" | |||
#include "megbrain/opr/io.h" | |||
#include "megbrain/opr/tensor_manip.h" | |||
#include "../algo_chooser.h" | |||
@@ -12,12 +16,93 @@ namespace imperative { | |||
namespace { | |||
namespace matrix_mul { | |||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& matmul = def.cast_final_safe<MatrixMul>(); | |||
mgb_assert(inputs.size() == 2); | |||
OperatorNodeConfig config{matmul.make_name()}; | |||
return opr::MatrixMul::make( | |||
inputs[0], inputs[1], matmul.param(), matmul.policy(), config); | |||
auto inp1 = SymbolVar{inputs[0]}, inp2 = SymbolVar{inputs[1]}; | |||
auto dim1 = matmul.dimA, dim2 = matmul.dimB; | |||
auto cn = inputs[0]->comp_node(); | |||
using Desc = opr::AxisAddRemove::AxisDesc; | |||
using IndexDesc = opr::Subtensor::IndexDesc; | |||
OperatorNodeConfig config{matmul.make_name(), cn}; | |||
DTypeScalar vi{-1}; | |||
auto graph = inputs[0]->owner_graph(); | |||
bool remove_row = false, remove_col = false; | |||
if (dim1 == 1) { | |||
dim1 = 2; | |||
remove_row = true; | |||
inp1 = inp1.add_axis(0); | |||
} | |||
if (dim2 == 1) { | |||
dim2 = 2; | |||
remove_col = true; | |||
inp2 = inp2.add_axis(1); | |||
} | |||
SymbolVar shp1_head, shp1_tail, shp2_head, shp2_tail; | |||
if (dim1 > 2) { | |||
auto idx = opr::ImmutableTensor::make(*graph, vi, config); | |||
auto shp1 = inp1.symshape(); | |||
IndexDesc head_desc(1); | |||
head_desc[0].end = idx; | |||
shp1_head = opr::Subtensor::make(shp1, head_desc); | |||
auto batch = opr::Reduce::make(shp1_head, {Reduce::Mode::PRODUCT, 0}); | |||
IndexDesc tail_desc(1); | |||
tail_desc[0].begin = idx; | |||
shp1_tail = opr::Subtensor::make(shp1, tail_desc); | |||
auto tshp = opr::Concat::make({batch, shp1_tail}, 0, cn); | |||
inp1 = inp1.reshape(tshp); | |||
} | |||
if (dim2 > 2) { | |||
auto idx = opr::ImmutableTensor::make(*graph, vi, config); | |||
auto shp2 = inp2.symshape(); | |||
IndexDesc head_desc(1); | |||
head_desc[0].end = idx; | |||
shp2_head = opr::Subtensor::make(shp2, head_desc); | |||
auto batch = opr::Reduce::make(shp2_head, {Reduce::Mode::PRODUCT, 0}); | |||
IndexDesc tail_desc(1); | |||
tail_desc[0].begin = idx; | |||
auto shp2_tail = opr::Subtensor::make(shp2, tail_desc); | |||
auto tshp = opr::Concat::make({batch, shp2_tail}, 0, cn); | |||
inp2 = inp2.reshape(tshp); | |||
} | |||
auto result = | |||
opr::MatrixMul::make(inp1, inp2, matmul.param(), matmul.policy(), config); | |||
if (dim1 > 2) { | |||
auto idx = opr::ImmutableTensor::make(*graph, vi, config); | |||
auto result_shape = result.symshape(); | |||
IndexDesc tail_desc(1); | |||
tail_desc[0].begin = idx; | |||
auto shp_tail = opr::Subtensor::make(result_shape, tail_desc); | |||
auto tshp = opr::Concat::make({shp1_head, shp_tail}, 0, cn); | |||
result = result.reshape(tshp); | |||
} | |||
if (dim2 > 2) { | |||
auto idx = opr::ImmutableTensor::make(*graph, vi, config); | |||
auto result_shape = result.symshape(); | |||
IndexDesc tail_desc(1); | |||
tail_desc[0].begin = idx; | |||
auto shp_tail = opr::Subtensor::make(result_shape, tail_desc); | |||
auto tshp = opr::Concat::make({shp2_head, shp_tail}, 0, cn); | |||
result = result.reshape(tshp); | |||
} | |||
auto maxdim = dim1 > dim2 ? dim1 : dim2; | |||
if (remove_row) { | |||
std::vector<Desc> remove_param; | |||
remove_param.push_back(Desc::make_remove(maxdim - 2)); | |||
result = opr::AxisAddRemove::make(result, remove_param); | |||
} | |||
if (remove_col) { | |||
std::vector<Desc> remove_param; | |||
remove_param.push_back(Desc::make_remove(maxdim - 1)); | |||
result = opr::AxisAddRemove::make(result, remove_param); | |||
} | |||
return result; | |||
} | |||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
@@ -27,8 +112,14 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
auto layout2 = inputs[1].layout; | |||
size_t dim1 = layout1.ndim, dim2 = layout2.ndim; | |||
DType dst_dtype; | |||
DnnOprCaller<megdnn::MatrixMul> dnn_opr(inputs[0].comp_node); | |||
dnn_opr.op->param() = matmul.param(); | |||
dnn_opr.op->deduce_dtype(layout1.dtype, layout1.dtype, dst_dtype); | |||
if (dim1 == 0 || dim2 == 0) { | |||
return {{{TensorLayout(layout1.dtype), inputs[0].comp_node}}, false}; | |||
return {{{TensorLayout(dst_dtype), inputs[0].comp_node}}, false}; | |||
} | |||
if (matmul.transposeA) | |||
@@ -37,7 +128,8 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
std::swap(layout2[0], layout2[1]); | |||
mgb_assert(layout1[dim1 - 1] == layout2[0]); | |||
TensorLayout dst_layout(layout1.dtype); | |||
TensorLayout dst_layout(dst_dtype); | |||
size_t ci = 0; | |||
for (size_t i = 0; i < dim1 - 1; i++) | |||
dst_layout[ci++] = layout1[i]; | |||
@@ -61,6 +153,12 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||
SmallVector<TensorND> inp_tensornds(inputs.size()); | |||
TensorLayout layout1 = inputs[0]->layout(), layout2 = inputs[1]->layout(); | |||
DnnOprCaller<megdnn::MatrixMul> dnn_opr(cn); | |||
dnn_opr.op->param() = matmul.param(); | |||
DType dst_dtype; | |||
dnn_opr.op->deduce_dtype(layout1.dtype, layout1.dtype, dst_dtype); | |||
// only matters when layout1 has dim 2 | |||
if (matmul.transposeA) | |||
std::swap(layout1.shape[0], layout1.shape[1]); | |||
@@ -69,7 +167,7 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||
std::swap(layout2.shape[0], layout2.shape[1]); | |||
size_t dim1 = layout1.ndim, dim2 = layout2.ndim; | |||
TensorLayout real_dst_layout(layout1.dtype); | |||
TensorLayout real_dst_layout(dst_dtype); | |||
if (validated) { | |||
real_dst_layout = output_descs[0].layout; | |||
} else { | |||
@@ -126,12 +224,9 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||
inp_tensornds[1] = inputs[1]->dnn_tensor(); | |||
} | |||
TensorLayout dst_layout = TensorLayout({layout_a[0], layout_b[1]}, layout_a.dtype); | |||
TensorLayout dst_layout = TensorLayout({layout_a[0], layout_b[1]}, dst_dtype); | |||
dst_layout.init_contiguous_stride(); | |||
DnnOprCaller<megdnn::MatrixMul> dnn_opr(cn); | |||
dnn_opr.op->param() = matmul.param(); | |||
DeviceTensorND out = | |||
BlobManager::inst()->alloc_workspace_with_defrag(cn, dst_layout); | |||
size_t sz = setup_algo<megdnn::MatrixMul>( | |||
@@ -167,9 +262,99 @@ namespace batched_matrix_mul { | |||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& matmul = def.cast_final_safe<BatchedMatrixMul>(); | |||
mgb_assert(inputs.size() == 2); | |||
OperatorNodeConfig config{matmul.make_name()}; | |||
return opr::BatchedMatrixMul::make( | |||
inputs[0], inputs[1], matmul.param(), matmul.policy(), config); | |||
auto inp1 = SymbolVar{inputs[0]}, inp2 = SymbolVar{inputs[1]}; | |||
auto dim1 = matmul.dimA, dim2 = matmul.dimB; | |||
auto cn = inputs[0]->comp_node(); | |||
using Desc = opr::AxisAddRemove::AxisDesc; | |||
using IndexDesc = opr::Subtensor::IndexDesc; | |||
OperatorNodeConfig config{matmul.make_name(), cn}; | |||
DTypeScalar vi{-2}; | |||
auto graph = inputs[0]->owner_graph(); | |||
auto idx = opr::ImmutableTensor::make(*graph, vi, config); | |||
bool remove_row = false, remove_col = false; | |||
if (dim1 == 1) { | |||
dim1 = 2; | |||
remove_row = true; | |||
inp1 = inp1.add_axis(0); | |||
} | |||
if (dim2 == 1) { | |||
dim2 = 2; | |||
remove_col = true; | |||
inp2 = inp2.add_axis(1); | |||
} | |||
auto shp1 = inp1.symshape(); | |||
auto shp2 = inp2.symshape(); | |||
SymbolVar shp1_head, shp1_tail, shp2_head, shp2_tail; | |||
SymbolVar batch_shape; | |||
if (dim1 > dim2) { | |||
HostTensorND hv = HostTensorND(cn, {1}, dtype::Int32()); | |||
auto* ptr = hv.ptr<dt_int32>(); | |||
ptr[0] = -dim2; | |||
IndexDesc head_desc(1); | |||
head_desc[0].end = opr::ImmutableTensor::make(*graph, hv, config); | |||
shp1_head = opr::Subtensor::make(shp1, head_desc); | |||
shp2 = opr::Concat::make({shp1_head, shp2}, 0, cn); | |||
inp2 = inp2.broadcast(shp2); | |||
head_desc[0].end = idx; | |||
batch_shape = opr::Subtensor::make(shp1, head_desc); | |||
} | |||
if (dim2 > dim1) { | |||
HostTensorND hv = HostTensorND(cn, {1}, dtype::Int32()); | |||
auto* ptr = hv.ptr<dt_int32>(); | |||
ptr[0] = -dim1; | |||
IndexDesc head_desc(1); | |||
head_desc[0].end = opr::ImmutableTensor::make(*graph, hv, config); | |||
shp2_head = opr::Subtensor::make(shp2, head_desc); | |||
shp1 = opr::Concat::make({shp2_head, shp1}, 0, cn); | |||
inp1 = inp1.broadcast(shp1); | |||
head_desc[0].end = idx; | |||
batch_shape = opr::Subtensor::make(shp2, head_desc); | |||
} | |||
if (dim1 == dim2) { | |||
IndexDesc head_desc(1); | |||
head_desc[0].end = idx; | |||
batch_shape = opr::Subtensor::make(shp1, head_desc); | |||
} | |||
auto maxdim = dim1 > dim2 ? dim1 : dim2; | |||
if (maxdim > 3) { | |||
IndexDesc tail_desc(1); | |||
tail_desc[0].begin = idx; | |||
shp1_tail = opr::Subtensor::make(shp1, tail_desc); | |||
auto batch = opr::Reduce::make(batch_shape, {Reduce::Mode::PRODUCT, 0}); | |||
shp1 = opr::Concat::make({batch, shp1_tail}, 0, cn); | |||
inp1 = inp1.reshape(shp1); | |||
shp2_tail = opr::Subtensor::make(shp2, tail_desc); | |||
shp2 = opr::Concat::make({batch, shp2_tail}, 0, cn); | |||
inp2 = inp2.reshape(shp2); | |||
} | |||
auto result = opr::BatchedMatrixMul::make( | |||
inp1, inp2, matmul.param(), matmul.policy(), config); | |||
if (maxdim > 3) { | |||
auto result_shp = result.symshape(); | |||
IndexDesc tail_desc(1); | |||
tail_desc[0].begin = idx; | |||
auto shp_tail = opr::Subtensor::make(result_shp, tail_desc); | |||
result_shp = opr::Concat::make({batch_shape, shp_tail}, 0, cn); | |||
result = result.reshape(result_shp); | |||
} | |||
if (remove_row) { | |||
std::vector<Desc> remove_param; | |||
remove_param.push_back(Desc::make_remove(maxdim - 2)); | |||
result = opr::AxisAddRemove::make(result, remove_param); | |||
} | |||
if (remove_col) { | |||
std::vector<Desc> remove_param; | |||
remove_param.push_back(Desc::make_remove(maxdim - 1)); | |||
result = opr::AxisAddRemove::make(result, remove_param); | |||
} | |||
return result; | |||
} | |||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
@@ -178,8 +363,14 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
TensorLayout layout1 = inputs[0].layout, layout2 = inputs[1].layout; | |||
size_t dim1 = layout1.ndim, dim2 = layout2.ndim; | |||
DType dst_dtype; | |||
DnnOprCaller<megdnn::MatrixMul> dnn_opr(inputs[0].comp_node); | |||
dnn_opr.op->param() = matmul.param(); | |||
dnn_opr.op->deduce_dtype(layout1.dtype, layout1.dtype, dst_dtype); | |||
if (dim1 == 0 || dim2 == 0) { | |||
return {{{TensorLayout(layout1.dtype), inputs[0].comp_node}}, false}; | |||
return {{{TensorLayout(dst_dtype), inputs[0].comp_node}}, false}; | |||
} | |||
if (matmul.transposeA) | |||
@@ -187,7 +378,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
if (matmul.transposeB) | |||
std::swap(layout2[dim2 - 1], layout2[dim2 - 2]); | |||
TensorLayout dst_layout(layout1.dtype); | |||
TensorLayout dst_layout(dst_dtype); | |||
size_t di = 0; | |||
if (dim1 > dim2) { | |||
for (size_t i = 0; i < dim1 - 2; i++) | |||
@@ -217,6 +408,11 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||
TensorLayout layout1 = inputs[0]->layout(), layout2 = inputs[1]->layout(); | |||
size_t dim1 = layout1.ndim, dim2 = layout2.ndim; | |||
DnnOprCaller<megdnn::BatchedMatrixMul> dnn_opr(cn); | |||
dnn_opr.op->param() = matmul.param(); | |||
DType dst_dtype; | |||
dnn_opr.op->deduce_dtype(layout1.dtype, layout1.dtype, dst_dtype); | |||
bool remove_row = false, remove_col = false; | |||
if (dim1 == 1) { | |||
dim1 = 2; | |||
@@ -234,6 +430,7 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||
TensorShape tshp, batch_shp; | |||
size_t j = 0; | |||
auto inp1 = inputs[0], inp2 = inputs[1]; | |||
if (dim1 > dim2) { | |||
for (size_t i = 0; i < dim1 - 2; i++) | |||
tshp[j++] = layout1.shape[i]; | |||
@@ -266,7 +463,6 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||
shp2.ndim += 2; | |||
size_t maxdim = dim1 > dim2 ? dim1 : dim2; | |||
size_t nbatch = batch_shp[0]; | |||
auto inp1 = inputs[0], inp2 = inputs[1]; | |||
if (maxdim > 3) { | |||
nbatch = std::accumulate( | |||
batch_shp.shape, batch_shp.shape + batch_shp.ndim, (size_t)1, | |||
@@ -274,29 +470,29 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||
TensorLayout layout_a; | |||
// batched_matmul does not support memory forwarding, so ensure contiguous | |||
// manually | |||
TensorShape nl1 = TensorShape( | |||
{nbatch, layout1[layout1.ndim - 2], layout1[layout1.ndim - 1]}); | |||
if (!layout1.try_reshape(layout_a, nl1)) { | |||
inp1 = Tensor::make(inputs[0]->blob(), inputs[0]->offset(), layout1); | |||
inp1->to_contiguous_inplace(); | |||
layout1 = inp1->layout(); | |||
} | |||
inp1 = Tensor::make(inputs[0]->blob(), inputs[0]->offset(), layout1); | |||
inp1->to_contiguous_inplace(); | |||
layout1 = inp1->layout(); | |||
layout_a = layout1.reshape(nl1); | |||
layout1 = layout_a; | |||
TensorShape nl2 = TensorShape( | |||
{nbatch, layout2[layout2.ndim - 2], layout2[layout2.ndim - 1]}); | |||
if (!layout2.try_reshape(layout_a, nl2)) { | |||
inp2 = Tensor::make(inputs[1]->blob(), inputs[1]->offset(), layout2); | |||
inp2->to_contiguous_inplace(); | |||
layout2 = inp2->layout(); | |||
} | |||
inp2 = Tensor::make(inputs[1]->blob(), inputs[1]->offset(), layout2); | |||
inp2->to_contiguous_inplace(); | |||
layout2 = inp2->layout(); | |||
layout_a = layout2.reshape(nl2); | |||
layout2 = layout_a; | |||
} | |||
TensorLayout dst_layout( | |||
{nbatch, matmul.transposeA ? layout1[2] : layout1[1], | |||
matmul.transposeB ? layout2[1] : layout2[2]}, | |||
layout1.dtype); | |||
dst_dtype); | |||
dst_layout.init_contiguous_stride(); | |||
if (dim1 == 0 || dim2 == 0 || layout1[layout1.ndim - 1] == 0) { | |||
@@ -317,9 +513,6 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||
DeviceTensorND out = | |||
BlobManager::inst()->alloc_workspace_with_defrag(cn, dst_layout); | |||
DnnOprCaller<megdnn::BatchedMatrixMul> dnn_opr(cn); | |||
dnn_opr.op->param() = matmul.param(); | |||
size_t sz = setup_algo<megdnn::BatchedMatrixMul>( | |||
{layout1, layout2, dst_layout}, dnn_opr.op.get(), 0, false, false, cn, | |||
matmul.policy(), false); | |||
@@ -246,7 +246,12 @@ private: | |||
it.name, enumMember.substr(0, d)); | |||
body += " break;\n"; | |||
} | |||
body += " default: break;\n"; | |||
body += " default:\n"; | |||
body += | |||
formatv(" props_.emplace_back(\"{0}\", " | |||
"\"INVALID\");\n", | |||
it.name); | |||
body += " break;\n"; | |||
body += " }\n"; | |||
} else { | |||
auto&& attr = llvm::cast<MgbHashableAttrMixin>(it.attr); | |||
@@ -89,19 +89,35 @@ void OpDefEmitter::emit_header() { | |||
gen_ctor("", "", " = default;"); | |||
if (!op.getMgbAttributes().empty()) { | |||
std::string strategy_val = ""; | |||
std::vector<std::string> paramList, initList; | |||
for (auto&& i : op.getMgbAttributes()) { | |||
if (attr_to_ctype(i.attr).compare("Strategy") == 0) { | |||
strategy_val = i.name; | |||
} | |||
paramList.push_back(formatv("{0} {1}_", attr_to_ctype(i.attr), i.name)); | |||
initList.push_back(formatv("{0}({0}_)", i.name)); | |||
} | |||
paramList.push_back("std::string scope_ = {}"); | |||
gen_ctor( | |||
llvm::join(paramList, ", "), ": " + llvm::join(initList, ", "), | |||
" { set_scope(scope_); }"); | |||
if (!strategy_val.empty()) { | |||
gen_ctor( | |||
llvm::join(paramList, ", "), ": " + llvm::join(initList, ", "), | |||
formatv(" {" | |||
"\n set_scope(scope_);" | |||
"\n mgb_assert(static_cast<uint32_t>({0}) <= " | |||
"uint32_t(8));" | |||
"\n }", | |||
strategy_val)); | |||
} else { | |||
gen_ctor( | |||
llvm::join(paramList, ", "), ": " + llvm::join(initList, ", "), | |||
" { set_scope(scope_); }"); | |||
} | |||
} | |||
auto packedParams = op.getPackedParams(); | |||
if (!packedParams.empty()) { | |||
std::string strategy_val = ""; | |||
std::vector<std::string> paramList, initList; | |||
for (auto&& p : packedParams) { | |||
auto&& paramFields = p.getFields(); | |||
@@ -111,6 +127,9 @@ void OpDefEmitter::emit_header() { | |||
paramFields.empty() ? paramType.str() | |||
: formatv("{0} {1}", paramType, paramName)); | |||
for (auto&& i : paramFields) { | |||
if (i.name.compare("strategy") == 0) { | |||
strategy_val = i.name; | |||
} | |||
initList.push_back(formatv("{0}({1}.{0})", i.name, paramName)); | |||
} | |||
} | |||
@@ -118,9 +137,20 @@ void OpDefEmitter::emit_header() { | |||
paramList.push_back(formatv("{0} {1}_", attr_to_ctype(i.attr), i.name)); | |||
initList.push_back(formatv("{0}({0}_)", i.name)); | |||
} | |||
gen_ctor( | |||
llvm::join(paramList, ", "), | |||
initList.empty() ? "" : ": " + llvm::join(initList, ", "), " {}"); | |||
if (!strategy_val.empty()) { | |||
gen_ctor( | |||
llvm::join(paramList, ", "), | |||
initList.empty() ? "" : ": " + llvm::join(initList, ", "), | |||
formatv(" {" | |||
"\n mgb_assert(static_cast<uint32_t>({0}) <= " | |||
"uint32_t(8));" | |||
"\n }", | |||
strategy_val)); | |||
} else { | |||
gen_ctor( | |||
llvm::join(paramList, ", "), | |||
initList.empty() ? "" : ": " + llvm::join(initList, ", "), " {}"); | |||
} | |||
} | |||
if (!packedParams.empty()) { | |||
@@ -43,9 +43,19 @@ def TypeCvt: MgbHashableOp<"TypeCvt", [], [NoSideEffect]> { | |||
def MatrixInverse: MgbHashableOp<"MatrixInverse", [EmptyParam]>; | |||
def MatrixMul: MgbHashableOp<"MatrixMul", [MatrixMulParam, ExecutionPolicyParamBase<"policy">]>; | |||
def MatrixMul: MgbHashableOp<"MatrixMul", [MatrixMulParam, ExecutionPolicyParamBase<"policy">]> { | |||
let extraArguments = (ins | |||
MgbUI32Attr:$dimA, | |||
MgbUI32Attr:$dimB | |||
); | |||
} | |||
def BatchedMatrixMul: MgbHashableOp<"BatchedMatmul", [MatrixMulParam, ExecutionPolicyParamBase<"policy">]>; | |||
def BatchedMatrixMul: MgbHashableOp<"BatchedMatmul", [MatrixMulParam, ExecutionPolicyParamBase<"policy">]> { | |||
let extraArguments = (ins | |||
MgbUI32Attr:$dimA, | |||
MgbUI32Attr:$dimB | |||
); | |||
} | |||
def Dot: MgbHashableOp<"Dot", [EmptyParam]>; | |||