GitOrigin-RevId: bb7ab8fa9d
release-1.10
@@ -70,15 +70,6 @@ def _matmul( | |||
maxdim = dim1 if dim1 > dim2 else dim2 | |||
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | |||
Strategy = builtin.ops.MatrixMul.Strategy | |||
strategy = Strategy(0) | |||
if _config._benchmark_kernel: | |||
strategy |= Strategy.PROFILE | |||
else: | |||
strategy |= Strategy.HEURISTIC | |||
if _config._deterministic_kernel: | |||
strategy |= Strategy.REPRODUCIBLE | |||
if dim1 == 1 and dim2 == 1: # dispatch to Dot | |||
(result,) = apply(builtin.Dot(), inp1, inp2) | |||
return result | |||
@@ -621,6 +621,7 @@ def max_pool2d( | |||
pad_h=padding_h, | |||
pad_w=padding_w, | |||
mode="max", | |||
strategy=get_execution_strategy(), | |||
format=conv_format, | |||
) | |||
(output,) = apply(op, inp) | |||
@@ -665,6 +666,7 @@ def avg_pool2d( | |||
pad_h=padding_h, | |||
pad_w=padding_w, | |||
mode=mode, | |||
strategy=get_execution_strategy(), | |||
format=conv_format, | |||
) | |||
(output,) = apply(op, inp) | |||
@@ -1493,7 +1493,7 @@ py::object _transpose_cpp(py::handle inp_hdl, py::handle args) { | |||
py::object _matmul_cpp( | |||
py::handle inp1, py::handle inp2, py::handle dim1, py::handle dim2, | |||
py::handle transpose_a, py::handle transpose_b, py::handle compute_mode, | |||
py::handle profile, py::handle determistic) { | |||
py::handle profile, py::handle deterministic) { | |||
::megdnn::param::MatrixMul::ComputeMode mode = | |||
::megdnn::param::MatrixMul::ComputeMode::DEFAULT; | |||
if (compute_mode.cast<std::string>().compare(std::string("float32")) == 0) { | |||
@@ -1506,7 +1506,7 @@ py::object _matmul_cpp( | |||
} else { | |||
cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::HEURISTIC; | |||
} | |||
if (determistic.cast<bool>()) { | |||
if (deterministic.cast<bool>()) { | |||
cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::REPRODUCIBLE; | |||
} | |||
std::shared_ptr<OpDef> op = MatrixMul::make( | |||
@@ -1523,7 +1523,7 @@ py::object _matmul_cpp( | |||
py::object _batched_matmul_cpp( | |||
py::handle inp1, py::handle inp2, py::handle dim1, py::handle dim2, | |||
py::handle transpose_a, py::handle transpose_b, py::handle compute_mode, | |||
py::handle profile, py::handle determistic) { | |||
py::handle profile, py::handle deterministic) { | |||
::megdnn::param::MatrixMul::ComputeMode mode = | |||
::megdnn::param::MatrixMul::ComputeMode::DEFAULT; | |||
if (compute_mode.cast<std::string>().compare(std::string("float32")) == 0) { | |||
@@ -1536,7 +1536,7 @@ py::object _batched_matmul_cpp( | |||
} else { | |||
cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::HEURISTIC; | |||
} | |||
if (determistic.cast<bool>()) { | |||
if (deterministic.cast<bool>()) { | |||
cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::REPRODUCIBLE; | |||
} | |||
std::shared_ptr<OpDef> op = BatchedMatrixMul::make( | |||
@@ -10,6 +10,10 @@ | |||
*/ | |||
#include "megbrain/imperative/ops/opr_attr.h" | |||
#include "megbrain/opr/blas.h" | |||
#include "megbrain/opr/dnn/convolution.h" | |||
#include "megbrain/opr/dnn/pooling.h" | |||
#include "megbrain/rdnn/profiler.h" | |||
#include "megbrain/serialization/opr_load_dump.h" | |||
#include "../op_trait.h" | |||
@@ -65,6 +69,42 @@ public: | |||
const serialization::GraphDumpConfig& config() const { mgb_assert(0); } | |||
}; | |||
#define cb(FASTRUN_OPR) \ | |||
megdnn::param::ExecutionPolicy get_strategy_##FASTRUN_OPR( \ | |||
cg::OperatorNodeBase* opr) { \ | |||
auto policy = \ | |||
opr->cast_final<opr::FASTRUN_OPR>().execution_policy_transient(); \ | |||
return policy; \ | |||
} \ | |||
void set_strategy_##FASTRUN_OPR( \ | |||
cg::OperatorNodeBase* opr, megdnn::param::ExecutionPolicy policy) { \ | |||
auto&& p = opr->cast_final<opr::FASTRUN_OPR>(); \ | |||
p.set_execution_policy(policy); \ | |||
} | |||
DNN_FOREACH_FASTRUN_OPR(cb) | |||
#undef cb | |||
typedef thin_function<megdnn::param::ExecutionPolicy(cg::OperatorNodeBase*)> get_func; | |||
typedef thin_function<void(cg::OperatorNodeBase*, megdnn::param::ExecutionPolicy)> | |||
set_func; | |||
static const mgb::thin_hash_table::ThinHashMap< | |||
mgb::Typeinfo*, std::pair<get_func, set_func>>& | |||
get_type2policy() { | |||
static mgb::thin_hash_table::ThinHashMap< | |||
mgb::Typeinfo*, std::pair<get_func, set_func>> | |||
sl_type2policy; | |||
static std::once_flag flag; | |||
std::call_once(flag, [&]() { | |||
#define cb(FASTRUN_OPR) \ | |||
sl_type2policy[opr::FASTRUN_OPR::typeinfo()] = \ | |||
std::make_pair(get_strategy_##FASTRUN_OPR, set_strategy_##FASTRUN_OPR); | |||
DNN_FOREACH_FASTRUN_OPR(cb) | |||
}); | |||
return std::as_const(sl_type2policy); | |||
} | |||
VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& attr = def.cast_final_safe<OprAttr>(); | |||
auto config = attr.config; | |||
@@ -73,7 +113,12 @@ VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto registry = serialization::OprRegistry::find_by_name(attr.type); | |||
mgb_assert(registry, "operator %s not found", attr.type.c_str()); | |||
OprParamsLoadContext ctx{attr.param, inputs[0]->owner_graph()}; | |||
return registry->loader(ctx, inputs, config).usable_output(); | |||
auto opr_with_accessor = registry->loader(ctx, inputs, config); | |||
auto&& opr = opr_with_accessor.opr(); | |||
if (get_type2policy().find(opr->dyn_typeinfo()) != get_type2policy().end()) { | |||
get_type2policy().at(opr->dyn_typeinfo()).second(opr, attr.policy); | |||
} | |||
return opr_with_accessor.usable_output(); | |||
} | |||
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* opr) { | |||
@@ -84,7 +129,11 @@ std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* opr) { | |||
registry->dumper, "operator %s cannot be serialized", | |||
opr->dyn_typeinfo()->name); | |||
registry->dumper(ctx, *opr); | |||
return OprAttr::make(registry->name, std::move(ctx.m_param), opr->config()); | |||
megdnn::param::ExecutionPolicy policy; | |||
if (get_type2policy().find(opr->dyn_typeinfo()) != get_type2policy().end()) { | |||
policy = get_type2policy().at(opr->dyn_typeinfo()).first(opr); | |||
} | |||
return OprAttr::make(registry->name, std::move(ctx.m_param), policy, opr->config()); | |||
} | |||
std::vector<std::pair<const char*, std::string>> props(const OpDef& def) { | |||
@@ -108,6 +157,8 @@ OP_TRAIT_REG(OprAttr, OprAttr) | |||
bool OprAttr::is_same_st(const Hashable& rhs_) const { | |||
auto&& rhs = static_cast<const OprAttr&>(rhs_); | |||
return type == rhs.type && param == rhs.param && | |||
policy.strategy == rhs.policy.strategy && | |||
policy.workspace_limit == rhs.policy.workspace_limit && | |||
config.comp_node() == rhs.config.comp_node() && | |||
config.output_dtype() == rhs.config.output_dtype(); | |||
} | |||
@@ -115,7 +166,12 @@ bool OprAttr::is_same_st(const Hashable& rhs_) const { | |||
size_t OprAttr::hash() const { | |||
return hash_pair_combine( | |||
hash_pair_combine( | |||
mgb::hash(type), mgb::hash(static_cast<std::vector<char>>(param))), | |||
hash_pair_combine( | |||
mgb::hash(type), | |||
mgb::hash(static_cast<std::vector<char>>(param))), | |||
hash_pair_combine( | |||
static_cast<size_t>(policy.strategy), | |||
policy.workspace_limit)), | |||
config.hash()); | |||
} | |||
@@ -12,6 +12,7 @@ | |||
#pragma once | |||
#include "megbrain/imperative/op_def.h" | |||
#include "megbrain/opr/param_defs.h" | |||
namespace mgb { | |||
namespace imperative { | |||
@@ -38,12 +39,16 @@ public: | |||
Type type; | |||
Param param; | |||
megdnn::param::ExecutionPolicy policy; | |||
cg::OperatorNodeConfig config; | |||
OprAttr() = default; | |||
OprAttr(const Type& t) : type(t) {} | |||
OprAttr(const Type& t, const Param& p, const cg::OperatorNodeConfig& c) | |||
: type(t), param(p), config(c) {} | |||
OprAttr(const Type& t, const Param& p, const megdnn::param::ExecutionPolicy ps, | |||
const cg::OperatorNodeConfig& c) | |||
: type(t), param(p), policy(ps), config(c) {} | |||
std::string repr() const; | |||
@@ -157,6 +157,51 @@ TEST(TestImperative, BackwardGraphBasic) { | |||
} | |||
} | |||
TEST(TestImperative, ProfileBackward) { | |||
auto cn = CompNode::load("xpux"); | |||
using Policy = megdnn::param::ExecutionPolicy; | |||
using S = Policy::Strategy; | |||
Policy policy; | |||
policy.strategy = S::PROFILE; | |||
{ | |||
megdnn::param::Convolution param; | |||
auto op = std::shared_ptr<OpDef>(Convolution::make(param, policy)); | |||
LogicalTensorDesc inp_desc = { | |||
TensorLayout({16, 3, 16, 16}, dtype::Float32()), cn}; | |||
LogicalTensorDesc weight_desc = { | |||
TensorLayout({16, 3, 5, 5}, dtype::Float32()), cn}; | |||
auto bg = OpDef::make_backward_graph( | |||
*op, {inp_desc, weight_desc}, {true, false}, {true}); | |||
auto&& bop = (bg.graph.exprs.at(0)).op; | |||
auto&& attr = bop->cast_final_safe<OprAttr>(); | |||
// attr.type = ConvolutionBackwardDataV2 | |||
mgb_assert(attr.policy.strategy == S::PROFILE); | |||
} | |||
{ | |||
megdnn::param::Pooling param; | |||
auto op = std::shared_ptr<OpDef>(Pooling::make(param, policy)); | |||
LogicalTensorDesc inp_desc = { | |||
TensorLayout({16, 3, 16, 16}, dtype::Float32()), cn}; | |||
auto bg = OpDef::make_backward_graph(*op, {inp_desc}, {true}, {true}); | |||
auto&& bop = (bg.graph.exprs.at(0)).op; | |||
auto&& attr = bop->cast_final_safe<OprAttr>(); | |||
// attr.type = PoolingBackwardV1 | |||
mgb_assert(attr.policy.strategy == S::PROFILE); | |||
} | |||
{ | |||
megdnn::param::MatrixMul param; | |||
auto op = std::shared_ptr<OpDef>(MatrixMul::make(param, policy, 2, 2)); | |||
LogicalTensorDesc inp1_desc = {TensorLayout({12, 16}, dtype::Float32()), cn}; | |||
LogicalTensorDesc inp2_desc = {TensorLayout({16, 20}, dtype::Float32()), cn}; | |||
auto bg = OpDef::make_backward_graph( | |||
*op, {inp1_desc, inp2_desc}, {true, false}, {true}); | |||
auto&& bop = (bg.graph.exprs.at(0)).op; | |||
auto&& attr = bop->cast_final_safe<OprAttr>(); | |||
// attr.type = MatrixMulV2 | |||
mgb_assert(attr.policy.strategy == S::PROFILE); | |||
} | |||
} | |||
TEST(TestImperative, BackwardGraphIdentity) { | |||
HostTensorGenerator<> gen; | |||
auto host_a = gen({42}), host_dc = gen({42}); | |||
@@ -185,17 +185,21 @@ MGB_IMPL_OPR_GRAD(MatrixMul) { | |||
if (wrt_idx == 0) { | |||
// A * B = C, A' = C' * Bt | |||
if (opr.param().transposeA) { | |||
grad = MatrixMul::make(i1, og, {opr.param().transposeB, true}); | |||
grad = MatrixMul::make( | |||
i1, og, {opr.param().transposeB, true}, opr.execution_policy()); | |||
} else { | |||
grad = MatrixMul::make(og, i1, {false, !opr.param().transposeB}); | |||
grad = MatrixMul::make( | |||
og, i1, {false, !opr.param().transposeB}, opr.execution_policy()); | |||
} | |||
} else { | |||
mgb_assert(wrt_idx == 1); | |||
// A * B = C, B' = At * C' | |||
if (opr.param().transposeB) { | |||
grad = MatrixMul::make(og, i0, {true, opr.param().transposeA}); | |||
grad = MatrixMul::make( | |||
og, i0, {true, opr.param().transposeA}, opr.execution_policy()); | |||
} else { | |||
grad = MatrixMul::make(i0, og, {!opr.param().transposeA, false}); | |||
grad = MatrixMul::make( | |||
i0, og, {!opr.param().transposeA, false}, opr.execution_policy()); | |||
} | |||
} | |||
return grad.node(); | |||
@@ -358,17 +362,21 @@ MGB_IMPL_OPR_GRAD(BatchedMatrixMul) { | |||
if (wrt_idx == 0) { | |||
// A * B = C, A' = C' * Bt | |||
if (opr.param().transposeA) { | |||
grad = BatchedMatrixMul::make(i1, og, {opr.param().transposeB, true}); | |||
grad = BatchedMatrixMul::make( | |||
i1, og, {opr.param().transposeB, true}, opr.execution_policy()); | |||
} else { | |||
grad = BatchedMatrixMul::make(og, i1, {false, !opr.param().transposeB}); | |||
grad = BatchedMatrixMul::make( | |||
og, i1, {false, !opr.param().transposeB}, opr.execution_policy()); | |||
} | |||
} else { | |||
mgb_assert(wrt_idx == 1); | |||
// A * B = C, B' = At * C' | |||
if (opr.param().transposeB) { | |||
grad = BatchedMatrixMul::make(og, i0, {true, opr.param().transposeA}); | |||
grad = BatchedMatrixMul::make( | |||
og, i0, {true, opr.param().transposeA}, opr.execution_policy()); | |||
} else { | |||
grad = BatchedMatrixMul::make(i0, og, {!opr.param().transposeA, false}); | |||
grad = BatchedMatrixMul::make( | |||
i0, og, {!opr.param().transposeA, false}, opr.execution_policy()); | |||
} | |||
} | |||
return grad.node(); | |||
@@ -59,7 +59,8 @@ size_t PoolingForward::get_workspace_size_bytes( | |||
MGB_IMPL_OPR_GRAD(PoolingForward) { | |||
mgb_assert(wrt_idx == 0); | |||
SymbolVar grad = PoolingBackward::make( | |||
opr.input(0), opr.output(0), out_grad[0], opr.param()); | |||
opr.input(0), opr.output(0), out_grad[0], opr.param(), | |||
opr.execution_policy()); | |||
return grad.node(); | |||
} | |||
#endif | |||
@@ -26,7 +26,7 @@ namespace opr { | |||
/*! | |||
* \brief matrix_mul(trans0(opr0), trans1(opr1)) | |||
*/ | |||
MGB_DEFINE_OPR_CLASS( | |||
MGB_DEFINE_OPR_CLASS_WITH_EXPORT( | |||
MatrixMul, intl::MegDNNOprWrapperFwd<megdnn::MatrixMul>, | |||
public mixin::AlgoChooserHelper) // { | |||
public: | |||
@@ -57,7 +57,7 @@ private: | |||
/*! | |||
* \brief batched matrix multiplication on 3D inputs | |||
*/ | |||
MGB_DEFINE_OPR_CLASS( | |||
MGB_DEFINE_OPR_CLASS_WITH_EXPORT( | |||
BatchedMatrixMul, intl::MegDNNOprWrapperFwd<megdnn::BatchedMatrixMul>, | |||
public mixin::AlgoChooserHelper) // { | |||
public: | |||
@@ -18,7 +18,7 @@ | |||
namespace mgb { | |||
namespace opr { | |||
MGB_DEFINE_OPR_CLASS( | |||
MGB_DEFINE_OPR_CLASS_WITH_EXPORT( | |||
PoolingForward, intl::MegDNNOprWrapperFwd<megdnn::PoolingForward>, | |||
public mixin::AlgoChooserHelper) // { | |||
public: | |||
@@ -37,7 +37,7 @@ public: | |||
}; | |||
using Pooling = PoolingForward; | |||
MGB_DEFINE_OPR_CLASS( | |||
MGB_DEFINE_OPR_CLASS_WITH_EXPORT( | |||
PoolingBackward, intl::MegDNNOprWrapperBwd<megdnn::PoolingBackward>, | |||
public mixin::AlgoChooserHelper) // { | |||
public: | |||
@@ -51,7 +51,7 @@ public: | |||
* Exception would be thrown if execution_policy() has been accessed, | |||
* since it would influence cache and many other decisions. | |||
*/ | |||
void set_execution_policy(const ExecutionPolicy& policy); | |||
MGE_WIN_DECLSPEC_FUC void set_execution_policy(const ExecutionPolicy& policy); | |||
/*! | |||
* \brief register a hook to implement custom algo chooser | |||