|
|
@@ -416,4 +416,19 @@ TEST(TestOprMisc, TopKSortedIdxOnly) { |
|
|
|
MGB_ASSERT_TENSOR_EQ(host_gx, *host_y); |
|
|
|
} |
|
|
|
|
|
|
|
TEST(TestOprMisc, TopKGrad) { |
|
|
|
HostTensorGenerator<> gen; |
|
|
|
auto graph = ComputingGraph::make(); |
|
|
|
std::shared_ptr<HostTensorND> host_x = gen({2, 5}); |
|
|
|
std::shared_ptr<HostTensorND> host_k = gen({1}); |
|
|
|
host_k->ptr<float>()[0] = 3; |
|
|
|
auto x = opr::Host2DeviceCopy::make(*graph, host_x), |
|
|
|
k = opr::Host2DeviceCopy::make(*graph, host_k), |
|
|
|
ki = opr::TypeCvt::make(k, dtype::Int32{}), |
|
|
|
val = opr::TopK::make(x, ki, |
|
|
|
opr::TopK::Param::Mode::VALUE_IDX_SORTED)[0], |
|
|
|
gk = cg::grad(opr::reduce_sum(val, val.make_scalar(1)), ki, true, false); |
|
|
|
EXPECT_TRUE(gk == nullptr); |
|
|
|
} |
|
|
|
|
|
|
|
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |