GitOrigin-RevId: 10a3c5b106
release-1.6
@@ -142,6 +142,26 @@ def test_matmul(): | |||||
) | ) | ||||
@pytest.mark.parametrize( | |||||
"shape_a, shape_b", [((0,), (0,)), ((10, 0), (0, 10)), ((3, 10, 0), (3, 0, 10)),], | |||||
) | |||||
@pytest.mark.parametrize("is_symbolic", [None, True, False]) | |||||
def test_matmul_empty_tensor(shape_a, shape_b, is_symbolic): | |||||
def func(a, b): | |||||
return F.matmul(a, b) | |||||
if is_symbolic is not None: | |||||
func = jit.trace(symbolic=is_symbolic)(func) | |||||
a = tensor(np.random.randn(*shape_a)) | |||||
b = tensor(np.random.randn(*shape_b)) | |||||
for _ in range(3): | |||||
out = func(a, b) | |||||
assert np.all(out.numpy() == 0) | |||||
if is_symbolic is None: | |||||
break | |||||
def test_interpolate(): | def test_interpolate(): | ||||
def linear_interpolate(): | def linear_interpolate(): | ||||
inp = tensor(np.arange(1, 3, dtype=np.float32).reshape(1, 1, 2)) | inp = tensor(np.arange(1, 3, dtype=np.float32).reshape(1, 1, 2)) | ||||
@@ -45,6 +45,7 @@ MatrixMul::MatrixMul(VarNode* a, VarNode* b, const Param& param, | |||||
init_megdnn_opr(*this, param); | init_megdnn_opr(*this, param); | ||||
m_policy = policy; | m_policy = policy; | ||||
add_input({a, b}); | add_input({a, b}); | ||||
output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); | |||||
} | } | ||||
SymbolVar MatrixMul::make(SymbolVar a, SymbolVar b, const Param& param, | SymbolVar MatrixMul::make(SymbolVar a, SymbolVar b, const Param& param, | ||||
@@ -61,6 +62,15 @@ void MatrixMul::init_output_dtype() { | |||||
output(0)->dtype(output_dtype); | output(0)->dtype(output_dtype); | ||||
} | } | ||||
MatrixMul::NodeProp* MatrixMul::do_make_node_prop() const { | |||||
auto ret = Super::do_make_node_prop(); | |||||
ret->add_dep_type_existing_var(input(0), | |||||
NodeProp::DepType::VALUE_ALLOW_EMPTY); | |||||
ret->add_dep_type_existing_var(input(1), | |||||
NodeProp::DepType::VALUE_ALLOW_EMPTY); | |||||
return ret; | |||||
} | |||||
bool MatrixMul::check_layout(const TensorLayout& layout, int transpose) { | bool MatrixMul::check_layout(const TensorLayout& layout, int transpose) { | ||||
mgb_assert(layout.ndim == 2, "input to MatrixMul must be 2-dim; got %s", | mgb_assert(layout.ndim == 2, "input to MatrixMul must be 2-dim; got %s", | ||||
layout.to_string().c_str()); | layout.to_string().c_str()); | ||||
@@ -138,6 +148,17 @@ void MatrixMul::scn_do_execute() { | |||||
auto inp0 = input(0)->dev_tensor().as_megdnn(), | auto inp0 = input(0)->dev_tensor().as_megdnn(), | ||||
inp1 = input(1)->dev_tensor().as_megdnn(), | inp1 = input(1)->dev_tensor().as_megdnn(), | ||||
out = output(0)->dev_tensor().as_megdnn(); | out = output(0)->dev_tensor().as_megdnn(); | ||||
if ((inp0.layout.is_empty() || inp1.layout.is_empty())) { | |||||
if (!out.layout.is_empty()) { | |||||
if (!m_fill_opr) { | |||||
m_fill_opr = intl::get_megdnn_handle(comp_node())-> | |||||
create_operator<megdnn::Fill>(); | |||||
} | |||||
m_fill_opr->param() = 0; | |||||
m_fill_opr->exec(out, {}); | |||||
} | |||||
return; | |||||
} | |||||
auto transpose = [](TensorLayout& layout, bool& trans) { | auto transpose = [](TensorLayout& layout, bool& trans) { | ||||
if (!check_layout(layout, 0)) { | if (!check_layout(layout, 0)) { | ||||
mgb_assert(check_layout(layout, 1)); | mgb_assert(check_layout(layout, 1)); | ||||
@@ -193,6 +214,7 @@ BatchedMatrixMul::BatchedMatrixMul(VarNode* a, VarNode* b, const Param& param, | |||||
init_megdnn_opr(*this, param); | init_megdnn_opr(*this, param); | ||||
m_policy = policy; | m_policy = policy; | ||||
add_input({a, b}); | add_input({a, b}); | ||||
output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); | |||||
} | } | ||||
SymbolVar BatchedMatrixMul::make(SymbolVar a, SymbolVar b, const Param& param, | SymbolVar BatchedMatrixMul::make(SymbolVar a, SymbolVar b, const Param& param, | ||||
@@ -229,6 +251,15 @@ void BatchedMatrixMul::init_output_dtype() { | |||||
output(0)->dtype(output_dtype); | output(0)->dtype(output_dtype); | ||||
} | } | ||||
BatchedMatrixMul::NodeProp* BatchedMatrixMul::do_make_node_prop() const { | |||||
auto ret = Super::do_make_node_prop(); | |||||
ret->add_dep_type_existing_var(input(0), | |||||
NodeProp::DepType::VALUE_ALLOW_EMPTY); | |||||
ret->add_dep_type_existing_var(input(1), | |||||
NodeProp::DepType::VALUE_ALLOW_EMPTY); | |||||
return ret; | |||||
} | |||||
bool BatchedMatrixMul::check_layout(const TensorLayout& layout, | bool BatchedMatrixMul::check_layout(const TensorLayout& layout, | ||||
bool transpose) { | bool transpose) { | ||||
int lhs = (transpose) ? 2 : 1, rhs = (transpose) ? 1 : 2; | int lhs = (transpose) ? 2 : 1, rhs = (transpose) ? 1 : 2; | ||||
@@ -294,6 +325,17 @@ void BatchedMatrixMul::scn_do_execute() { | |||||
auto inp0 = input(0)->dev_tensor().as_megdnn(), | auto inp0 = input(0)->dev_tensor().as_megdnn(), | ||||
inp1 = input(1)->dev_tensor().as_megdnn(), | inp1 = input(1)->dev_tensor().as_megdnn(), | ||||
out = output(0)->dev_tensor().as_megdnn(); | out = output(0)->dev_tensor().as_megdnn(); | ||||
if ((inp0.layout.is_empty() || inp1.layout.is_empty())) { | |||||
if (!out.layout.is_empty()) { | |||||
if (!m_fill_opr) { | |||||
m_fill_opr = intl::get_megdnn_handle(comp_node())-> | |||||
create_operator<megdnn::Fill>(); | |||||
} | |||||
m_fill_opr->param() = 0; | |||||
m_fill_opr->exec(out, {}); | |||||
} | |||||
return; | |||||
} | |||||
auto transpose = [](TensorLayout& layout, bool& trans) { | auto transpose = [](TensorLayout& layout, bool& trans) { | ||||
if (!check_layout(layout, false)) { | if (!check_layout(layout, false)) { | ||||
mgb_assert(check_layout(layout, true)); | mgb_assert(check_layout(layout, true)); | ||||
@@ -354,6 +396,7 @@ Dot::Dot(VarNode *opr0, VarNode *opr1, const OperatorNodeConfig &config): | |||||
{ | { | ||||
init_megdnn_opr(*this, {}); | init_megdnn_opr(*this, {}); | ||||
add_input({opr0, opr1}, AddInputSortType::CUR_ADDED); | add_input({opr0, opr1}, AddInputSortType::CUR_ADDED); | ||||
output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); | |||||
static_assert(std::is_empty<Param>::value, "Dot param should be empty"); | static_assert(std::is_empty<Param>::value, "Dot param should be empty"); | ||||
mgb_assert(opr0->dtype().category() != DTypeCategory::QUANTIZED && | mgb_assert(opr0->dtype().category() != DTypeCategory::QUANTIZED && | ||||
opr1->dtype().category() != DTypeCategory::QUANTIZED, | opr1->dtype().category() != DTypeCategory::QUANTIZED, | ||||
@@ -406,10 +449,28 @@ void Dot::scn_do_execute() { | |||||
i1.layout.stride[0] = 0; | i1.layout.stride[0] = 0; | ||||
} | } | ||||
} | } | ||||
if ((i0.layout.is_empty() || i1.layout.is_empty())) { | |||||
if (!m_fill_opr) { | |||||
m_fill_opr = intl::get_megdnn_handle(comp_node())-> | |||||
create_operator<megdnn::Fill>(); | |||||
} | |||||
m_fill_opr->param() = 0; | |||||
m_fill_opr->exec(output(0)->dev_tensor().as_megdnn(), {}); | |||||
return; | |||||
} | |||||
megdnn_opr()->exec(i0, i1, output(0)->dev_tensor().as_megdnn(), | megdnn_opr()->exec(i0, i1, output(0)->dev_tensor().as_megdnn(), | ||||
intl::get_megdnn_workspace_from_var(output(1))); | intl::get_megdnn_workspace_from_var(output(1))); | ||||
} | } | ||||
Dot::NodeProp* Dot::do_make_node_prop() const { | |||||
auto ret = Super::do_make_node_prop(); | |||||
ret->add_dep_type_existing_var(input(0), | |||||
NodeProp::DepType::VALUE_ALLOW_EMPTY); | |||||
ret->add_dep_type_existing_var(input(1), | |||||
NodeProp::DepType::VALUE_ALLOW_EMPTY); | |||||
return ret; | |||||
} | |||||
void Dot::add_input_layout_constraint() { | void Dot::add_input_layout_constraint() { | ||||
auto check = [](const TensorLayout &ly) { | auto check = [](const TensorLayout &ly) { | ||||
mgb_throw_if(ly.ndim != 1, GraphError, | mgb_throw_if(ly.ndim != 1, GraphError, | ||||
@@ -17,6 +17,7 @@ | |||||
#include "megbrain/graph.h" | #include "megbrain/graph.h" | ||||
#include "megbrain/opr/internal/megdnn_opr_wrapper.h" | #include "megbrain/opr/internal/megdnn_opr_wrapper.h" | ||||
#include "megdnn/oprs/general.h" | |||||
#include "megdnn/oprs/linalg.h" | #include "megdnn/oprs/linalg.h" | ||||
namespace mgb { | namespace mgb { | ||||
@@ -40,6 +41,7 @@ private: | |||||
void add_input_layout_constraint() override; | void add_input_layout_constraint() override; | ||||
void scn_do_execute() override; | void scn_do_execute() override; | ||||
void init_output_dtype() override; | void init_output_dtype() override; | ||||
NodeProp* do_make_node_prop() const override; | |||||
size_t get_workspace_size_bytes(const TensorShapeArray& input_shapes, | size_t get_workspace_size_bytes(const TensorShapeArray& input_shapes, | ||||
const TensorShapeArray& output_shapes) | const TensorShapeArray& output_shapes) | ||||
const override; | const override; | ||||
@@ -47,6 +49,7 @@ private: | |||||
//! store the policy of all transpose situations | //! store the policy of all transpose situations | ||||
megdnn::ExecutionPolicy m_cadidate_execution_policies[4]; | megdnn::ExecutionPolicy m_cadidate_execution_policies[4]; | ||||
std::unique_ptr<megdnn::Fill> m_fill_opr; | |||||
}; | }; | ||||
/*! | /*! | ||||
@@ -70,6 +73,7 @@ private: | |||||
void add_input_layout_constraint() override; | void add_input_layout_constraint() override; | ||||
void init_output_dtype() override; | void init_output_dtype() override; | ||||
void scn_do_execute() override; | void scn_do_execute() override; | ||||
NodeProp* do_make_node_prop() const override; | |||||
size_t get_workspace_size_bytes(const TensorShapeArray& input_shapes, | size_t get_workspace_size_bytes(const TensorShapeArray& input_shapes, | ||||
const TensorShapeArray& output_shapes) | const TensorShapeArray& output_shapes) | ||||
const override; | const override; | ||||
@@ -77,6 +81,7 @@ private: | |||||
static bool check_layout(const TensorLayout& layout, bool transpose); | static bool check_layout(const TensorLayout& layout, bool transpose); | ||||
//! store the policy of all transpose situations | //! store the policy of all transpose situations | ||||
megdnn::ExecutionPolicy m_cadidate_execution_policies[4]; | megdnn::ExecutionPolicy m_cadidate_execution_policies[4]; | ||||
std::unique_ptr<megdnn::Fill> m_fill_opr; | |||||
}; | }; | ||||
/*! | /*! | ||||
@@ -101,7 +106,9 @@ MGB_DEFINE_OPR_CLASS(Dot, cg::SingleCNOperatorNodeBaseT< | |||||
void add_input_layout_constraint() override; | void add_input_layout_constraint() override; | ||||
void scn_do_execute() override; | void scn_do_execute() override; | ||||
void init_output_static_infer_desc() override; | void init_output_static_infer_desc() override; | ||||
NodeProp* do_make_node_prop() const override; | |||||
void record_execute_deps(ExecDependencyArray &deps) override; | void record_execute_deps(ExecDependencyArray &deps) override; | ||||
std::unique_ptr<megdnn::Fill> m_fill_opr; | |||||
}; | }; | ||||
MGB_DEFINE_MEGDNN_OPR_WRAPPER_FWD1(MatrixInverse); | MGB_DEFINE_MEGDNN_OPR_WRAPPER_FWD1(MatrixInverse); | ||||
@@ -94,7 +94,9 @@ void run_sgemm_test(bool transa, bool transb) { | |||||
Checker(make_graph, fwd) | Checker(make_graph, fwd) | ||||
.run({mkx(4, 6), mky(6, 2)}, opt) | .run({mkx(4, 6), mky(6, 2)}, opt) | ||||
.run({mkx(2, 3), mky(3, 100)}, opt) | .run({mkx(2, 3), mky(3, 100)}, opt) | ||||
.run({mkx(20, 3), mky(3, 20)}, opt); | |||||
.run({mkx(20, 3), mky(3, 20)}, opt) | |||||
.run({mkx(10, 0), mky(0, 10)}, opt) | |||||
.run({mkx(0, 0), mky(0, 0)}, opt); | |||||
} | } | ||||
#define FWD_BATCH_GEMM(dt_src, dt_dst) \ | #define FWD_BATCH_GEMM(dt_src, dt_dst) \ | ||||
@@ -143,7 +145,9 @@ void run_batched_sgemm_test(bool transa, bool transb) { | |||||
Checker(make_graph, fwd) | Checker(make_graph, fwd) | ||||
.run({mkx(3, 5, 7), mky(3, 7, 2)}, opt) | .run({mkx(3, 5, 7), mky(3, 7, 2)}, opt) | ||||
.run({mkx(64, 1, 2), mky(64, 2, 1)}, opt) | .run({mkx(64, 1, 2), mky(64, 2, 1)}, opt) | ||||
.run({mkx(1, 2, 3), mky(1, 3, 4)}, opt); | |||||
.run({mkx(1, 2, 3), mky(1, 3, 4)}, opt) | |||||
.run({mkx(3, 0, 2), mky(3, 2, 0)}, opt) | |||||
.run({mkx(64, 10, 0), mky(64, 0, 10)}, opt); | |||||
} | } | ||||
auto gen_fp16 = [](HostTensorND& dest) { | auto gen_fp16 = [](HostTensorND& dest) { | ||||
@@ -198,6 +202,7 @@ void run_batched_hgemm_test(bool transa, bool transb) { | |||||
checker.run({mkx(3, 5, 7), mky(3, 7, 2)}, opt) | checker.run({mkx(3, 5, 7), mky(3, 7, 2)}, opt) | ||||
.run({mkx(64, 1, 2), mky(64, 2, 1)}, opt) | .run({mkx(64, 1, 2), mky(64, 2, 1)}, opt) | ||||
.run({mkx(64, 10, 0), mky(64, 0, 10)}, opt) | |||||
.run({mkx(1, 2, 3), mky(1, 3, 4)}, opt); | .run({mkx(1, 2, 3), mky(1, 3, 4)}, opt); | ||||
} | } | ||||
@@ -236,6 +241,7 @@ void run_batched_igemm_test(bool transa, bool transb) { | |||||
checker.run({mkx(3, 5, 7), mky(3, 7, 2)}, opt) | checker.run({mkx(3, 5, 7), mky(3, 7, 2)}, opt) | ||||
.run({mkx(64, 1, 2), mky(64, 2, 1)}, opt) | .run({mkx(64, 1, 2), mky(64, 2, 1)}, opt) | ||||
.run({mkx(64, 10, 0), mky(64, 0, 10)}, opt) | |||||
.run({mkx(1, 2, 3), mky(1, 3, 4)}, opt); | .run({mkx(1, 2, 3), mky(1, 3, 4)}, opt); | ||||
} | } | ||||
@@ -650,7 +656,8 @@ TEST(TestOprBlas, Dot) { | |||||
.run({TensorShape{15}, TensorShape{1}}) | .run({TensorShape{15}, TensorShape{1}}) | ||||
.run({TensorShape{1}, TensorShape{16}}) | .run({TensorShape{1}, TensorShape{16}}) | ||||
.run({TensorShape{23}, TensorShape{23}}) | .run({TensorShape{23}, TensorShape{23}}) | ||||
.run({TensorShape{1000}, TensorShape{1000}}); | |||||
.run({TensorShape{1000}, TensorShape{1000}}) | |||||
.run({TensorShape{0}, TensorShape{0}}); | |||||
} | } | ||||
TEST(TestOprBlas, TransMatMul) { | TEST(TestOprBlas, TransMatMul) { | ||||
@@ -250,7 +250,6 @@ DEF_IMPL(void)::do_run(const ShapeInpArray& shapes, const RunOptions& opt) { | |||||
for (size_t i = 0; i < nr_out; ++i) { | for (size_t i = 0; i < nr_out; ++i) { | ||||
if (m_outputs_allow_grad[i]) { | if (m_outputs_allow_grad[i]) { | ||||
auto nr = m_outputs_truth[i].shape().total_nr_elems(); | auto nr = m_outputs_truth[i].shape().total_nr_elems(); | ||||
mgb_assert(nr, "got empty output"); | |||||
if (opt.cont_loss_p) { | if (opt.cont_loss_p) { | ||||
m_loss_p[i]->resize({nr}); | m_loss_p[i]->resize({nr}); | ||||
auto ptr = m_loss_p[i]->template ptr<float>(); | auto ptr = m_loss_p[i]->template ptr<float>(); | ||||
@@ -36,7 +36,7 @@ std::vector<HostTensorND> mgb::numerical_diff_pt2( | |||||
resize(cur_inp->shape()); | resize(cur_inp->shape()); | ||||
auto dptr = dest.ptr<float>(); | auto dptr = dest.ptr<float>(); | ||||
mgb_assert(cur_inp->layout().is_contiguous()); | |||||
mgb_assert(cur_inp->layout().is_contiguous() || cur_inp->layout().is_empty()); | |||||
auto cur_inp_ptr = cur_inp->ptr<float>(); | auto cur_inp_ptr = cur_inp->ptr<float>(); | ||||
mgb::RealTimer timer; | mgb::RealTimer timer; | ||||