|
|
@@ -525,7 +525,34 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { |
|
|
|
} |
|
|
|
mgb_assert(tup.size() == 7); |
|
|
|
if (auto* t = try_cast(tup[0].ptr())) { |
|
|
|
m_tensor = t->m_tensor->copy(); |
|
|
|
m_tensor = t->m_tensor; |
|
|
|
// TODO: merge two path in arg parse |
|
|
|
if (!tup[1].is_none()) { |
|
|
|
auto dtype = tup[1].cast<DType>(); |
|
|
|
mgb_assert( |
|
|
|
dtype == m_tensor->dtype(), "dtype mismatch: %s vs %s", |
|
|
|
dtype.name(), m_tensor->dtype().name()); |
|
|
|
} |
|
|
|
if (!tup[2].is_none()) { |
|
|
|
auto device = as_comp_node(tup[2]); |
|
|
|
mgb_assert( |
|
|
|
device == m_tensor->comp_node(), "device mismatch: %s vs %s", |
|
|
|
device.to_string().c_str(), |
|
|
|
m_tensor->comp_node().to_string().c_str()); |
|
|
|
} |
|
|
|
mgb_assert(!tup[3].cast<bool>(), "expect is_const == False, got True"); |
|
|
|
bool no_cache = tup[4].cast<bool>(); |
|
|
|
if (no_cache) { |
|
|
|
// always copy because it's hard to tell whether this tensor is cached |
|
|
|
m_tensor = m_tensor->copy(); |
|
|
|
} |
|
|
|
// ignore name |
|
|
|
if (!tup[6].is_none()) { |
|
|
|
Format format = tup[6].cast<std::string>(); |
|
|
|
mgb_assert( |
|
|
|
format == m_tensor->format(), "format mismatch: %s vs %s", |
|
|
|
format.to_string().c_str(), m_tensor->format().to_string().c_str()); |
|
|
|
} |
|
|
|
} else { |
|
|
|
auto data = tup[0]; |
|
|
|
DType dtype = tup[1].cast<DType>(); |
|
|
@@ -1030,7 +1057,7 @@ void init_tensor(py::module m) { |
|
|
|
try { |
|
|
|
self.compiled->compile(); |
|
|
|
} catch (const std::exception& e) { |
|
|
|
mgb_log_error(e.what()); |
|
|
|
mgb_log_error("error in trace: %s", e.what()); |
|
|
|
} |
|
|
|
} |
|
|
|
// register transformations |
|
|
|