GitOrigin-RevId: 05577a8bc8
HuaHua404-patch-4
@@ -37,7 +37,6 @@ _ElwMod = builtin.Elemwise.Mode | |||
def _elemwise_multi_type(*args, mode, **kwargs): | |||
op = builtin.ElemwiseMultiType(mode=mode, **kwargs) | |||
args = convert_inputs(*args) | |||
(result,) = apply(op, *args) | |||
return result | |||
@@ -249,22 +248,22 @@ class ArrayMethodMixin(abc.ABC): | |||
__hash__ = None # due to __eq__ diviates from python convention | |||
__lt__ = lambda self, value: _elemwise_multi_type( | |||
self, value, mode="lt", dtype="Bool" | |||
self, value, mode="lt", dtype="bool" | |||
) | |||
__le__ = lambda self, value: _elemwise_multi_type( | |||
self, value, mode="leq", dtype="Bool" | |||
self, value, mode="leq", dtype="bool" | |||
) | |||
__gt__ = lambda self, value: _elemwise_multi_type( | |||
value, self, mode="lt", dtype="Bool" | |||
value, self, mode="lt", dtype="bool" | |||
) | |||
__ge__ = lambda self, value: _elemwise_multi_type( | |||
value, self, mode="leq", dtype="Bool" | |||
value, self, mode="leq", dtype="bool" | |||
) | |||
__eq__ = lambda self, value: _elemwise_multi_type( | |||
self, value, mode="eq", dtype="Bool" | |||
self, value, mode="eq", dtype="bool" | |||
) | |||
__ne__ = lambda self, value: _elemwise_multi_type( | |||
self, value, mode="neq", dtype="Bool" | |||
self, value, mode="neq", dtype="bool" | |||
) | |||
__neg__ = _unary_elwise(_ElwMod.NEGATE) | |||
@@ -52,7 +52,7 @@ def isnan(inp: Tensor) -> Tensor: | |||
>>> F.isnan(x).numpy() | |||
array([False, True, False]) | |||
""" | |||
return _elemwise_multi_type(inp, mode="isnan", dtype="Bool") | |||
return _elemwise_multi_type(inp, mode="isnan", dtype="bool") | |||
def isinf(inp: Tensor) -> Tensor: | |||
@@ -69,7 +69,7 @@ def isinf(inp: Tensor) -> Tensor: | |||
>>> F.isinf(x).numpy() | |||
array([False, True, False]) | |||
""" | |||
return _elemwise_multi_type(inp, mode="isinf", dtype="Bool") | |||
return _elemwise_multi_type(inp, mode="isinf", dtype="bool") | |||
def sign(inp: Tensor): | |||
@@ -118,7 +118,7 @@ PyObject* py_apply( | |||
tensors[i] = tw->m_tensor->data(); | |||
} else if ( | |||
DTypePromoteCfg::convert_input_enabled && | |||
op->same_type<Elemwise>()) { | |||
(op->same_type<Elemwise>() || op->same_type<ElemwiseMultiType>())) { | |||
tensors[i] = convert_pyinput_to_tensor(i); | |||
} else { | |||
PyErr_SetString(PyExc_TypeError, "py_apply expects tensor as inputs"); | |||
@@ -53,6 +53,41 @@ mgb::DType get_promoted_dtype(const SmallVector<DType>& dtypes) { | |||
return ret; | |||
} | |||
ValueRefList elemwise_multi_type_rule(const OpDef& op, Span<ValueRef> inputs) { | |||
auto&& elem_op = op.cast_final_safe<ElemwiseMultiType>(); | |||
static std::unordered_set<ElemwiseMultiType::Mode> cast_case = { | |||
ElemwiseMultiType::Mode::EQ, | |||
ElemwiseMultiType::Mode::NEQ, | |||
ElemwiseMultiType::Mode::LT, | |||
ElemwiseMultiType::Mode::LEQ, | |||
}; | |||
if (cast_case.find(elem_op.mode) == cast_case.end()) { | |||
return imperative::apply(op, inputs); | |||
} | |||
SmallVector<DType> dtypes(inputs.size()); | |||
for (size_t i = 0; i < inputs.size(); ++i) { | |||
dtypes[i] = *(inputs[i].dtype()); | |||
} | |||
ValueRefList converted(inputs.size()); | |||
mgb::DType target_dtype = get_promoted_dtype(dtypes); | |||
for (size_t i = 0; i < inputs.size(); ++i) { | |||
if (!is_quantized_dtype(dtypes[i]) && dtypes[i] != target_dtype && | |||
DTypePromoteCfg::convert_input_enabled) { | |||
converted[i] = imperative::apply( | |||
ApplyOp(*TypeCvt::make(target_dtype)), inputs[i])[0]; | |||
dtypes[i] = target_dtype; | |||
} else { | |||
converted[i] = inputs[i]; | |||
} | |||
} | |||
return imperative::apply(op, converted); | |||
} | |||
ValueRefList elemwise_rule(const OpDef& op, Span<ValueRef> inputs) { | |||
auto&& elem_op = op.cast_final_safe<Elemwise>(); | |||
@@ -349,6 +384,7 @@ ValueRefList naive_promote_rule(const OpDef& op, Span<ValueRef> inputs) { | |||
struct DTypePromoteRuleRegistry { | |||
DTypePromoteRuleRegistry() { | |||
register_dtype_promote_rule<Elemwise>(elemwise_rule); | |||
register_dtype_promote_rule<ElemwiseMultiType>(elemwise_multi_type_rule); | |||
register_dtype_promote_rule<Concat>(naive_promote_rule); | |||
register_dtype_promote_rule<GroupLocal>(naive_promote_rule); | |||
register_dtype_promote_rule<Reduce>(reduce_rule); | |||
@@ -16,52 +16,6 @@ | |||
using namespace mgb; | |||
using namespace opr; | |||
namespace { | |||
//! global operator instance for static inference | |||
template <class Opr> | |||
class StaticInferOpr { | |||
intl::UniqPtrWithCN<Opr> m_opr; | |||
MGB_MUTEX m_mtx; | |||
public: | |||
class Lock { | |||
friend class StaticInferOpr; | |||
StaticInferOpr* m_owner; | |||
explicit Lock(StaticInferOpr* owner) : m_owner{owner} { | |||
#if !__DEPLOY_ON_XP_SP2__ | |||
m_owner->m_mtx.lock(); | |||
#endif | |||
} | |||
public: | |||
Lock(Lock&& rhs) : m_owner{rhs.m_owner} { rhs.m_owner = nullptr; } | |||
~Lock() { | |||
#if !__DEPLOY_ON_XP_SP2__ | |||
if (m_owner) | |||
m_owner->m_mtx.unlock(); | |||
#endif | |||
} | |||
Lock& operator=(const Lock&) = delete; | |||
Lock& operator=(Lock&&) = delete; | |||
intl::UniqPtrWithCN<Opr>& operator()() { return m_owner->m_opr; } | |||
}; | |||
//! lock and acquire the operator | |||
Lock lock() { | |||
Lock ret{this}; | |||
if (!m_opr) { | |||
m_opr = intl::create_megdnn_opr<Opr>(CompNode::default_cpu()); | |||
} | |||
return ret; | |||
} | |||
}; | |||
} // anonymous namespace | |||
/* ========================= BatchedDTypePromotion ========================= */ | |||
intl::BatchedDTypePromotion::BatchedDTypePromotion(const VarNodeArrayView& vars) | |||
: m_orig_vars{vars} { | |||
@@ -1,6 +1,6 @@ | |||
#include "megbrain/opr/nn_int.h" | |||
#include "./internal/megdnn_opr_wrapper.inl" | |||
#include "megbrain/opr/utility.h" | |||
#include "megdnn/oprs/general.h" | |||
using namespace mgb; | |||
@@ -18,6 +18,7 @@ ElemwiseMultiType::ElemwiseMultiType( | |||
for (auto i : inputs) { | |||
add_input({i}); | |||
} | |||
output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); | |||
} | |||
SymbolVar ElemwiseMultiType::make( | |||
@@ -52,8 +53,13 @@ void ElemwiseMultiType::init_output_dtype() { | |||
void ElemwiseMultiType::scn_do_execute() { | |||
megdnn::TensorNDArray inp_arr(input().size()); | |||
for (size_t i = 0; i < input().size(); ++i) { | |||
if (input()[i]->dev_tensor().empty()) { | |||
mgb_assert(output(0)->dev_tensor().empty()); | |||
return; | |||
} | |||
inp_arr[i] = input()[i]->dev_tensor().as_megdnn(); | |||
} | |||
mgb_assert(!output(0)->dev_tensor().empty()); | |||
megdnn_opr()->exec(inp_arr, output(0)->dev_tensor().as_megdnn()); | |||
} | |||
@@ -75,4 +81,120 @@ void ElemwiseMultiType::add_input_layout_constraint() { | |||
#endif | |||
} | |||
ElemwiseMultiType::NodeProp* ElemwiseMultiType::do_make_node_prop() const { | |||
auto ret = Super::do_make_node_prop(); | |||
for (auto& inp : input()) { | |||
ret->add_dep_type_existing_var(inp, NodeProp::DepType::VALUE_ALLOW_EMPTY); | |||
} | |||
return ret; | |||
} | |||
void ElemwiseMultiType::init_output_static_infer_desc() { | |||
Super::init_output_static_infer_desc(); | |||
static StaticInferOpr<megdnn::ElemwiseMultiType> static_infer_opr; | |||
using namespace cg::static_infer; | |||
auto infer_value = [this](DeviceTensorND& dest, const InpVal& inp) { | |||
SmallVector<DeviceTensorND> inp_vals(inp.val.size()); | |||
for (size_t i = 0; i < inp_vals.size(); ++i) | |||
inp_vals[i] = inp.val[i].value(); | |||
DType out_dt; | |||
auto trait = ModeTrait::from_mode(param().mode); | |||
if (trait.need_specify_out_dtype) { | |||
auto dtype = config().output_dtype(); | |||
mgb_assert(dtype.valid()); | |||
out_dt = dtype; | |||
} else { | |||
DType dtype; | |||
trait.check_out(dtype, false); | |||
out_dt = dtype; | |||
} | |||
auto sopr = static_infer_opr.lock(); | |||
perform(param().mode, out_dt, dest, inp_vals, sopr()); | |||
return true; | |||
}; | |||
DepVal deps(input().size()); | |||
for (size_t i = 0; i < input().size(); ++i) | |||
deps[i] = {input(i), DepType::VALUE}; | |||
owner_graph()->static_infer_manager().register_value_infer( | |||
output(0), {SourceType::DEP, deps, infer_value}); | |||
} | |||
TensorShape ElemwiseMultiType::get_output_var_shape( | |||
Mode mode, const TensorShapeArray& input_shapes) { | |||
mgb_assert(input_shapes.size() == ModeTrait::from_mode(mode).arity); | |||
TensorShape ret; | |||
megdnn::Elemwise::deduce_shape(input_shapes, ret); | |||
return ret; | |||
} | |||
void ElemwiseMultiType::call_megdnn_opr_exec( | |||
CompNode comp_node, megdnn::TensorNDArray& inp, const megdnn::TensorND& out, | |||
megdnn::ElemwiseMultiType* opr, ElemwiseMultiType* caller) { | |||
// All Elemwise operations on QuantizedS32/QuantizedS8 are not related to | |||
// scale. MegDNN does not support computing Elemwise for | |||
// QuantizedS32/QuantizedS8, we translate the data type to Int32/Int8 before | |||
// passing to MegDNN. | |||
if (inp.size() && inp[0].layout.dtype.category() == DTypeCategory::QUANTIZED) { | |||
auto inp_dtype = inp[0].layout.dtype; | |||
DType compute_dtype; | |||
if (inp_dtype.enumv() == DTypeEnum::QuantizedS32) { | |||
compute_dtype = dtype::Int32(); | |||
} else if (inp_dtype.enumv() == DTypeEnum::QuantizedS8) { | |||
compute_dtype = dtype::Int8(); | |||
} else { | |||
mgb_throw( | |||
MegBrainError, "Unsupported Quantized Elemwise Mode %s: %d on %s", | |||
inp[0].layout.dtype.name(), int(opr->param().mode), | |||
comp_node.to_string().c_str()); | |||
} | |||
megdnn::TensorNDArray run_inp(inp); | |||
for (size_t i = 0; i < inp.size(); i++) { | |||
run_inp[i].layout.dtype = compute_dtype; | |||
} | |||
megdnn::TensorND run_out = out; | |||
run_out.layout.dtype = compute_dtype; | |||
opr->exec(run_inp, run_out); | |||
return; | |||
} | |||
opr->exec(inp, out); | |||
} | |||
void ElemwiseMultiType::perform( | |||
Mode mode, DType out_dt, DeviceTensorND& dest, | |||
const SmallVector<DeviceTensorND>& inputs, | |||
intl::UniqPtrWithCN<megdnn::ElemwiseMultiType>& opr) { | |||
megdnn::TensorNDArray dnn_inputs(inputs.size()); | |||
TensorShapeArray inp_shapes(inputs.size()); | |||
CompNode out_cn; | |||
for (size_t i = 0; i < inputs.size(); ++i) { | |||
auto&& t = inputs[i]; | |||
if (!i) { | |||
out_cn = t.comp_node(); | |||
} else { | |||
mgb_assert(t.comp_node() == out_cn); | |||
} | |||
if (t.shape().is_empty()) { | |||
mgb_assert(dest.empty()); | |||
return; | |||
} | |||
inp_shapes[i] = t.shape(); | |||
} | |||
if (!opr) { | |||
opr = intl::create_megdnn_opr<megdnn::ElemwiseMultiType>(out_cn); | |||
} else { | |||
mgb_assert(out_cn == opr.comp_node()); | |||
} | |||
out_cn.activate(); | |||
for (size_t i = 0; i < inputs.size(); ++i) | |||
dnn_inputs[i] = inputs[i].as_megdnn(); | |||
dest.comp_node(out_cn).dtype(out_dt).resize(get_output_var_shape(mode, inp_shapes)); | |||
opr->param() = {mode}; | |||
call_megdnn_opr_exec(out_cn, dnn_inputs, dest.as_megdnn(), opr.get(), nullptr); | |||
} | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -26,6 +26,14 @@ public: | |||
const VarNodeArrayView& inputs, Param param, | |||
const OperatorNodeConfig& config = {}); | |||
MGE_WIN_DECLSPEC_FUC static TensorShape get_output_var_shape( | |||
Mode mode, const TensorShapeArray& input_shapes); | |||
MGE_WIN_DECLSPEC_FUC static void perform( | |||
Mode mode, DType out_dt, DeviceTensorND& dest, | |||
const SmallVector<DeviceTensorND>& inputs, | |||
intl::UniqPtrWithCN<megdnn::ElemwiseMultiType>& opr); | |||
private: | |||
using ModeTrait = megdnn::ElemwiseMultiType::ModeTrait; | |||
@@ -40,6 +48,14 @@ private: | |||
void record_execute_deps(ExecDependencyArray& deps) override; | |||
void add_input_layout_constraint() override; | |||
NodeProp* do_make_node_prop() const override; | |||
void init_output_static_infer_desc() override; | |||
static void call_megdnn_opr_exec( | |||
CompNode comp_node, megdnn::TensorNDArray& inp, const megdnn::TensorND& out, | |||
megdnn::ElemwiseMultiType* opr, ElemwiseMultiType* caller); | |||
}; | |||
//! deprecated; TODO: remove in megbrain 8 | |||
@@ -509,6 +509,49 @@ public: | |||
bool is_const() const { return m_is_const; } | |||
}; | |||
//! global operator instance for static inference | |||
template <class Opr> | |||
class StaticInferOpr { | |||
intl::UniqPtrWithCN<Opr> m_opr; | |||
MGB_MUTEX m_mtx; | |||
public: | |||
class Lock { | |||
friend class StaticInferOpr; | |||
StaticInferOpr* m_owner; | |||
explicit Lock(StaticInferOpr* owner) : m_owner{owner} { | |||
#if !__DEPLOY_ON_XP_SP2__ | |||
m_owner->m_mtx.lock(); | |||
#endif | |||
} | |||
public: | |||
Lock(Lock&& rhs) : m_owner{rhs.m_owner} { rhs.m_owner = nullptr; } | |||
~Lock() { | |||
#if !__DEPLOY_ON_XP_SP2__ | |||
if (m_owner) | |||
m_owner->m_mtx.unlock(); | |||
#endif | |||
} | |||
Lock& operator=(const Lock&) = delete; | |||
Lock& operator=(Lock&&) = delete; | |||
intl::UniqPtrWithCN<Opr>& operator()() { return m_owner->m_opr; } | |||
}; | |||
//! lock and acquire the operator | |||
Lock lock() { | |||
Lock ret{this}; | |||
if (!m_opr) { | |||
m_opr = intl::create_megdnn_opr<Opr>(CompNode::default_cpu()); | |||
} | |||
return ret; | |||
} | |||
}; | |||
} // namespace opr | |||
} // namespace mgb | |||