|
|
@@ -16,39 +16,74 @@ from ..ops.special import Const |
|
|
|
from ..tensor.core import OpBase, TensorBase, TensorWrapperBase, apply |
|
|
|
|
|
|
|
|
|
|
|
def dtype_promotion(raw_inputs): |
|
|
|
def add_dtype(i): |
|
|
|
if type(i) == int: |
|
|
|
return np.array(i, dtype=np.int32) |
|
|
|
if type(i) == float: |
|
|
|
return np.array(i, dtype=np.float32) |
|
|
|
if type(i) == bool: |
|
|
|
return np.array(i, dtype=np.bool_) |
|
|
|
return None |
|
|
|
|
|
|
|
scalar_inputs = [ |
|
|
|
add_dtype(i) for i in raw_inputs if not hasattr(i, "dtype") and add_dtype(i) |
|
|
|
] |
|
|
|
inputs = [i for i in raw_inputs if hasattr(i, "dtype")] |
|
|
|
assert len(scalar_inputs + inputs) > 0 |
|
|
|
dtype = None |
|
|
|
if len(inputs) > 0: |
|
|
|
dtype = np.result_type(*inputs) |
|
|
|
dtype_all = np.result_type(*(inputs + scalar_inputs)) |
|
|
|
assert ( |
|
|
|
dtype != np.float64 and dtype != np.int64 |
|
|
|
), "unsupport dtype {} by dtype_promotion, please use explict type convert".format( |
|
|
|
dtype |
|
|
|
) |
|
|
|
if dtype_all == np.bool_: |
|
|
|
for i in raw_inputs: |
|
|
|
if not hasattr(i, "dtype") or i.dtype != np.bool_: |
|
|
|
raise TypeError( |
|
|
|
"bool dtype can not be operated with an element without bool dtype" |
|
|
|
) |
|
|
|
if dtype_all == np.float64: |
|
|
|
dtype_all = np.float32 |
|
|
|
return dtype_all |
|
|
|
def dtype_promotion(inputs): |
|
|
|
""" |
|
|
|
Returns the dtype that would result from performing an arithmetic |
|
|
|
operation on the provided input tensors and scalars. |
|
|
|
""" |
|
|
|
# map numpy.dtype.kind to priority |
|
|
|
category_priority = { |
|
|
|
"f": 3, # floating-point |
|
|
|
"i": 2, # signed integer |
|
|
|
"u": 2, # unsigned integer |
|
|
|
"b": 1, # boolean |
|
|
|
} |
|
|
|
|
|
|
|
def scalar2dtype(x): |
|
|
|
""" |
|
|
|
For scalar `x`, returns its corresponding type. A floating point scalar |
|
|
|
has dtype 'float32'. An integral non-boolean scalar has dtype 'int32'. |
|
|
|
A boolean scalar has dtype 'bool'. |
|
|
|
""" |
|
|
|
if isinstance(x, bool): |
|
|
|
return np.bool_ |
|
|
|
if isinstance(x, int): |
|
|
|
return np.int32 |
|
|
|
if isinstance(x, float): |
|
|
|
return np.float32 |
|
|
|
|
|
|
|
def promote_types(types, cat): |
|
|
|
""" |
|
|
|
Returns the data type with sufficient size to hold all types of |
|
|
|
category `cat` in the list `types`. |
|
|
|
""" |
|
|
|
used_types = [ |
|
|
|
i for i in types if category_priority.get(np.dtype(i).kind, 0) == cat |
|
|
|
] |
|
|
|
assert len(used_types) > 0 |
|
|
|
res = used_types[0] |
|
|
|
for i in used_types: |
|
|
|
res = np.promote_types(res, i) |
|
|
|
return res |
|
|
|
|
|
|
|
def max_priority(types): |
|
|
|
""" |
|
|
|
Returns the maximum value of the priority of each type in the list |
|
|
|
`types`. |
|
|
|
""" |
|
|
|
if not types: |
|
|
|
return 0 |
|
|
|
else: |
|
|
|
return max([category_priority.get(np.dtype(i).kind, 0) for i in types]) |
|
|
|
|
|
|
|
scalars = [] |
|
|
|
tensors = [] |
|
|
|
|
|
|
|
for data in inputs: |
|
|
|
if hasattr(data, "dtype"): |
|
|
|
tensors.append(data.dtype) |
|
|
|
elif isinstance(data, (float, int, bool)): |
|
|
|
scalars.append(scalar2dtype(data)) |
|
|
|
|
|
|
|
max_pri_scalars = max_priority(scalars) |
|
|
|
max_pri_tensors = max_priority(tensors) |
|
|
|
|
|
|
|
assert max_pri_scalars > 0 or max_pri_tensors > 0 |
|
|
|
|
|
|
|
if max_pri_scalars > max_pri_tensors: |
|
|
|
return promote_types(scalars, max_pri_scalars) |
|
|
|
else: |
|
|
|
return promote_types(tensors, max_pri_tensors) |
|
|
|
|
|
|
|
|
|
|
|
def get_device(inputs): |
|
|
|