|
@@ -16,7 +16,7 @@ private: |
|
|
public: |
|
|
public: |
|
|
OprProxy() = default; |
|
|
OprProxy() = default; |
|
|
OprProxy(int k) : m_k{k} {} |
|
|
OprProxy(int k) : m_k{k} {} |
|
|
void init(TopK*, const TensorLayoutArray&) {} |
|
|
|
|
|
|
|
|
void init(TopK*, const TensorNDArray&) {} |
|
|
|
|
|
|
|
|
void deduce_layout(TopK* opr, TensorLayoutArray& layouts) { |
|
|
void deduce_layout(TopK* opr, TensorLayoutArray& layouts) { |
|
|
if (layouts.size() == 3) { |
|
|
if (layouts.size() == 3) { |
|
|