@@ -337,6 +337,24 @@ DynamicProgrammingSolver::Solution DynamicProgrammingSolver::Impl::solve(
SmallVector<Cut> cuts;
size_t cur = 0;
/* \notes: In the layout selection problem, different operator layout configurations
* will produce tensors with same layout (i.e. same state in the DP problem). This
* means we should only keep the smallest time of all the same states but produced
* by different layout configs. Here we use the following helper function to
* deduplicate the states
*/
auto add_state = [](StateTable& states, const State& state, const Value& value) {
auto iter = states.find(state);
if (iter == states.end()) {
states[state] = value;
} else {
float time = iter->second.time;
if (value.time < time) {
iter->second = value;
}
}
};
/// initialize states
auto init = [&, this](OperatorNodeBase* opr) {
auto it = rst.opr_record.find(opr);
@@ -388,7 +406,7 @@ DynamicProgrammingSolver::Solution DynamicProgrammingSolver::Impl::solve(
ivar_time += min_time;
}
value.time = opr_time + ivar_time + ovar_time;
states[state] = value ;
add_state(states, state, value) ;
}
};
@@ -455,15 +473,7 @@ DynamicProgrammingSolver::Solution DynamicProgrammingSolver::Impl::solve(
ivar_time += min_time;
}
value.time = prev_time + opr_time + ivar_time + ovar_time;
auto iter = states.find(state);
if (iter == states.end()) {
states[state] = value;
} else {
float time = iter->second.time;
if (value.time < time) {
iter->second = value;
}
}
add_state(states, state, value);
}
}
cuts.emplace_back(Cut{});
@@ -481,6 +491,7 @@ DynamicProgrammingSolver::Solution DynamicProgrammingSolver::Impl::solve(
auto& states = cuts.back().states;
prune(states, edges[cur], ctx);
force_prune(states);
}
cur++;
}