|
|
@@ -169,7 +169,7 @@ class OneflowTensorPadder(Padder): |
|
|
|
else: |
|
|
|
max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] |
|
|
|
|
|
|
|
tensor = oneflow.full(max_shape, value=pad_val, dtype=dtype, device=device) |
|
|
|
tensor = oneflow.full(max_shape, fill_value=pad_val, dtype=dtype, device=device) |
|
|
|
for i, field in enumerate(batch_field): |
|
|
|
slices = (i, ) + tuple(slice(0, s) for s in shapes[i]) |
|
|
|
tensor[slices] = field |
|
|
@@ -221,6 +221,6 @@ def get_padded_oneflow_tensor(batch_field, dtype=None, pad_val=0): |
|
|
|
:return: |
|
|
|
""" |
|
|
|
shapes = get_shape(batch_field) |
|
|
|
tensor = oneflow.full(shapes, dtype=dtype, value=pad_val) |
|
|
|
tensor = oneflow.full(shapes, dtype=dtype, fill_value=pad_val) |
|
|
|
tensor = fill_tensor(batch_field, tensor, dtype=dtype) |
|
|
|
return tensor |