|
|
@@ -269,7 +269,20 @@ void ChannelImpl::dispatch_default_cpu( |
|
|
|
|
|
|
|
uint64_t op_id = Profiler::next_id(); |
|
|
|
|
|
|
|
OpDef::apply_on_device_tensornd(*op, input_tensornds, &output_tensornds); |
|
|
|
if (op->trait()->apply_on_device_tensornd) { |
|
|
|
OpDef::apply_on_device_tensornd(*op, input_tensornds, &output_tensornds); |
|
|
|
} else { |
|
|
|
// proxy to apply_on_physical_tensor |
|
|
|
SmallVector<TensorPtr> input_tensors; |
|
|
|
for (auto&& input_tensornd : input_tensornds) { |
|
|
|
input_tensors.push_back(Tensor::make( |
|
|
|
input_tensornd, HostTensorND::make_proxy(input_tensornd))); |
|
|
|
} |
|
|
|
auto output_tensors = OpDef::apply_on_physical_tensor(*op, input_tensors); |
|
|
|
for (size_t i = 0; i < output_tensors.size(); ++i) { |
|
|
|
output_tensornds[i].copy_from_fixlayout(output_tensors[i]->dev_tensor()); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
SmallVector<TensorInfo*> output_infos; |
|
|
|
output_infos.reserve(output_descs.size()); |
|
|
@@ -357,6 +370,17 @@ SmallVector<Handle> ChannelImpl::apply_op( |
|
|
|
std::shared_ptr<OpDef> op, const SmallVector<Handle>& inputs) { |
|
|
|
MGB_LOCK_GUARD(m_spin); |
|
|
|
mgb_assert(check_available(), "Channel already closed"); |
|
|
|
auto* input = reinterpret_cast<TensorInfo*>(inputs[0]); |
|
|
|
if (op->same_type<GetVarShape>() && input->desc.layout.ndim) { |
|
|
|
size_t ndim = input->desc.layout.ndim; |
|
|
|
auto& gvs = op->cast_final_safe<GetVarShape>(); |
|
|
|
if (gvs.axis == MEGDNN_MAX_NDIM) { |
|
|
|
HostTensorND shape_tensor{input->desc.comp_node, {ndim}, dtype::Int32()}; |
|
|
|
DeviceTensorND shape_tensor_device = shape_tensor.proxy_to_default_cpu(); |
|
|
|
cg::copy_shape_to_tensor_value(shape_tensor_device, input->desc.layout); |
|
|
|
return {reinterpret_cast<Handle>(put_impl(shape_tensor, false))}; |
|
|
|
} |
|
|
|
} |
|
|
|
return apply_op_impl(std::move(op), inputs); |
|
|
|
} |
|
|
|
|
|
|
@@ -621,6 +645,12 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) { |
|
|
|
TensorProduceEvent, dest->id, ptr->layout(), ptr->comp_node(), |
|
|
|
ptr->dev_tensor().raw_ptr()); |
|
|
|
// update tensor desc for static infer |
|
|
|
if (dest->desc.layout.ndim) { |
|
|
|
mgb_assert( |
|
|
|
dest->desc.layout.eq_shape(ptr->layout()), |
|
|
|
"shape infer error, %s vs %s", dest->desc.layout.to_string().c_str(), |
|
|
|
ptr->layout().to_string().c_str()); |
|
|
|
} |
|
|
|
dest->desc.layout = ptr->layout(); |
|
|
|
dest->desc.comp_node = ptr->comp_node(); |
|
|
|
dest->memory = ptr->blob()->size(); |
|
|
|