|
@@ -134,7 +134,7 @@ def full(shape, value, dtype="float32", device=None) -> Tensor: |
|
|
if device is None: |
|
|
if device is None: |
|
|
device = get_default_device() |
|
|
device = get_default_device() |
|
|
(x,) = Const(value, dtype=dtype, device=device)() |
|
|
(x,) = Const(value, dtype=dtype, device=device)() |
|
|
if shape is (): # scalar.shape |
|
|
|
|
|
|
|
|
if type(shape) in (list, tuple) and len(shape) == 0: |
|
|
return x |
|
|
return x |
|
|
return broadcast_to(x, shape) |
|
|
return broadcast_to(x, shape) |
|
|
|
|
|
|
|
@@ -278,7 +278,7 @@ def full_like( |
|
|
|
|
|
|
|
|
""" |
|
|
""" |
|
|
(x,) = Const(value, dtype=inp.dtype, device=inp.device)(inp) |
|
|
(x,) = Const(value, dtype=inp.dtype, device=inp.device)(inp) |
|
|
if inp.shape is (): |
|
|
|
|
|
|
|
|
if inp.ndim == 0: |
|
|
return x |
|
|
return x |
|
|
return broadcast_to(x, inp.shape) |
|
|
return broadcast_to(x, inp.shape) |
|
|
|
|
|
|
|
|