|
|
@@ -10,8 +10,8 @@ |
|
|
|
* implied. |
|
|
|
*/ |
|
|
|
|
|
|
|
#include <numeric> |
|
|
|
#include "megbrain/gopt/reformat_emitter.h" |
|
|
|
#include <numeric> |
|
|
|
#include "megbrain/opr/tensor_manip.h" |
|
|
|
|
|
|
|
using namespace mgb; |
|
|
@@ -19,34 +19,7 @@ using namespace gopt; |
|
|
|
using Dimension = megdnn::Dimension; |
|
|
|
using NamedTensorShape = megdnn::NamedTensorShape; |
|
|
|
|
|
|
|
ReshapeEmitter::Operator ReshapeEmitter::emit() const { |
|
|
|
auto pattern = analyze(); |
|
|
|
auto op = [pattern](VarNode* var) { |
|
|
|
auto sym_var = SymbolVar(var); |
|
|
|
auto shp = opr::GetVarShape::make(sym_var); |
|
|
|
auto cv = [&sym_var](int c) { return sym_var.make_scalar(c); }; |
|
|
|
auto sub = [&shp, &cv](int ax) { |
|
|
|
return opr::IndexAt::make(shp, {{0, cv(ax)}}); |
|
|
|
}; |
|
|
|
SymbolVarArray axs; |
|
|
|
for (auto i : pattern) { |
|
|
|
if (std::get<0>(i) >= 0) { |
|
|
|
if (std::get<2>(i)) |
|
|
|
axs.emplace_back(sub(std::get<0>(i)) * std::get<1>(i)); |
|
|
|
else |
|
|
|
axs.emplace_back(sub(std::get<0>(i)) / std::get<1>(i)); |
|
|
|
} else { |
|
|
|
axs.emplace_back(cv(std::get<1>(i))); |
|
|
|
} |
|
|
|
} |
|
|
|
auto tshp = opr::Concat::make(axs, 0); |
|
|
|
auto ovar = opr::Reshape::make(sym_var, tshp); |
|
|
|
return ovar.node(); |
|
|
|
}; |
|
|
|
return op; |
|
|
|
} |
|
|
|
|
|
|
|
SmallVector<std::tuple<int, int, bool>> ReshapeEmitter::analyze() const { |
|
|
|
ModifyShapeMixin::Pattern ModifyShapeMixin::mixin_analyze() const { |
|
|
|
static constexpr uint32_t UNDETERMINED_EXTENT = |
|
|
|
Dimension::UNDETERMINED_EXTENT; |
|
|
|
ThinHashMap<Dimension::Name, int> name2dominant; |
|
|
@@ -58,7 +31,7 @@ SmallVector<std::tuple<int, int, bool>> ReshapeEmitter::analyze() const { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
SmallVector<std::tuple<int, int, bool>> pattern(m_dest.ndim); |
|
|
|
Pattern pattern(m_dest.ndim); |
|
|
|
for (size_t i = 0; i < m_dest.ndim; ++i) { |
|
|
|
auto name = m_dest[i].name(); |
|
|
|
if (m_dest[i].extent() == UNDETERMINED_EXTENT) { |
|
|
@@ -74,28 +47,90 @@ SmallVector<std::tuple<int, int, bool>> ReshapeEmitter::analyze() const { |
|
|
|
return pattern; |
|
|
|
} |
|
|
|
|
|
|
|
DimshuffleEmitter::Operator DimshuffleEmitter::emit() const { |
|
|
|
auto pattern = m_pattern; |
|
|
|
auto op = [pattern](VarNode* var) { |
|
|
|
ModifyShapeMixin::Checker ModifyShapeMixin::mixin_emit_checker( |
|
|
|
const Pattern& pattern) const { |
|
|
|
auto src = m_src; |
|
|
|
auto checker = [src, pattern](VarNode* var) { |
|
|
|
const auto& shp = var->shape(); |
|
|
|
if (shp.ndim != src.ndim) |
|
|
|
return false; |
|
|
|
bool available = true; |
|
|
|
for (size_t i = 0; i < shp.ndim; ++i) { |
|
|
|
if (src[i].extent() != Dimension::UNDETERMINED_EXTENT) { |
|
|
|
available &= (shp[i] == src[i].extent()); |
|
|
|
} |
|
|
|
} |
|
|
|
for (auto&& i : pattern) { |
|
|
|
int axis, factor; |
|
|
|
bool mul; |
|
|
|
std::tie(axis, factor, mul) = i; |
|
|
|
if (axis >= 0 && !mul) { |
|
|
|
available &= (shp[axis] % factor == 0); |
|
|
|
} |
|
|
|
} |
|
|
|
return available; |
|
|
|
}; |
|
|
|
return checker; |
|
|
|
} |
|
|
|
|
|
|
|
ReshapeEmitter::EmitResult ReshapeEmitter::emit() const { |
|
|
|
auto pattern = mixin_analyze(); |
|
|
|
auto builder = [pattern](VarNode* var) { |
|
|
|
auto sym_var = SymbolVar(var); |
|
|
|
auto shp = opr::GetVarShape::make(sym_var); |
|
|
|
auto cv = [&sym_var](int c) { return sym_var.make_scalar(c); }; |
|
|
|
auto sub = [&shp, &cv](int ax) { |
|
|
|
return opr::IndexAt::make(shp, {{0, cv(ax)}}); |
|
|
|
}; |
|
|
|
SymbolVarArray axs; |
|
|
|
for (auto&& i : pattern) { |
|
|
|
int axis, factor; |
|
|
|
bool mul; |
|
|
|
std::tie(axis, factor, mul) = i; |
|
|
|
if (axis >= 0) { |
|
|
|
if (mul) |
|
|
|
axs.emplace_back(sub(axis) * factor); |
|
|
|
else |
|
|
|
axs.emplace_back(sub(axis) / factor); |
|
|
|
} else { |
|
|
|
axs.emplace_back(cv(factor)); |
|
|
|
} |
|
|
|
} |
|
|
|
auto tshp = opr::Concat::make(axs, 0); |
|
|
|
auto ovar = opr::Reshape::make(sym_var, tshp); |
|
|
|
return ovar.node(); |
|
|
|
}; |
|
|
|
auto checker = mixin_emit_checker(pattern); |
|
|
|
return std::make_tuple(builder, checker); |
|
|
|
} |
|
|
|
|
|
|
|
DimshuffleEmitter::EmitResult DimshuffleEmitter::emit() const { |
|
|
|
auto&& pattern = m_pattern; |
|
|
|
auto builder = [pattern](VarNode* var) { |
|
|
|
auto sym_var = SymbolVar(var); |
|
|
|
return opr::Dimshuffle::make(sym_var, pattern).node(); |
|
|
|
}; |
|
|
|
return op; |
|
|
|
auto checker = [pattern](VarNode* var) { |
|
|
|
return var->shape().ndim == pattern.size(); |
|
|
|
}; |
|
|
|
return std::make_tuple(builder, checker); |
|
|
|
} |
|
|
|
|
|
|
|
ReformatEmitter::Operator ReformatEmitter::emit() const { |
|
|
|
ReformatEmitter::EmitResult ReformatEmitter::emit() const { |
|
|
|
auto ops = analyze(); |
|
|
|
auto op = [ops](VarNode* var) { |
|
|
|
auto builder = [ops](VarNode* var) { |
|
|
|
VarNode* ovar = var; |
|
|
|
for (const auto& o : ops) { |
|
|
|
ovar = o(ovar); |
|
|
|
for (const auto& i : ops) { |
|
|
|
ovar = i(ovar); |
|
|
|
} |
|
|
|
return ovar; |
|
|
|
}; |
|
|
|
return op; |
|
|
|
auto pattern = mixin_analyze(); |
|
|
|
auto checker = mixin_emit_checker(pattern); |
|
|
|
return std::make_tuple(builder, checker); |
|
|
|
} |
|
|
|
|
|
|
|
SmallVector<ReformatEmitter::Operator> ReformatEmitter::analyze() const { |
|
|
|
SmallVector<ReformatEmitter::Builder> ReformatEmitter::analyze() const { |
|
|
|
struct Dim { |
|
|
|
Dimension dim; |
|
|
|
int index; |
|
|
@@ -161,12 +196,12 @@ SmallVector<ReformatEmitter::Operator> ReformatEmitter::analyze() const { |
|
|
|
i1[i] = src_dims[src_perm[i]].dim; |
|
|
|
i2[i] = src_dims[src_perm[permute[i]]].dim; |
|
|
|
} |
|
|
|
SmallVector<Operator> ops; |
|
|
|
SmallVector<Builder> ops; |
|
|
|
if (!m_src.eq_shape(i1)) |
|
|
|
ops.emplace_back(ReshapeEmitter(m_src, i1).emit()); |
|
|
|
ops.emplace_back(DimshuffleEmitter(permute).emit()); |
|
|
|
ops.emplace_back(std::get<0>(ReshapeEmitter(m_src, i1).emit())); |
|
|
|
ops.emplace_back(std::get<0>(DimshuffleEmitter(permute).emit())); |
|
|
|
if (!m_dest.eq_shape(i2)) |
|
|
|
ops.emplace_back(ReshapeEmitter(i2, m_dest).emit()); |
|
|
|
ops.emplace_back(std::get<0>(ReshapeEmitter(i2, m_dest).emit())); |
|
|
|
return ops; |
|
|
|
} |
|
|
|
|
|
|
|
// vim: syntax=cpp.doxygen |