|
|
@@ -548,6 +548,19 @@ TEST(TestNormalizeArithChainPass, PowcCExpand2) { |
|
|
|
graph->compile({make_callback_copy(grad, host_g)})); |
|
|
|
} |
|
|
|
|
|
|
|
TEST_PASS(NormalizeArithChainPass, SubDiv) { |
|
|
|
auto x = mkvar("x"), y = mkvar("y"), z = mkvar("z"), |
|
|
|
a0_ = x - y / 2.f, |
|
|
|
a1 = x + (-0.5f) * y, |
|
|
|
b0_ = x - ((y - (z / 5.f)) / 2.f), |
|
|
|
b1 = x + (-0.5f) * y + 0.1f * z; |
|
|
|
|
|
|
|
SymbolVar a0, b0; |
|
|
|
unpack_vector(run_opt({a0_, b0_}), a0, b0); |
|
|
|
EXPECT_EQ(a1, a0); |
|
|
|
EXPECT_EQ(b1, b0); |
|
|
|
} |
|
|
|
|
|
|
|
TEST_PASS(ReorderArithChainPass, 0) { |
|
|
|
auto chk = [this](SymbolVar inp, SymbolVar expect) { |
|
|
|
check(expect, inp, gopt::ConstVarType::IMMUTABLE_AND_PARAM); |
|
|
|