GitOrigin-RevId: 05577a8bc8
HuaHua404-patch-4
@@ -37,7 +37,6 @@ _ElwMod = builtin.Elemwise.Mode | |||||
def _elemwise_multi_type(*args, mode, **kwargs): | def _elemwise_multi_type(*args, mode, **kwargs): | ||||
op = builtin.ElemwiseMultiType(mode=mode, **kwargs) | op = builtin.ElemwiseMultiType(mode=mode, **kwargs) | ||||
args = convert_inputs(*args) | |||||
(result,) = apply(op, *args) | (result,) = apply(op, *args) | ||||
return result | return result | ||||
@@ -249,22 +248,22 @@ class ArrayMethodMixin(abc.ABC): | |||||
__hash__ = None # due to __eq__ diviates from python convention | __hash__ = None # due to __eq__ diviates from python convention | ||||
__lt__ = lambda self, value: _elemwise_multi_type( | __lt__ = lambda self, value: _elemwise_multi_type( | ||||
self, value, mode="lt", dtype="Bool" | |||||
self, value, mode="lt", dtype="bool" | |||||
) | ) | ||||
__le__ = lambda self, value: _elemwise_multi_type( | __le__ = lambda self, value: _elemwise_multi_type( | ||||
self, value, mode="leq", dtype="Bool" | |||||
self, value, mode="leq", dtype="bool" | |||||
) | ) | ||||
__gt__ = lambda self, value: _elemwise_multi_type( | __gt__ = lambda self, value: _elemwise_multi_type( | ||||
value, self, mode="lt", dtype="Bool" | |||||
value, self, mode="lt", dtype="bool" | |||||
) | ) | ||||
__ge__ = lambda self, value: _elemwise_multi_type( | __ge__ = lambda self, value: _elemwise_multi_type( | ||||
value, self, mode="leq", dtype="Bool" | |||||
value, self, mode="leq", dtype="bool" | |||||
) | ) | ||||
__eq__ = lambda self, value: _elemwise_multi_type( | __eq__ = lambda self, value: _elemwise_multi_type( | ||||
self, value, mode="eq", dtype="Bool" | |||||
self, value, mode="eq", dtype="bool" | |||||
) | ) | ||||
__ne__ = lambda self, value: _elemwise_multi_type( | __ne__ = lambda self, value: _elemwise_multi_type( | ||||
self, value, mode="neq", dtype="Bool" | |||||
self, value, mode="neq", dtype="bool" | |||||
) | ) | ||||
__neg__ = _unary_elwise(_ElwMod.NEGATE) | __neg__ = _unary_elwise(_ElwMod.NEGATE) | ||||
@@ -52,7 +52,7 @@ def isnan(inp: Tensor) -> Tensor: | |||||
>>> F.isnan(x).numpy() | >>> F.isnan(x).numpy() | ||||
array([False, True, False]) | array([False, True, False]) | ||||
""" | """ | ||||
return _elemwise_multi_type(inp, mode="isnan", dtype="Bool") | |||||
return _elemwise_multi_type(inp, mode="isnan", dtype="bool") | |||||
def isinf(inp: Tensor) -> Tensor: | def isinf(inp: Tensor) -> Tensor: | ||||
@@ -69,7 +69,7 @@ def isinf(inp: Tensor) -> Tensor: | |||||
>>> F.isinf(x).numpy() | >>> F.isinf(x).numpy() | ||||
array([False, True, False]) | array([False, True, False]) | ||||
""" | """ | ||||
return _elemwise_multi_type(inp, mode="isinf", dtype="Bool") | |||||
return _elemwise_multi_type(inp, mode="isinf", dtype="bool") | |||||
def sign(inp: Tensor): | def sign(inp: Tensor): | ||||
@@ -118,7 +118,7 @@ PyObject* py_apply( | |||||
tensors[i] = tw->m_tensor->data(); | tensors[i] = tw->m_tensor->data(); | ||||
} else if ( | } else if ( | ||||
DTypePromoteCfg::convert_input_enabled && | DTypePromoteCfg::convert_input_enabled && | ||||
op->same_type<Elemwise>()) { | |||||
(op->same_type<Elemwise>() || op->same_type<ElemwiseMultiType>())) { | |||||
tensors[i] = convert_pyinput_to_tensor(i); | tensors[i] = convert_pyinput_to_tensor(i); | ||||
} else { | } else { | ||||
PyErr_SetString(PyExc_TypeError, "py_apply expects tensor as inputs"); | PyErr_SetString(PyExc_TypeError, "py_apply expects tensor as inputs"); | ||||
@@ -53,6 +53,41 @@ mgb::DType get_promoted_dtype(const SmallVector<DType>& dtypes) { | |||||
return ret; | return ret; | ||||
} | } | ||||
ValueRefList elemwise_multi_type_rule(const OpDef& op, Span<ValueRef> inputs) { | |||||
auto&& elem_op = op.cast_final_safe<ElemwiseMultiType>(); | |||||
static std::unordered_set<ElemwiseMultiType::Mode> cast_case = { | |||||
ElemwiseMultiType::Mode::EQ, | |||||
ElemwiseMultiType::Mode::NEQ, | |||||
ElemwiseMultiType::Mode::LT, | |||||
ElemwiseMultiType::Mode::LEQ, | |||||
}; | |||||
if (cast_case.find(elem_op.mode) == cast_case.end()) { | |||||
return imperative::apply(op, inputs); | |||||
} | |||||
SmallVector<DType> dtypes(inputs.size()); | |||||
for (size_t i = 0; i < inputs.size(); ++i) { | |||||
dtypes[i] = *(inputs[i].dtype()); | |||||
} | |||||
ValueRefList converted(inputs.size()); | |||||
mgb::DType target_dtype = get_promoted_dtype(dtypes); | |||||
for (size_t i = 0; i < inputs.size(); ++i) { | |||||
if (!is_quantized_dtype(dtypes[i]) && dtypes[i] != target_dtype && | |||||
DTypePromoteCfg::convert_input_enabled) { | |||||
converted[i] = imperative::apply( | |||||
ApplyOp(*TypeCvt::make(target_dtype)), inputs[i])[0]; | |||||
dtypes[i] = target_dtype; | |||||
} else { | |||||
converted[i] = inputs[i]; | |||||
} | |||||
} | |||||
return imperative::apply(op, converted); | |||||
} | |||||
ValueRefList elemwise_rule(const OpDef& op, Span<ValueRef> inputs) { | ValueRefList elemwise_rule(const OpDef& op, Span<ValueRef> inputs) { | ||||
auto&& elem_op = op.cast_final_safe<Elemwise>(); | auto&& elem_op = op.cast_final_safe<Elemwise>(); | ||||
@@ -349,6 +384,7 @@ ValueRefList naive_promote_rule(const OpDef& op, Span<ValueRef> inputs) { | |||||
struct DTypePromoteRuleRegistry { | struct DTypePromoteRuleRegistry { | ||||
DTypePromoteRuleRegistry() { | DTypePromoteRuleRegistry() { | ||||
register_dtype_promote_rule<Elemwise>(elemwise_rule); | register_dtype_promote_rule<Elemwise>(elemwise_rule); | ||||
register_dtype_promote_rule<ElemwiseMultiType>(elemwise_multi_type_rule); | |||||
register_dtype_promote_rule<Concat>(naive_promote_rule); | register_dtype_promote_rule<Concat>(naive_promote_rule); | ||||
register_dtype_promote_rule<GroupLocal>(naive_promote_rule); | register_dtype_promote_rule<GroupLocal>(naive_promote_rule); | ||||
register_dtype_promote_rule<Reduce>(reduce_rule); | register_dtype_promote_rule<Reduce>(reduce_rule); | ||||
@@ -16,52 +16,6 @@ | |||||
using namespace mgb; | using namespace mgb; | ||||
using namespace opr; | using namespace opr; | ||||
namespace { | |||||
//! global operator instance for static inference | |||||
template <class Opr> | |||||
class StaticInferOpr { | |||||
intl::UniqPtrWithCN<Opr> m_opr; | |||||
MGB_MUTEX m_mtx; | |||||
public: | |||||
class Lock { | |||||
friend class StaticInferOpr; | |||||
StaticInferOpr* m_owner; | |||||
explicit Lock(StaticInferOpr* owner) : m_owner{owner} { | |||||
#if !__DEPLOY_ON_XP_SP2__ | |||||
m_owner->m_mtx.lock(); | |||||
#endif | |||||
} | |||||
public: | |||||
Lock(Lock&& rhs) : m_owner{rhs.m_owner} { rhs.m_owner = nullptr; } | |||||
~Lock() { | |||||
#if !__DEPLOY_ON_XP_SP2__ | |||||
if (m_owner) | |||||
m_owner->m_mtx.unlock(); | |||||
#endif | |||||
} | |||||
Lock& operator=(const Lock&) = delete; | |||||
Lock& operator=(Lock&&) = delete; | |||||
intl::UniqPtrWithCN<Opr>& operator()() { return m_owner->m_opr; } | |||||
}; | |||||
//! lock and acquire the operator | |||||
Lock lock() { | |||||
Lock ret{this}; | |||||
if (!m_opr) { | |||||
m_opr = intl::create_megdnn_opr<Opr>(CompNode::default_cpu()); | |||||
} | |||||
return ret; | |||||
} | |||||
}; | |||||
} // anonymous namespace | |||||
/* ========================= BatchedDTypePromotion ========================= */ | /* ========================= BatchedDTypePromotion ========================= */ | ||||
intl::BatchedDTypePromotion::BatchedDTypePromotion(const VarNodeArrayView& vars) | intl::BatchedDTypePromotion::BatchedDTypePromotion(const VarNodeArrayView& vars) | ||||
: m_orig_vars{vars} { | : m_orig_vars{vars} { | ||||
@@ -1,6 +1,6 @@ | |||||
#include "megbrain/opr/nn_int.h" | #include "megbrain/opr/nn_int.h" | ||||
#include "./internal/megdnn_opr_wrapper.inl" | #include "./internal/megdnn_opr_wrapper.inl" | ||||
#include "megbrain/opr/utility.h" | |||||
#include "megdnn/oprs/general.h" | #include "megdnn/oprs/general.h" | ||||
using namespace mgb; | using namespace mgb; | ||||
@@ -18,6 +18,7 @@ ElemwiseMultiType::ElemwiseMultiType( | |||||
for (auto i : inputs) { | for (auto i : inputs) { | ||||
add_input({i}); | add_input({i}); | ||||
} | } | ||||
output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); | |||||
} | } | ||||
SymbolVar ElemwiseMultiType::make( | SymbolVar ElemwiseMultiType::make( | ||||
@@ -52,8 +53,13 @@ void ElemwiseMultiType::init_output_dtype() { | |||||
void ElemwiseMultiType::scn_do_execute() { | void ElemwiseMultiType::scn_do_execute() { | ||||
megdnn::TensorNDArray inp_arr(input().size()); | megdnn::TensorNDArray inp_arr(input().size()); | ||||
for (size_t i = 0; i < input().size(); ++i) { | for (size_t i = 0; i < input().size(); ++i) { | ||||
if (input()[i]->dev_tensor().empty()) { | |||||
mgb_assert(output(0)->dev_tensor().empty()); | |||||
return; | |||||
} | |||||
inp_arr[i] = input()[i]->dev_tensor().as_megdnn(); | inp_arr[i] = input()[i]->dev_tensor().as_megdnn(); | ||||
} | } | ||||
mgb_assert(!output(0)->dev_tensor().empty()); | |||||
megdnn_opr()->exec(inp_arr, output(0)->dev_tensor().as_megdnn()); | megdnn_opr()->exec(inp_arr, output(0)->dev_tensor().as_megdnn()); | ||||
} | } | ||||
@@ -75,4 +81,120 @@ void ElemwiseMultiType::add_input_layout_constraint() { | |||||
#endif | #endif | ||||
} | } | ||||
ElemwiseMultiType::NodeProp* ElemwiseMultiType::do_make_node_prop() const { | |||||
auto ret = Super::do_make_node_prop(); | |||||
for (auto& inp : input()) { | |||||
ret->add_dep_type_existing_var(inp, NodeProp::DepType::VALUE_ALLOW_EMPTY); | |||||
} | |||||
return ret; | |||||
} | |||||
void ElemwiseMultiType::init_output_static_infer_desc() { | |||||
Super::init_output_static_infer_desc(); | |||||
static StaticInferOpr<megdnn::ElemwiseMultiType> static_infer_opr; | |||||
using namespace cg::static_infer; | |||||
auto infer_value = [this](DeviceTensorND& dest, const InpVal& inp) { | |||||
SmallVector<DeviceTensorND> inp_vals(inp.val.size()); | |||||
for (size_t i = 0; i < inp_vals.size(); ++i) | |||||
inp_vals[i] = inp.val[i].value(); | |||||
DType out_dt; | |||||
auto trait = ModeTrait::from_mode(param().mode); | |||||
if (trait.need_specify_out_dtype) { | |||||
auto dtype = config().output_dtype(); | |||||
mgb_assert(dtype.valid()); | |||||
out_dt = dtype; | |||||
} else { | |||||
DType dtype; | |||||
trait.check_out(dtype, false); | |||||
out_dt = dtype; | |||||
} | |||||
auto sopr = static_infer_opr.lock(); | |||||
perform(param().mode, out_dt, dest, inp_vals, sopr()); | |||||
return true; | |||||
}; | |||||
DepVal deps(input().size()); | |||||
for (size_t i = 0; i < input().size(); ++i) | |||||
deps[i] = {input(i), DepType::VALUE}; | |||||
owner_graph()->static_infer_manager().register_value_infer( | |||||
output(0), {SourceType::DEP, deps, infer_value}); | |||||
} | |||||
TensorShape ElemwiseMultiType::get_output_var_shape( | |||||
Mode mode, const TensorShapeArray& input_shapes) { | |||||
mgb_assert(input_shapes.size() == ModeTrait::from_mode(mode).arity); | |||||
TensorShape ret; | |||||
megdnn::Elemwise::deduce_shape(input_shapes, ret); | |||||
return ret; | |||||
} | |||||
void ElemwiseMultiType::call_megdnn_opr_exec( | |||||
CompNode comp_node, megdnn::TensorNDArray& inp, const megdnn::TensorND& out, | |||||
megdnn::ElemwiseMultiType* opr, ElemwiseMultiType* caller) { | |||||
// All Elemwise operations on QuantizedS32/QuantizedS8 are not related to | |||||
// scale. MegDNN does not support computing Elemwise for | |||||
// QuantizedS32/QuantizedS8, we translate the data type to Int32/Int8 before | |||||
// passing to MegDNN. | |||||
if (inp.size() && inp[0].layout.dtype.category() == DTypeCategory::QUANTIZED) { | |||||
auto inp_dtype = inp[0].layout.dtype; | |||||
DType compute_dtype; | |||||
if (inp_dtype.enumv() == DTypeEnum::QuantizedS32) { | |||||
compute_dtype = dtype::Int32(); | |||||
} else if (inp_dtype.enumv() == DTypeEnum::QuantizedS8) { | |||||
compute_dtype = dtype::Int8(); | |||||
} else { | |||||
mgb_throw( | |||||
MegBrainError, "Unsupported Quantized Elemwise Mode %s: %d on %s", | |||||
inp[0].layout.dtype.name(), int(opr->param().mode), | |||||
comp_node.to_string().c_str()); | |||||
} | |||||
megdnn::TensorNDArray run_inp(inp); | |||||
for (size_t i = 0; i < inp.size(); i++) { | |||||
run_inp[i].layout.dtype = compute_dtype; | |||||
} | |||||
megdnn::TensorND run_out = out; | |||||
run_out.layout.dtype = compute_dtype; | |||||
opr->exec(run_inp, run_out); | |||||
return; | |||||
} | |||||
opr->exec(inp, out); | |||||
} | |||||
void ElemwiseMultiType::perform( | |||||
Mode mode, DType out_dt, DeviceTensorND& dest, | |||||
const SmallVector<DeviceTensorND>& inputs, | |||||
intl::UniqPtrWithCN<megdnn::ElemwiseMultiType>& opr) { | |||||
megdnn::TensorNDArray dnn_inputs(inputs.size()); | |||||
TensorShapeArray inp_shapes(inputs.size()); | |||||
CompNode out_cn; | |||||
for (size_t i = 0; i < inputs.size(); ++i) { | |||||
auto&& t = inputs[i]; | |||||
if (!i) { | |||||
out_cn = t.comp_node(); | |||||
} else { | |||||
mgb_assert(t.comp_node() == out_cn); | |||||
} | |||||
if (t.shape().is_empty()) { | |||||
mgb_assert(dest.empty()); | |||||
return; | |||||
} | |||||
inp_shapes[i] = t.shape(); | |||||
} | |||||
if (!opr) { | |||||
opr = intl::create_megdnn_opr<megdnn::ElemwiseMultiType>(out_cn); | |||||
} else { | |||||
mgb_assert(out_cn == opr.comp_node()); | |||||
} | |||||
out_cn.activate(); | |||||
for (size_t i = 0; i < inputs.size(); ++i) | |||||
dnn_inputs[i] = inputs[i].as_megdnn(); | |||||
dest.comp_node(out_cn).dtype(out_dt).resize(get_output_var_shape(mode, inp_shapes)); | |||||
opr->param() = {mode}; | |||||
call_megdnn_opr_exec(out_cn, dnn_inputs, dest.as_megdnn(), opr.get(), nullptr); | |||||
} | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -26,6 +26,14 @@ public: | |||||
const VarNodeArrayView& inputs, Param param, | const VarNodeArrayView& inputs, Param param, | ||||
const OperatorNodeConfig& config = {}); | const OperatorNodeConfig& config = {}); | ||||
MGE_WIN_DECLSPEC_FUC static TensorShape get_output_var_shape( | |||||
Mode mode, const TensorShapeArray& input_shapes); | |||||
MGE_WIN_DECLSPEC_FUC static void perform( | |||||
Mode mode, DType out_dt, DeviceTensorND& dest, | |||||
const SmallVector<DeviceTensorND>& inputs, | |||||
intl::UniqPtrWithCN<megdnn::ElemwiseMultiType>& opr); | |||||
private: | private: | ||||
using ModeTrait = megdnn::ElemwiseMultiType::ModeTrait; | using ModeTrait = megdnn::ElemwiseMultiType::ModeTrait; | ||||
@@ -40,6 +48,14 @@ private: | |||||
void record_execute_deps(ExecDependencyArray& deps) override; | void record_execute_deps(ExecDependencyArray& deps) override; | ||||
void add_input_layout_constraint() override; | void add_input_layout_constraint() override; | ||||
NodeProp* do_make_node_prop() const override; | |||||
void init_output_static_infer_desc() override; | |||||
static void call_megdnn_opr_exec( | |||||
CompNode comp_node, megdnn::TensorNDArray& inp, const megdnn::TensorND& out, | |||||
megdnn::ElemwiseMultiType* opr, ElemwiseMultiType* caller); | |||||
}; | }; | ||||
//! deprecated; TODO: remove in megbrain 8 | //! deprecated; TODO: remove in megbrain 8 | ||||
@@ -509,6 +509,49 @@ public: | |||||
bool is_const() const { return m_is_const; } | bool is_const() const { return m_is_const; } | ||||
}; | }; | ||||
//! global operator instance for static inference | |||||
template <class Opr> | |||||
class StaticInferOpr { | |||||
intl::UniqPtrWithCN<Opr> m_opr; | |||||
MGB_MUTEX m_mtx; | |||||
public: | |||||
class Lock { | |||||
friend class StaticInferOpr; | |||||
StaticInferOpr* m_owner; | |||||
explicit Lock(StaticInferOpr* owner) : m_owner{owner} { | |||||
#if !__DEPLOY_ON_XP_SP2__ | |||||
m_owner->m_mtx.lock(); | |||||
#endif | |||||
} | |||||
public: | |||||
Lock(Lock&& rhs) : m_owner{rhs.m_owner} { rhs.m_owner = nullptr; } | |||||
~Lock() { | |||||
#if !__DEPLOY_ON_XP_SP2__ | |||||
if (m_owner) | |||||
m_owner->m_mtx.unlock(); | |||||
#endif | |||||
} | |||||
Lock& operator=(const Lock&) = delete; | |||||
Lock& operator=(Lock&&) = delete; | |||||
intl::UniqPtrWithCN<Opr>& operator()() { return m_owner->m_opr; } | |||||
}; | |||||
//! lock and acquire the operator | |||||
Lock lock() { | |||||
Lock ret{this}; | |||||
if (!m_opr) { | |||||
m_opr = intl::create_megdnn_opr<Opr>(CompNode::default_cpu()); | |||||
} | |||||
return ret; | |||||
} | |||||
}; | |||||
} // namespace opr | } // namespace opr | ||||
} // namespace mgb | } // namespace mgb | ||||