run model OpenCL + ExternCOprRunner, for example
graph: part_a(OpenCL) --> part_b(ExternCOprRunner) --> part_c(OpenCL)
GitOrigin-RevId: f754b559a2
HuaHua404-patch-4
@@ -381,7 +381,6 @@ void ExternCOprRunner::check_param() { | |||||
void ExternCOprRunner::scn_do_execute() { | void ExternCOprRunner::scn_do_execute() { | ||||
SmallVector<MGBTensor> c_inp(input().size()), c_out(output().size()); | SmallVector<MGBTensor> c_inp(input().size()), c_out(output().size()); | ||||
SmallVector<HostTensorND> cpu_inp, cpu_out; | |||||
check_param(); | check_param(); | ||||
bool need_copy = false; | bool need_copy = false; | ||||
@@ -16,6 +16,9 @@ MGB_DEFINE_OPR_CLASS_WITH_EXPORT( | |||||
//! store dynamic store param | //! store dynamic store param | ||||
std::shared_ptr<ExternCOprParam> m_param; | std::shared_ptr<ExternCOprParam> m_param; | ||||
//! HostTensorND holder for scn_do_execute | |||||
SmallVector<HostTensorND> cpu_inp, cpu_out; | |||||
void get_output_var_shape( | void get_output_var_shape( | ||||
const TensorShapeArray& inp_shape, | const TensorShapeArray& inp_shape, | ||||
TensorShapeArray& out_shape) const override; | TensorShapeArray& out_shape) const override; | ||||
@@ -445,6 +445,22 @@ TEST(TestExternCOpr, GPUCompute) { | |||||
run_compute_test(CompNode::load("gpux"), MGB_DTYPE_FLOAT32); | run_compute_test(CompNode::load("gpux"), MGB_DTYPE_FLOAT32); | ||||
} | } | ||||
#if MGB_OPENCL | |||||
#include "megcore_opencl.h" | |||||
#define REQUIRE_OPENCL() \ | |||||
do { \ | |||||
if (!CompNode::get_device_count(CompNode::DeviceType::OPENCL)) { \ | |||||
return; \ | |||||
} \ | |||||
} while (0) | |||||
TEST(TestExternCOpr, OPENCLCompute) { | |||||
REQUIRE_OPENCL(); | |||||
run_compute_test(CompNode::load("openclx"), MGB_DTYPE_FLOAT32); | |||||
} | |||||
#endif | |||||
TEST(TestExternCOpr, CPUComputeMultiDtype) { | TEST(TestExternCOpr, CPUComputeMultiDtype) { | ||||
run_compute_test(CompNode::load("cpux"), MGB_DTYPE_INT32); | run_compute_test(CompNode::load("cpux"), MGB_DTYPE_INT32); | ||||
#if !MEGDNN_DISABLE_FLOAT16 | #if !MEGDNN_DISABLE_FLOAT16 | ||||