|
@@ -240,11 +240,30 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const { |
|
|
&readers](OperatorNodeBase* opr) { |
|
|
&readers](OperatorNodeBase* opr) { |
|
|
ThinHashSet<OperatorNodeBase*> opr_set; |
|
|
ThinHashSet<OperatorNodeBase*> opr_set; |
|
|
ThinHashSet<OperatorNodeBase*> reader_set; |
|
|
ThinHashSet<OperatorNodeBase*> reader_set; |
|
|
|
|
|
// check typecvt |
|
|
|
|
|
auto typecvt = try_cast_as_op<opr::TypeCvt>(opr); |
|
|
|
|
|
if (typecvt == nullptr) |
|
|
|
|
|
return false; |
|
|
|
|
|
auto in_dtype = typecvt->input(0)->dtype(), |
|
|
|
|
|
out_dtype = typecvt->output(0)->dtype(); |
|
|
|
|
|
bool is_s82s4 = in_dtype.enumv() == DTypeEnum::QuantizedS8 && |
|
|
|
|
|
(out_dtype.enumv() == DTypeEnum::QuantizedS4 || |
|
|
|
|
|
out_dtype.enumv() == DTypeEnum::Quantized4Asymm); |
|
|
|
|
|
if (!is_s82s4) |
|
|
|
|
|
return false; |
|
|
|
|
|
opr_set.insert(typecvt); |
|
|
|
|
|
|
|
|
// check reshape |
|
|
// check reshape |
|
|
auto reshape = try_cast_as_op<opr::Reshape>(opr); |
|
|
|
|
|
|
|
|
auto reshape = |
|
|
|
|
|
try_cast_as_op<opr::Reshape>(typecvt->input(0)->owner_opr()); |
|
|
if (reshape == nullptr) |
|
|
if (reshape == nullptr) |
|
|
return false; |
|
|
return false; |
|
|
opr_set.insert(opr); |
|
|
|
|
|
|
|
|
opr_set.insert(reshape); |
|
|
|
|
|
for (auto&& i : readers[reshape]) { |
|
|
|
|
|
if (i.second & DepType::DEV_VALUE) { |
|
|
|
|
|
reader_set.insert(i.first); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
// check dimshuffle |
|
|
// check dimshuffle |
|
|
auto shuffle = |
|
|
auto shuffle = |
|
@@ -267,27 +286,9 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const { |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
auto typecvt = |
|
|
|
|
|
try_cast_as_op<opr::TypeCvt>(shuffle->input(0)->owner_opr()); |
|
|
|
|
|
if (typecvt == nullptr) |
|
|
|
|
|
return false; |
|
|
|
|
|
auto in_dtype = typecvt->input(0)->dtype(), |
|
|
|
|
|
out_dtype = typecvt->output(0)->dtype(); |
|
|
|
|
|
bool is_s82s4 = in_dtype.enumv() == DTypeEnum::QuantizedS8 && |
|
|
|
|
|
(out_dtype.enumv() == DTypeEnum::QuantizedS4 || |
|
|
|
|
|
out_dtype.enumv() == DTypeEnum::Quantized4Asymm); |
|
|
|
|
|
if (!is_s82s4) |
|
|
|
|
|
return false; |
|
|
|
|
|
opr_set.insert(typecvt); |
|
|
|
|
|
for (auto&& i : readers[typecvt]) { |
|
|
|
|
|
if (i.second & DepType::DEV_VALUE) { |
|
|
|
|
|
reader_set.insert(i.first); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// check conv bias |
|
|
// check conv bias |
|
|
auto conv_bias = |
|
|
auto conv_bias = |
|
|
try_cast_as_op<opr::ConvBias>(typecvt->input(0)->owner_opr()); |
|
|
|
|
|
|
|
|
try_cast_as_op<opr::ConvBias>(shuffle->input(0)->owner_opr()); |
|
|
if (conv_bias == nullptr) |
|
|
if (conv_bias == nullptr) |
|
|
return false; |
|
|
return false; |
|
|
auto inp_dtype = conv_bias->input(0)->dtype(); |
|
|
auto inp_dtype = conv_bias->input(0)->dtype(); |
|
|