GitOrigin-RevId: 4ceb2eb601
release-1.10
@@ -20,9 +20,10 @@ from .._imperative_rt.core2 import ( | |||||
Tensor, | Tensor, | ||||
apply, | apply, | ||||
astype_cpp, | astype_cpp, | ||||
batched_matmul_cpp, | |||||
broadcast_cpp, | broadcast_cpp, | ||||
dtype_promotion, | |||||
getitem_cpp, | getitem_cpp, | ||||
matmul_cpp, | |||||
) | ) | ||||
from .._imperative_rt.core2 import reduce_to_scalar as _reduce_to_scalar | from .._imperative_rt.core2 import reduce_to_scalar as _reduce_to_scalar | ||||
from .._imperative_rt.core2 import reshape_cpp, setitem_cpp, squeeze_cpp, transpose_cpp | from .._imperative_rt.core2 import reshape_cpp, setitem_cpp, squeeze_cpp, transpose_cpp | ||||
@@ -266,6 +267,42 @@ class _Hashable: | |||||
return self.value == o.value | return self.value == o.value | ||||
def symbolicMatrixMul( | |||||
inp1, inp2, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy | |||||
): | |||||
extentedMatrixMulOp = _get_extentedMatrixMulOp( | |||||
inp1.device, | |||||
inp1.dtype, | |||||
dim1, | |||||
dim2, | |||||
transpose_a, | |||||
transpose_b, | |||||
compute_mode, | |||||
format, | |||||
strategy=_Hashable(strategy), | |||||
) | |||||
(result,) = apply(extentedMatrixMulOp(), inp1, inp2) | |||||
return result | |||||
def symbolicBatchedMatrixMul( | |||||
inp1, inp2, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy | |||||
): | |||||
extentedBatchedMatrixMulOp = _get_extentedBatchedMatrixMulOp( | |||||
inp1.device, | |||||
inp1.dtype, | |||||
dim1, | |||||
dim2, | |||||
transpose_a, | |||||
transpose_b, | |||||
compute_mode, | |||||
format, | |||||
strategy=_Hashable(strategy), | |||||
) | |||||
(result,) = apply(extentedBatchedMatrixMulOp(), inp1, inp2) | |||||
return result | |||||
def _matmul( | def _matmul( | ||||
inp1, | inp1, | ||||
inp2, | inp2, | ||||
@@ -274,16 +311,6 @@ def _matmul( | |||||
compute_mode="default", | compute_mode="default", | ||||
format="default", | format="default", | ||||
): | ): | ||||
if amp._enabled: | |||||
compute_mode = "float32" | |||||
inp1, inp2 = cast_tensors(inp1, inp2) | |||||
else: | |||||
dtype = dtype_promotion(inp1, inp2) | |||||
if inp1.dtype != dtype: | |||||
inp1 = inp1.astype(dtype) | |||||
if inp2.dtype != dtype: | |||||
inp2 = inp2.astype(dtype) | |||||
dim1, dim2 = inp1.ndim, inp2.ndim | dim1, dim2 = inp1.ndim, inp2.ndim | ||||
assert dim1 > 0 and dim2 > 0 | assert dim1 > 0 and dim2 > 0 | ||||
maxdim = dim1 if dim1 > dim2 else dim2 | maxdim = dim1 if dim1 > dim2 else dim2 | ||||
@@ -301,34 +328,46 @@ def _matmul( | |||||
if dim1 == 1 and dim2 == 1: # dispatch to Dot | if dim1 == 1 and dim2 == 1: # dispatch to Dot | ||||
(result,) = apply(builtin.Dot(), inp1, inp2) | (result,) = apply(builtin.Dot(), inp1, inp2) | ||||
return result | return result | ||||
elif maxdim <= 2 or dim2 <= 2: # dispath to MatrixMul | |||||
extentedMatrixMulOp = _get_extentedMatrixMulOp( | |||||
inp1.device, | |||||
inp1.dtype, | |||||
elif maxdim <= 2 or (dim2 <= 2 and not transpose_a): # dispath to MatrixMul | |||||
# 2x1 | |||||
# 1x2 | |||||
# 2x2 | |||||
# nx1(transpose_a=False), n>=3 | |||||
# nx2(transpose_a=False), n>=3 | |||||
return matmul_cpp( | |||||
inp1, | |||||
inp2, | |||||
dim1, | dim1, | ||||
dim2, | dim2, | ||||
transpose_a, | transpose_a, | ||||
transpose_b, | transpose_b, | ||||
compute_mode, | compute_mode, | ||||
format, | format, | ||||
strategy=_Hashable(strategy), | |||||
_config._benchmark_kernel, | |||||
_config._deterministic_kernel, | |||||
strategy, | |||||
symbolicMatrixMul, | |||||
) | ) | ||||
(result,) = apply(extentedMatrixMulOp(), inp1, inp2) | |||||
return result | |||||
else: # dispath to BatchedMatrixMul | else: # dispath to BatchedMatrixMul | ||||
extentedBatchedMatrixMulOp = _get_extentedBatchedMatrixMulOp( | |||||
inp1.device, | |||||
inp1.dtype, | |||||
# nx1(transpose_a=True), n>=3 | |||||
# nx2(transpose_a=True), n>=3 | |||||
# nxm,n>=3,m>=3 | |||||
# 1xm,m>=3 | |||||
# 2xm,m>=3 | |||||
return batched_matmul_cpp( | |||||
inp1, | |||||
inp2, | |||||
dim1, | dim1, | ||||
dim2, | dim2, | ||||
transpose_a, | transpose_a, | ||||
transpose_b, | transpose_b, | ||||
compute_mode, | compute_mode, | ||||
format, | format, | ||||
strategy=_Hashable(strategy), | |||||
_config._benchmark_kernel, | |||||
_config._deterministic_kernel, | |||||
strategy, | |||||
symbolicBatchedMatrixMul, | |||||
) | ) | ||||
(result,) = apply(extentedBatchedMatrixMulOp(), inp1, inp2) | |||||
return result | |||||
def _unary_elwise(mode): | def _unary_elwise(mode): | ||||
@@ -10,7 +10,7 @@ import collections | |||||
import math | import math | ||||
from typing import Iterable, Optional, Sequence, Tuple, Union | from typing import Iterable, Optional, Sequence, Tuple, Union | ||||
from ..core._imperative_rt.core2 import Const, apply, dtype_promotion | |||||
from ..core._imperative_rt.core2 import Const, apply | |||||
from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder | from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder | ||||
from ..core.ops import builtin | from ..core.ops import builtin | ||||
from ..core.tensor.array_method import _matmul | from ..core.tensor.array_method import _matmul | ||||
@@ -17,7 +17,6 @@ from ..core._imperative_rt.core2 import ( | |||||
apply, | apply, | ||||
dtype_promotion, | dtype_promotion, | ||||
) | ) | ||||
from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder | |||||
from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed | from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed | ||||
from ..core.ops import builtin | from ..core.ops import builtin | ||||
from ..core.ops.builtin import ( | from ..core.ops.builtin import ( | ||||
@@ -177,16 +176,6 @@ def conv1d( | |||||
assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT" | assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT" | ||||
assert inp.ndim == 3, "the input dimension of conv1d should be 3" | assert inp.ndim == 3, "the input dimension of conv1d should be 3" | ||||
assert weight.ndim == 3, "the weight dimension of conv1d should be 3" | assert weight.ndim == 3, "the weight dimension of conv1d should be 3" | ||||
if amp._enabled: | |||||
compute_mode = "float32" | |||||
inp, weight, bias = cast_tensors(inp, weight, bias) | |||||
else: | |||||
dtype = dtype_promotion(inp, weight) | |||||
if inp.dtype != dtype: | |||||
inp = inp.astype(dtype) | |||||
if weight.dtype != dtype: | |||||
weight = weight.astype(dtype) | |||||
if bias is not None: | if bias is not None: | ||||
assert bias.ndim == 3, "the bias dimension of conv1d should be 3" | assert bias.ndim == 3, "the bias dimension of conv1d should be 3" | ||||
@@ -522,12 +511,6 @@ def local_conv2d( | |||||
pad_h, pad_w = expand_hw(padding) | pad_h, pad_w = expand_hw(padding) | ||||
dilate_h, dilate_w = expand_hw(dilation) | dilate_h, dilate_w = expand_hw(dilation) | ||||
dtype = dtype_promotion(inp, weight) | |||||
if inp.dtype != dtype: | |||||
inp = inp.astype(dtype) | |||||
if weight.dtype != dtype: | |||||
weight = weight.astype(dtype) | |||||
# local conv only support "dense" mode, but weight could contain group dimension. | # local conv only support "dense" mode, but weight could contain group dimension. | ||||
op = builtin.GroupLocal( | op = builtin.GroupLocal( | ||||
stride_h=stride_h, | stride_h=stride_h, | ||||
@@ -433,6 +433,8 @@ WRAP_FUNC_PY35(reshape_cpp); | |||||
WRAP_FUNC_PY35(adaptive_pool2d_cpp); | WRAP_FUNC_PY35(adaptive_pool2d_cpp); | ||||
WRAP_FUNC_PY35(Const); | WRAP_FUNC_PY35(Const); | ||||
WRAP_FUNC_PY35(astype_cpp); | WRAP_FUNC_PY35(astype_cpp); | ||||
WRAP_FUNC_PY35(matmul_cpp); | |||||
WRAP_FUNC_PY35(batched_matmul_cpp); | |||||
WRAP_FUNC_PY35(convert_single_value_cpp); | WRAP_FUNC_PY35(convert_single_value_cpp); | ||||
WRAP_FUNC_PY35(convert_inputs_cpp); | WRAP_FUNC_PY35(convert_inputs_cpp); | ||||
WRAP_FUNC_PY35(astensor1d_cpp); | WRAP_FUNC_PY35(astensor1d_cpp); | ||||
@@ -588,6 +590,8 @@ void init_tensor(py::module m) { | |||||
MGE_PY_INTERFACE(adaptive_pool2d_cpp, adaptive_pool2d_cpp), | MGE_PY_INTERFACE(adaptive_pool2d_cpp, adaptive_pool2d_cpp), | ||||
MGE_PY_INTERFACE(Const, Const), | MGE_PY_INTERFACE(Const, Const), | ||||
MGE_PY_INTERFACE(astype_cpp, astype_cpp), | MGE_PY_INTERFACE(astype_cpp, astype_cpp), | ||||
MGE_PY_INTERFACE(matmul_cpp, matmul_cpp), | |||||
MGE_PY_INTERFACE(batched_matmul_cpp, batched_matmul_cpp), | |||||
MGE_PY_INTERFACE(convert_single_value_cpp, convert_single_value_cpp), | MGE_PY_INTERFACE(convert_single_value_cpp, convert_single_value_cpp), | ||||
MGE_PY_INTERFACE(convert_inputs_cpp, convert_inputs_cpp), | MGE_PY_INTERFACE(convert_inputs_cpp, convert_inputs_cpp), | ||||
MGE_PY_INTERFACE(astensor1d_cpp, astensor1d_cpp), | MGE_PY_INTERFACE(astensor1d_cpp, astensor1d_cpp), | ||||
@@ -1490,6 +1490,78 @@ py::object _transpose_cpp(py::handle inp_hdl, py::handle args) { | |||||
return ret[0]; | return ret[0]; | ||||
} | } | ||||
py::object _matmul_cpp( | |||||
py::handle inp1, py::handle inp2, py::handle dim1, py::handle dim2, | |||||
py::handle transpose_a, py::handle transpose_b, py::handle compute_mode, | |||||
py::handle format, py::handle profile, py::handle determistic, | |||||
py::handle strategy, py::handle func) { | |||||
if (enable_fastpath(inp1)) { | |||||
::megdnn::param::MatrixMul::ComputeMode mode = | |||||
::megdnn::param::MatrixMul::ComputeMode::DEFAULT; | |||||
if (compute_mode.cast<std::string>().compare(std::string("float32")) == 0) { | |||||
mode = ::megdnn::param::MatrixMul::ComputeMode::FLOAT32; | |||||
} | |||||
::megdnn::param::ExecutionPolicy::Strategy cstrategy; | |||||
if (profile.cast<bool>()) { | |||||
cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::PROFILE; | |||||
} else { | |||||
cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::HEURISTIC; | |||||
} | |||||
if (determistic.cast<bool>()) { | |||||
cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::REPRODUCIBLE; | |||||
} | |||||
std::shared_ptr<OpDef> op = MatrixMul::make( | |||||
transpose_a.cast<bool>(), transpose_b.cast<bool>(), mode, | |||||
::megdnn::param::MatrixMul::Format::DEFAULT, cstrategy, UINT64_MAX); | |||||
py::object Op = py::cast(op); | |||||
PyObject* p[3] = {Op.ptr(), inp1.ptr(), inp2.ptr()}; | |||||
py::tuple ret = py::reinterpret_steal<py::object>(py_apply(NULL, p, 3)); | |||||
return ret[0]; | |||||
} else { | |||||
// fallback to traceable implementation | |||||
return func( | |||||
inp1, inp2, dim1, dim2, transpose_a, transpose_b, compute_mode, format, | |||||
strategy); | |||||
} | |||||
} | |||||
py::object _batched_matmul_cpp( | |||||
py::handle inp1, py::handle inp2, py::handle dim1, py::handle dim2, | |||||
py::handle transpose_a, py::handle transpose_b, py::handle compute_mode, | |||||
py::handle format, py::handle profile, py::handle determistic, | |||||
py::handle strategy, py::handle func) { | |||||
if (enable_fastpath(inp1)) { | |||||
::megdnn::param::MatrixMul::ComputeMode mode = | |||||
::megdnn::param::MatrixMul::ComputeMode::DEFAULT; | |||||
if (compute_mode.cast<std::string>().compare(std::string("float32")) == 0) { | |||||
mode = ::megdnn::param::MatrixMul::ComputeMode::FLOAT32; | |||||
} | |||||
::megdnn::param::ExecutionPolicy::Strategy cstrategy; | |||||
if (profile.cast<bool>()) { | |||||
cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::PROFILE; | |||||
} else { | |||||
cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::HEURISTIC; | |||||
} | |||||
if (determistic.cast<bool>()) { | |||||
cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::REPRODUCIBLE; | |||||
} | |||||
std::shared_ptr<OpDef> op = BatchedMatrixMul::make( | |||||
transpose_a.cast<bool>(), transpose_b.cast<bool>(), mode, | |||||
::megdnn::param::MatrixMul::Format::DEFAULT, cstrategy, UINT64_MAX); | |||||
py::object Op = py::cast(op); | |||||
PyObject* p[3] = {Op.ptr(), inp1.ptr(), inp2.ptr()}; | |||||
py::tuple ret = py::reinterpret_steal<py::object>(py_apply(NULL, p, 3)); | |||||
return ret[0]; | |||||
} else { | |||||
// fallback to traceable implementation | |||||
return func( | |||||
inp1, inp2, dim1, dim2, transpose_a, transpose_b, compute_mode, format, | |||||
strategy); | |||||
} | |||||
} | |||||
PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs) { | PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs) { | ||||
try { | try { | ||||
return _make_shape_tuple(args[0]).release().ptr(); | return _make_shape_tuple(args[0]).release().ptr(); | ||||
@@ -1574,6 +1646,28 @@ PyObject* astype_cpp(PyObject* self, PyObject* const* args, size_t nargs) { | |||||
PYEXT17_TRANSLATE_EXC_RET(nullptr) | PYEXT17_TRANSLATE_EXC_RET(nullptr) | ||||
} | } | ||||
PyObject* matmul_cpp(PyObject* self, PyObject* const* args, size_t nargs) { | |||||
try { | |||||
return _matmul_cpp( | |||||
args[0], args[1], args[2], args[3], args[4], args[5], args[6], | |||||
args[7], args[8], args[9], args[10], args[11]) | |||||
.release() | |||||
.ptr(); | |||||
} | |||||
PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||||
} | |||||
PyObject* batched_matmul_cpp(PyObject* self, PyObject* const* args, size_t nargs) { | |||||
try { | |||||
return _batched_matmul_cpp( | |||||
args[0], args[1], args[2], args[3], args[4], args[5], args[6], | |||||
args[7], args[8], args[9], args[10], args[11]) | |||||
.release() | |||||
.ptr(); | |||||
} | |||||
PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||||
} | |||||
PyObject* convert_single_value_cpp( | PyObject* convert_single_value_cpp( | ||||
PyObject* self, PyObject* const* args, size_t nargs) { | PyObject* self, PyObject* const* args, size_t nargs) { | ||||
try { | try { | ||||
@@ -30,6 +30,10 @@ PyObject* Const(PyObject* self, PyObject* const* args, size_t nargs); | |||||
PyObject* astype_cpp(PyObject* self, PyObject* const* args, size_t nargs); | PyObject* astype_cpp(PyObject* self, PyObject* const* args, size_t nargs); | ||||
PyObject* matmul_cpp(PyObject* self, PyObject* const* args, size_t nargs); | |||||
PyObject* batched_matmul_cpp(PyObject* self, PyObject* const* args, size_t nargs); | |||||
PyObject* convert_single_value_cpp(PyObject* self, PyObject* const* args, size_t nargs); | PyObject* convert_single_value_cpp(PyObject* self, PyObject* const* args, size_t nargs); | ||||
PyObject* convert_inputs_cpp(PyObject* self, PyObject* const* args, size_t nargs); | PyObject* convert_inputs_cpp(PyObject* self, PyObject* const* args, size_t nargs); | ||||
@@ -1,87 +0,0 @@ | |||||
#include "megbrain/imperative/opr_utility.h" | |||||
#include "megbrain/imperative/ops/autogen.h" | |||||
#include "megbrain/imperative/utils/stats.h" | |||||
#include "megbrain/opr/basic_arith.h" | |||||
#include "megbrain/opr/blas.h" | |||||
#include "megbrain/opr/utility.h" | |||||
#include "../blob_manager_impl.h" | |||||
#include "../dnn_op_helper.h" | |||||
#include "../op_trait.h" | |||||
namespace mgb { | |||||
namespace imperative { | |||||
namespace { | |||||
namespace dot { | |||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& op = def.cast_final_safe<Dot>(); | |||||
mgb_assert(inputs.size() == 2); | |||||
OperatorNodeConfig config{op.make_name()}; | |||||
return opr::Dot::make(inputs[0], inputs[1], config); | |||||
} | |||||
SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||||
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||||
auto comp_node = inputs[0]->comp_node(); | |||||
using TensorND = megdnn::TensorND; | |||||
SmallVector<TensorND> inp_tensornds; | |||||
inp_tensornds.reserve(inputs.size()); | |||||
DnnOprCaller<megdnn::Dot> dnn_opr(comp_node); | |||||
for (unsigned i = 0; i < inputs.size(); ++i) { | |||||
auto dnn_ten = inputs[i]->dnn_tensor(); | |||||
inp_tensornds.push_back(dnn_ten); | |||||
} | |||||
TensorLayout oup_layout{inputs[0]->dtype()}; | |||||
auto inp1_tensor = inputs[0]->dnn_tensor(); | |||||
auto inp2_tensor = inputs[1]->dnn_tensor(); | |||||
dnn_opr.op->deduce_layout(inp1_tensor.layout, inp2_tensor.layout, oup_layout); | |||||
if (inputs[0]->layout().is_empty() || inputs[1]->layout().is_empty()) { | |||||
DnnOprCaller<megdnn::Fill> fill_opr(comp_node); | |||||
DeviceTensorND out = | |||||
BlobManager::inst()->alloc_workspace_with_defrag(comp_node, oup_layout); | |||||
fill_opr.op->param() = 0; | |||||
fill_opr.op->exec(out.as_megdnn(), {}); | |||||
return {Tensor::make(out)}; | |||||
} | |||||
auto sz = dnn_opr.op->get_workspace_in_bytes( | |||||
inp_tensornds[0].layout, inp_tensornds[1].layout, output_descs[0].layout); | |||||
DeviceTensorND out_devtensor = | |||||
BlobManager::inst()->alloc_workspace_with_defrag(comp_node, oup_layout); | |||||
TensorLayout w_layout({sz}, dtype::Byte()); | |||||
auto dnn_wk = dnn_opr.create_workspace(w_layout); | |||||
dnn_opr.op->exec( | |||||
inp_tensornds[0], inp_tensornds[1], out_devtensor.as_megdnn(), dnn_wk); | |||||
return {Tensor::make(out_devtensor)}; | |||||
} | |||||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | |||||
mgb_assert( | |||||
inputs.size() == 2, "Dot expects 2 inputs; got %lu actually", | |||||
inputs.size()); | |||||
SmallVector<LogicalTensorDesc> dests(1); | |||||
dests[0].layout = TensorLayout(TensorShape{1}, inputs[0].layout.dtype); | |||||
dests[0].comp_node = inputs[0].comp_node; | |||||
bool validated = inputs[0].layout.ndim != 0 && inputs[1].layout.ndim != 0; | |||||
return {dests, validated}; | |||||
} | |||||
OP_TRAIT_REG(Dot, Dot, mgb::opr::Dot) | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.infer_output_attrs_fallible(infer_output_attrs_fallible) | |||||
.apply_on_physical_tensor(apply_on_physical_tensor) | |||||
.fallback(); | |||||
} // namespace dot | |||||
} // anonymous namespace | |||||
} // namespace imperative | |||||
} // namespace mgb |
@@ -0,0 +1,435 @@ | |||||
#include <numeric> | |||||
#include "../blob_manager_impl.h" | |||||
#include "../dnn_op_helper.h" | |||||
#include "../op_trait.h" | |||||
#include "megbrain/imperative/ops/autogen.h" | |||||
#include "megbrain/opr/blas.h" | |||||
#include "../algo_chooser.h" | |||||
namespace mgb { | |||||
namespace imperative { | |||||
namespace { | |||||
namespace matrix_mul { | |||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& matmul = def.cast_final_safe<MatrixMul>(); | |||||
mgb_assert(inputs.size() == 2); | |||||
OperatorNodeConfig config{matmul.make_name()}; | |||||
return opr::MatrixMul::make( | |||||
inputs[0], inputs[1], matmul.param(), matmul.policy(), config); | |||||
} | |||||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | |||||
auto&& matmul = def.cast_final_safe<MatrixMul>(); | |||||
auto layout1 = inputs[0].layout; | |||||
auto layout2 = inputs[1].layout; | |||||
size_t dim1 = layout1.ndim, dim2 = layout2.ndim; | |||||
if (dim1 == 0 || dim2 == 0) { | |||||
return {{{TensorLayout(layout1.dtype), inputs[0].comp_node}}, false}; | |||||
} | |||||
if (matmul.transposeA) | |||||
std::swap(layout1[0], layout1[1]); | |||||
if (matmul.transposeB) | |||||
std::swap(layout2[0], layout2[1]); | |||||
mgb_assert(layout1[dim1 - 1] == layout2[0]); | |||||
TensorLayout dst_layout(layout1.dtype); | |||||
size_t ci = 0; | |||||
for (size_t i = 0; i < dim1 - 1; i++) | |||||
dst_layout[ci++] = layout1[i]; | |||||
if (dim2 == 2) | |||||
dst_layout[ci++] = layout2[1]; | |||||
dst_layout.ndim = ci; | |||||
dst_layout.init_contiguous_stride(); | |||||
SmallVector<LogicalTensorDesc> out_descs(1u); | |||||
out_descs[0] = {dst_layout, inputs[0].comp_node}; | |||||
return {out_descs, true}; | |||||
} | |||||
SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||||
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||||
auto&& matmul = def.cast_final_safe<MatrixMul>(); | |||||
auto&& cn = inputs[0]->comp_node(); | |||||
using TensorND = megdnn::TensorND; | |||||
SmallVector<TensorND> inp_tensornds(inputs.size()); | |||||
TensorLayout layout1 = inputs[0]->layout(), layout2 = inputs[1]->layout(); | |||||
// only matters when layout1 has dim 2 | |||||
if (matmul.transposeA) | |||||
std::swap(layout1.shape[0], layout1.shape[1]); | |||||
// only matters when layout2 has dim 2 | |||||
if (matmul.transposeB) | |||||
std::swap(layout2.shape[0], layout2.shape[1]); | |||||
size_t dim1 = layout1.ndim, dim2 = layout2.ndim; | |||||
TensorLayout real_dst_layout(layout1.dtype); | |||||
if (validated) { | |||||
real_dst_layout = output_descs[0].layout; | |||||
} else { | |||||
size_t ri = 0; | |||||
for (size_t i = 0; i < dim1 - 2; i++) | |||||
real_dst_layout[ri++] = layout1[i]; | |||||
real_dst_layout[ri++] = layout1[dim1 - 2]; | |||||
if (dim2 == 2) | |||||
real_dst_layout[ri++] = layout2[dim2 - 1]; | |||||
real_dst_layout.ndim = ri; | |||||
real_dst_layout.init_contiguous_stride(); | |||||
} | |||||
if (dim1 == 0 || dim2 == 0 || layout1[layout1.ndim - 1] == 0) { | |||||
DeviceTensorND out = | |||||
BlobManager::inst()->alloc_workspace_with_defrag(cn, real_dst_layout); | |||||
if (!out.empty()) { | |||||
dev_tensor_memset(out, 0); | |||||
} | |||||
return {Tensor::make(out)}; | |||||
} | |||||
TensorLayout layout_a = layout1, layout_b = layout2; | |||||
if (dim1 == 1) { | |||||
layout_a.add_axis_cont_inplace(0); | |||||
inp_tensornds[0] = inputs[0]->dnn_tensor(); | |||||
inp_tensornds[0].layout = layout_a; | |||||
} else if (dim1 > 2) { | |||||
size_t batch = std::accumulate( | |||||
layout1.shape, layout1.shape + dim1 - 1, (size_t)1, | |||||
std::multiplies<size_t>()); | |||||
TensorShape na = TensorShape{batch, layout1[dim1 - 1]}; | |||||
auto inp1 = inputs[0]; | |||||
if (!layout1.try_reshape(layout_a, na)) { | |||||
inp1 = Tensor::make(inp1->blob(), inp1->offset(), layout1); | |||||
inp1->to_contiguous_inplace(); | |||||
layout1 = inp1->layout(); | |||||
layout_a = TensorLayout{{batch, layout1[dim1 - 1]}, layout1.dtype}; | |||||
} | |||||
layout_a.init_contiguous_stride(); | |||||
inp_tensornds[0] = inp1->dnn_tensor(); | |||||
inp_tensornds[0].layout = layout_a; | |||||
} else { | |||||
inp_tensornds[0] = inputs[0]->dnn_tensor(); | |||||
} | |||||
if (dim2 == 1) { | |||||
layout_b.add_axis_inplace(1, 1, 1); | |||||
inp_tensornds[1] = inputs[1]->dnn_tensor(); | |||||
inp_tensornds[1].layout = layout_b; | |||||
} else { | |||||
inp_tensornds[1] = inputs[1]->dnn_tensor(); | |||||
} | |||||
TensorLayout dst_layout = TensorLayout({layout_a[0], layout_b[1]}, layout_a.dtype); | |||||
dst_layout.init_contiguous_stride(); | |||||
DnnOprCaller<megdnn::MatrixMul> dnn_opr(cn); | |||||
dnn_opr.op->param() = matmul.param(); | |||||
DeviceTensorND out = | |||||
BlobManager::inst()->alloc_workspace_with_defrag(cn, dst_layout); | |||||
size_t sz = setup_algo<megdnn::MatrixMul>( | |||||
{layout_a, layout_b, dst_layout}, dnn_opr.op.get(), 0, false, false, cn, | |||||
matmul.policy(), false); | |||||
TensorLayout w_layout({sz}, dtype::Byte()); | |||||
auto dnn_wk = dnn_opr.create_workspace(w_layout); | |||||
dnn_opr.op->exec(inp_tensornds[0], inp_tensornds[1], out.as_megdnn(), dnn_wk); | |||||
return {Tensor::make(out.sub(SubTensorSpec::make_from_layout(real_dst_layout)))}; | |||||
} | |||||
SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint( | |||||
const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||||
SmallVector<VarNode::LayoutConstraintCallback> layout_checker(inputs.size()); | |||||
layout_checker[0] = layout_checker[1] = [](const TensorLayout& layout) { | |||||
return layout.is_contiguous(); | |||||
}; | |||||
return layout_checker; | |||||
} | |||||
OP_TRAIT_REG(MatrixMul, MatrixMul) | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.infer_output_attrs_fallible(infer_output_attrs_fallible) | |||||
.apply_on_physical_tensor(apply_on_physical_tensor) | |||||
.get_input_layout_constraint(get_input_layout_constraint) | |||||
.fallback(); | |||||
} // namespace matrix_mul | |||||
} // namespace | |||||
namespace { | |||||
namespace batched_matrix_mul { | |||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& matmul = def.cast_final_safe<BatchedMatrixMul>(); | |||||
mgb_assert(inputs.size() == 2); | |||||
OperatorNodeConfig config{matmul.make_name()}; | |||||
return opr::BatchedMatrixMul::make( | |||||
inputs[0], inputs[1], matmul.param(), matmul.policy(), config); | |||||
} | |||||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | |||||
auto&& matmul = def.cast_final_safe<BatchedMatrixMul>(); | |||||
TensorLayout layout1 = inputs[0].layout, layout2 = inputs[1].layout; | |||||
size_t dim1 = layout1.ndim, dim2 = layout2.ndim; | |||||
if (dim1 == 0 || dim2 == 0) { | |||||
return {{{TensorLayout(layout1.dtype), inputs[0].comp_node}}, false}; | |||||
} | |||||
if (matmul.transposeA) | |||||
std::swap(layout1[dim1 - 1], layout1[dim1 - 2]); | |||||
if (matmul.transposeB) | |||||
std::swap(layout2[dim2 - 1], layout2[dim2 - 2]); | |||||
TensorLayout dst_layout(layout1.dtype); | |||||
size_t di = 0; | |||||
if (dim1 > dim2) { | |||||
for (size_t i = 0; i < dim1 - 2; i++) | |||||
dst_layout[di++] = layout1[i]; | |||||
} else { | |||||
for (size_t i = 0; i < dim2 - 2; i++) | |||||
dst_layout[di++] = layout2[i]; | |||||
} | |||||
if (dim1 > 1) | |||||
dst_layout[di++] = layout1[dim1 - 2]; | |||||
if (dim2 > 1) | |||||
dst_layout[di++] = layout2[dim2 - 1]; | |||||
dst_layout.ndim = di; | |||||
dst_layout.init_contiguous_stride(); | |||||
SmallVector<LogicalTensorDesc> out_descs(1u); | |||||
out_descs[0] = {dst_layout, inputs[0].comp_node}; | |||||
return {out_descs, true}; | |||||
} | |||||
SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||||
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||||
auto&& matmul = def.cast_final_safe<BatchedMatrixMul>(); | |||||
auto&& cn = inputs[0]->comp_node(); | |||||
TensorLayout layout1 = inputs[0]->layout(), layout2 = inputs[1]->layout(); | |||||
size_t dim1 = layout1.ndim, dim2 = layout2.ndim; | |||||
bool remove_row = false, remove_col = false; | |||||
if (dim1 == 1) { | |||||
dim1 = 2; | |||||
remove_row = true; | |||||
} | |||||
if (dim2 == 1) { | |||||
dim2 = 2; | |||||
remove_col = true; | |||||
} | |||||
if (remove_row) | |||||
layout1.add_axis_cont_inplace(0); | |||||
if (remove_col) | |||||
layout2.add_axis_inplace(1, 1, 1); | |||||
TensorShape tshp, batch_shp; | |||||
size_t j = 0; | |||||
if (dim1 > dim2) { | |||||
for (size_t i = 0; i < dim1 - 2; i++) | |||||
tshp[j++] = layout1.shape[i]; | |||||
batch_shp = tshp; | |||||
batch_shp.ndim = dim1 - 2; | |||||
tshp[j++] = layout2[layout2.ndim - 2]; | |||||
tshp[j++] = layout2[layout2.ndim - 1]; | |||||
tshp.ndim = j; | |||||
layout2 = layout2.broadcast(tshp); | |||||
} | |||||
if (dim2 > dim1) { | |||||
for (size_t i = 0; i < dim2 - 2; i++) | |||||
tshp[j++] = layout2.shape[i]; | |||||
batch_shp = tshp; | |||||
batch_shp.ndim = dim2 - 2; | |||||
tshp[j++] = layout1[layout1.ndim - 2]; | |||||
tshp[j++] = layout1[layout1.ndim - 1]; | |||||
tshp.ndim = j; | |||||
layout1 = layout1.broadcast(tshp); | |||||
} | |||||
if (dim1 == dim2) { | |||||
for (size_t i = 0; i < dim1 - 2; i++) | |||||
tshp[j++] = layout1.shape[i]; | |||||
batch_shp = tshp; | |||||
batch_shp.ndim = dim1 - 2; | |||||
} | |||||
TensorShape shp1 = batch_shp, shp2 = batch_shp; | |||||
shp1.ndim += 2; | |||||
shp2.ndim += 2; | |||||
size_t maxdim = dim1 > dim2 ? dim1 : dim2; | |||||
size_t nbatch = batch_shp[0]; | |||||
auto inp1 = inputs[0], inp2 = inputs[1]; | |||||
if (maxdim > 3) { | |||||
nbatch = std::accumulate( | |||||
batch_shp.shape, batch_shp.shape + batch_shp.ndim, (size_t)1, | |||||
std::multiplies<size_t>()); | |||||
TensorLayout layout_a; | |||||
TensorShape nl1 = TensorShape( | |||||
{nbatch, layout1[layout1.ndim - 2], layout1[layout1.ndim - 1]}); | |||||
if (!layout1.try_reshape(layout_a, nl1)) { | |||||
inp1 = Tensor::make(inputs[0]->blob(), inputs[0]->offset(), layout1); | |||||
inp1->to_contiguous_inplace(); | |||||
layout1 = inp1->layout(); | |||||
} | |||||
layout1 = layout_a; | |||||
TensorShape nl2 = TensorShape( | |||||
{nbatch, layout2[layout2.ndim - 2], layout2[layout2.ndim - 1]}); | |||||
if (!layout2.try_reshape(layout_a, nl2)) { | |||||
inp2 = Tensor::make(inputs[1]->blob(), inputs[1]->offset(), layout2); | |||||
inp2->to_contiguous_inplace(); | |||||
layout2 = inp2->layout(); | |||||
} | |||||
layout2 = layout_a; | |||||
} | |||||
TensorLayout dst_layout( | |||||
{nbatch, matmul.transposeA ? layout1[2] : layout1[1], | |||||
matmul.transposeB ? layout2[1] : layout2[2]}, | |||||
layout1.dtype); | |||||
dst_layout.init_contiguous_stride(); | |||||
if (dim1 == 0 || dim2 == 0 || layout1[layout1.ndim - 1] == 0) { | |||||
DeviceTensorND out = | |||||
BlobManager::inst()->alloc_workspace_with_defrag(cn, dst_layout); | |||||
if (!out.empty()) { | |||||
dev_tensor_memset(out, 0); | |||||
} | |||||
return {Tensor::make(out)}; | |||||
} | |||||
using TensorND = megdnn::TensorND; | |||||
TensorND inp_nd1 = inp1->dnn_tensor(); | |||||
inp_nd1.layout = layout1; | |||||
TensorND inp_nd2 = inp2->dnn_tensor(); | |||||
inp_nd2.layout = layout2; | |||||
DeviceTensorND out = | |||||
BlobManager::inst()->alloc_workspace_with_defrag(cn, dst_layout); | |||||
DnnOprCaller<megdnn::BatchedMatrixMul> dnn_opr(cn); | |||||
dnn_opr.op->param() = matmul.param(); | |||||
size_t sz = setup_algo<megdnn::BatchedMatrixMul>( | |||||
{layout1, layout2, dst_layout}, dnn_opr.op.get(), 0, false, false, cn, | |||||
matmul.policy(), false); | |||||
TensorLayout w_layout({sz}, dtype::Byte()); | |||||
auto dnn_wk = dnn_opr.create_workspace(w_layout); | |||||
dnn_opr.op->exec(inp_nd1, inp_nd2, out.as_megdnn(), dnn_wk); | |||||
shp1[shp1.ndim - 2] = dst_layout[dst_layout.ndim - 2]; | |||||
shp1[shp1.ndim - 1] = dst_layout[dst_layout.ndim - 1]; | |||||
if (maxdim > 3) { | |||||
dst_layout = dst_layout.reshape(shp1); | |||||
} | |||||
if (remove_row) { | |||||
dst_layout = dst_layout.remove_axis(maxdim - 2); | |||||
} | |||||
if (remove_col) { | |||||
dst_layout = dst_layout.remove_axis(maxdim - 1); | |||||
} | |||||
return {Tensor::make(out.sub(SubTensorSpec::make_from_layout(dst_layout)))}; | |||||
} | |||||
SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint( | |||||
const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||||
SmallVector<VarNode::LayoutConstraintCallback> layout_checker(inputs.size()); | |||||
layout_checker[0] = layout_checker[1] = [](const TensorLayout& layout) { | |||||
return layout.is_contiguous(); | |||||
}; | |||||
return layout_checker; | |||||
} | |||||
OP_TRAIT_REG(BatchedMatrixMul, BatchedMatrixMul) | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.infer_output_attrs_fallible(infer_output_attrs_fallible) | |||||
.get_input_layout_constraint(get_input_layout_constraint) | |||||
.apply_on_physical_tensor(apply_on_physical_tensor) | |||||
.fallback(); | |||||
} // namespace batched_matrix_mul | |||||
} // namespace | |||||
namespace { | |||||
namespace dot { | |||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& op = def.cast_final_safe<Dot>(); | |||||
mgb_assert(inputs.size() == 2); | |||||
OperatorNodeConfig config{op.make_name()}; | |||||
return opr::Dot::make(inputs[0], inputs[1], config); | |||||
} | |||||
SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||||
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||||
auto comp_node = inputs[0]->comp_node(); | |||||
using TensorND = megdnn::TensorND; | |||||
SmallVector<TensorND> inp_tensornds; | |||||
inp_tensornds.reserve(inputs.size()); | |||||
DnnOprCaller<megdnn::Dot> dnn_opr(comp_node); | |||||
for (unsigned i = 0; i < inputs.size(); ++i) { | |||||
auto dnn_ten = inputs[i]->dnn_tensor(); | |||||
inp_tensornds.push_back(dnn_ten); | |||||
} | |||||
TensorLayout oup_layout{inputs[0]->dtype()}; | |||||
auto inp1_tensor = inputs[0]->dnn_tensor(); | |||||
auto inp2_tensor = inputs[1]->dnn_tensor(); | |||||
dnn_opr.op->deduce_layout(inp1_tensor.layout, inp2_tensor.layout, oup_layout); | |||||
if (inputs[0]->layout().is_empty() || inputs[1]->layout().is_empty()) { | |||||
DeviceTensorND out = | |||||
BlobManager::inst()->alloc_workspace_with_defrag(comp_node, oup_layout); | |||||
if (!out.empty()) { | |||||
dev_tensor_memset(out, 0); | |||||
} | |||||
return {Tensor::make(out)}; | |||||
} | |||||
auto sz = dnn_opr.op->get_workspace_in_bytes( | |||||
inp_tensornds[0].layout, inp_tensornds[1].layout, output_descs[0].layout); | |||||
DeviceTensorND out_devtensor = | |||||
BlobManager::inst()->alloc_workspace_with_defrag(comp_node, oup_layout); | |||||
TensorLayout w_layout({sz}, dtype::Byte()); | |||||
auto dnn_wk = dnn_opr.create_workspace(w_layout); | |||||
dnn_opr.op->exec( | |||||
inp_tensornds[0], inp_tensornds[1], out_devtensor.as_megdnn(), dnn_wk); | |||||
return {Tensor::make(out_devtensor)}; | |||||
} | |||||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | |||||
mgb_assert( | |||||
inputs.size() == 2, "Dot expects 2 inputs; got %lu actually", | |||||
inputs.size()); | |||||
SmallVector<LogicalTensorDesc> dests(1); | |||||
dests[0].layout = TensorLayout(TensorShape{1}, inputs[0].layout.dtype); | |||||
dests[0].comp_node = inputs[0].comp_node; | |||||
bool validated = inputs[0].layout.ndim != 0 && inputs[1].layout.ndim != 0; | |||||
return {dests, validated}; | |||||
} | |||||
OP_TRAIT_REG(Dot, Dot, mgb::opr::Dot) | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.infer_output_attrs_fallible(infer_output_attrs_fallible) | |||||
.apply_on_physical_tensor(apply_on_physical_tensor) | |||||
.fallback(); | |||||
} // namespace dot | |||||
} // anonymous namespace | |||||
} // namespace imperative | |||||
} // namespace mgb |
@@ -123,7 +123,6 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
inputs[0]->dev_tensor().reset(inputs[0]->dev_tensor().storage(), src); | inputs[0]->dev_tensor().reset(inputs[0]->dev_tensor().storage(), src); | ||||
auto mode = op_def.param().mode; | auto mode = op_def.param().mode; | ||||
DnnOprCaller<megdnn::Fill> fill_op(comp_node); | |||||
if (!keepdim && src.ndim > 1) { | if (!keepdim && src.ndim > 1) { | ||||
layout.remove_axis_inplace(axis); | layout.remove_axis_inplace(axis); | ||||
@@ -135,12 +134,12 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
switch (mode) { | switch (mode) { | ||||
case Reduce::Mode::SUM: | case Reduce::Mode::SUM: | ||||
if (!out.empty()) { | if (!out.empty()) { | ||||
fill_op.op->param() = 0; | |||||
fill_op.op->exec(out.as_megdnn(), {}); | |||||
dev_tensor_memset(out, 0); | |||||
} | } | ||||
break; | break; | ||||
case Reduce::Mode::PRODUCT: | case Reduce::Mode::PRODUCT: | ||||
if (!out.empty()) { | if (!out.empty()) { | ||||
DnnOprCaller<megdnn::Fill> fill_op(comp_node); | |||||
fill_op.op->param() = 1; | fill_op.op->param() = 1; | ||||
fill_op.op->exec(out.as_megdnn(), {}); | fill_op.op->exec(out.as_megdnn(), {}); | ||||
} | } | ||||
@@ -320,34 +320,6 @@ OP_TRAIT_REG(BatchConvBias, BatchConvBias) | |||||
} // namespace | } // namespace | ||||
namespace { | namespace { | ||||
namespace matrix_mul { | |||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& matmul = static_cast<const MatrixMul&>(def); | |||||
mgb_assert(inputs.size() == 2); | |||||
OperatorNodeConfig config{matmul.make_name()}; | |||||
return opr::MatrixMul::make( | |||||
inputs[0], inputs[1], matmul.param(), matmul.policy(), config); | |||||
} | |||||
OP_TRAIT_REG(MatrixMul, MatrixMul).apply_on_var_node(apply_on_var_node).fallback(); | |||||
} // namespace matrix_mul | |||||
} // namespace | |||||
namespace { | |||||
namespace batched_matrix_mul { | |||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& matmul = static_cast<const BatchedMatrixMul&>(def); | |||||
mgb_assert(inputs.size() == 2); | |||||
OperatorNodeConfig config{matmul.make_name()}; | |||||
return opr::BatchedMatrixMul::make( | |||||
inputs[0], inputs[1], matmul.param(), matmul.policy(), config); | |||||
} | |||||
OP_TRAIT_REG(BatchedMatrixMul, BatchedMatrixMul) | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
} // namespace batched_matrix_mul | |||||
} // namespace | |||||
namespace { | |||||
namespace argsort { | namespace argsort { | ||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | ||||
auto&& argsort = static_cast<const Argsort&>(def); | auto&& argsort = static_cast<const Argsort&>(def); | ||||
@@ -183,6 +183,57 @@ ValueRefList convolution_rule(const OpDef& op, Span<ValueRef> inputs) { | |||||
return imperative::apply(op, converted); | return imperative::apply(op, converted); | ||||
} | } | ||||
ValueRefList matmul_rule(const OpDef& op, Span<ValueRef> inputs) { | |||||
auto&& conv_op = const_cast<MatrixMul&>(op.cast_final_safe<MatrixMul>()); | |||||
SmallVector<DType> dtypes = get_value_dtypes(inputs); | |||||
mgb::DType target_dtype; | |||||
if (DTypePromoteCfg::amp_dtype_autocast_enabled) { | |||||
conv_op.compute_mode = MatrixMul::ComputeMode::FLOAT32; | |||||
target_dtype = DTypePromoteCfg::amp_low_prec_dtype; | |||||
} else { | |||||
target_dtype = get_promoted_dtype(dtypes); | |||||
} | |||||
ValueRefList converted(inputs.size()); | |||||
for (size_t i = 0; i < inputs.size(); ++i) { | |||||
if (dtypes[i] != target_dtype) { | |||||
converted[i] = imperative::apply( | |||||
ApplyOp(*TypeCvt::make(target_dtype)), inputs[i])[0]; | |||||
} else { | |||||
converted[i] = inputs[i]; | |||||
} | |||||
} | |||||
return imperative::apply(op, converted); | |||||
} | |||||
ValueRefList batch_matmul_rule(const OpDef& op, Span<ValueRef> inputs) { | |||||
auto&& conv_op = | |||||
const_cast<BatchedMatrixMul&>(op.cast_final_safe<BatchedMatrixMul>()); | |||||
SmallVector<DType> dtypes = get_value_dtypes(inputs); | |||||
mgb::DType target_dtype; | |||||
if (DTypePromoteCfg::amp_dtype_autocast_enabled) { | |||||
conv_op.compute_mode = BatchedMatrixMul::ComputeMode::FLOAT32; | |||||
target_dtype = DTypePromoteCfg::amp_low_prec_dtype; | |||||
} else { | |||||
target_dtype = get_promoted_dtype(dtypes); | |||||
} | |||||
ValueRefList converted(inputs.size()); | |||||
for (size_t i = 0; i < inputs.size(); ++i) { | |||||
if (dtypes[i] != target_dtype) { | |||||
converted[i] = imperative::apply( | |||||
ApplyOp(*TypeCvt::make(target_dtype)), inputs[i])[0]; | |||||
} else { | |||||
converted[i] = inputs[i]; | |||||
} | |||||
} | |||||
return imperative::apply(op, converted); | |||||
} | |||||
// differ from Convolution, ConvolutionBackwardData is used in both | // differ from Convolution, ConvolutionBackwardData is used in both | ||||
// functional.conv_transpose2d and quantize.conv_transpose2d | // functional.conv_transpose2d and quantize.conv_transpose2d | ||||
ValueRefList convolution_backward_rule(const OpDef& op, Span<ValueRef> inputs) { | ValueRefList convolution_backward_rule(const OpDef& op, Span<ValueRef> inputs) { | ||||
@@ -259,8 +310,11 @@ struct DTypePromoteRuleRegistry { | |||||
DTypePromoteRuleRegistry() { | DTypePromoteRuleRegistry() { | ||||
register_dtype_promote_rule<Elemwise>(elemwise_rule); | register_dtype_promote_rule<Elemwise>(elemwise_rule); | ||||
register_dtype_promote_rule<Concat>(naive_promote_rule); | register_dtype_promote_rule<Concat>(naive_promote_rule); | ||||
register_dtype_promote_rule<GroupLocal>(naive_promote_rule); | |||||
register_dtype_promote_rule<Reduce>(reduce_rule); | register_dtype_promote_rule<Reduce>(reduce_rule); | ||||
register_dtype_promote_rule<Convolution>(convolution_rule); | register_dtype_promote_rule<Convolution>(convolution_rule); | ||||
register_dtype_promote_rule<MatrixMul>(matmul_rule); | |||||
register_dtype_promote_rule<BatchedMatrixMul>(batch_matmul_rule); | |||||
register_dtype_promote_rule<ConvolutionBackwardData>(convolution_backward_rule); | register_dtype_promote_rule<ConvolutionBackwardData>(convolution_backward_rule); | ||||
register_dtype_promote_rule<BatchNorm>(batch_norm_rule); | register_dtype_promote_rule<BatchNorm>(batch_norm_rule); | ||||
register_dtype_promote_rule<Convolution3D>(naive_promote_rule); | register_dtype_promote_rule<Convolution3D>(naive_promote_rule); | ||||