|
|
@@ -64,7 +64,7 @@ class TorchDataLoader(DataLoader): |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, dataset, batch_size: int = 16, |
|
|
|
shuffle: bool = False, sampler = None, batch_sampler = None, |
|
|
|
shuffle: bool = False, sampler=None, batch_sampler=None, |
|
|
|
num_workers: int = 0, collate_fn: Union[Callable, str, None] = 'auto', |
|
|
|
pin_memory: bool = False, drop_last: bool = False, |
|
|
|
timeout: float = 0, worker_init_fn: Optional[Callable] = None, |
|
|
|