1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
| import torch import torch.multiprocessing as multiprocessing from torch.utils.data import DataLoader from torch.utils.data import SequentialSampler from torch.utils.data import RandomSampler from torch.utils.data import BatchSampler from torch.utils.data import _utils from torch.utils.data.dataloader import _DataLoaderIter
from torch.utils.data._utils import collate from torch.utils.data._utils import signal_handling from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL from torch.utils.data._utils import ExceptionWrapper from torch.utils.data._utils import IS_WINDOWS from torch.utils.data._utils.worker import ManagerWatchdog
from torch._six import queue
def _ms_loop(dataset, index_queue, data_queue, done_event, collate_fn, scale, seed, init_fn, worker_id): try: collate._use_shared_memory = True signal_handling._set_worker_signal_handlers()
torch.set_num_threads(1) random.seed(seed) torch.manual_seed(seed)
data_queue.cancel_join_thread()
if init_fn is not None: init_fn(worker_id)
watchdog = ManagerWatchdog()
while watchdog.is_alive(): try: r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) except queue.Empty: continue
if r is None: assert done_event.is_set() return elif done_event.is_set(): continue
idx, batch_indices = r try: idx_scale = 0 if len(scale) > 1 and dataset.train: idx_scale = random.randrange(0, len(scale)) dataset.set_scale(idx_scale)
samples = collate_fn([dataset[i] for i in batch_indices]) samples.append(idx_scale) except Exception: data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) else: data_queue.put((idx, samples)) del samples
except KeyboardInterrupt: pass
class _MSDataLoaderIter(_DataLoaderIter):
def __init__(self, loader): self.dataset = loader.dataset self.scale = loader.scale self.collate_fn = loader.collate_fn self.batch_sampler = loader.batch_sampler self.num_workers = loader.num_workers self.pin_memory = loader.pin_memory and torch.cuda.is_available() self.timeout = loader.timeout
self.sample_iter = iter(self.batch_sampler)
base_seed = torch.LongTensor(1).random_().item()
if self.num_workers > 0: self.worker_init_fn = loader.worker_init_fn self.worker_queue_idx = 0 self.worker_result_queue = multiprocessing.Queue() self.batches_outstanding = 0 self.worker_pids_set = False self.shutdown = False self.send_idx = 0 self.rcvd_idx = 0 self.reorder_dict = {} self.done_event = multiprocessing.Event()
base_seed = torch.LongTensor(1).random_()[0]
self.index_queues = [] self.workers = [] for i in range(self.num_workers): index_queue = multiprocessing.Queue() index_queue.cancel_join_thread() w = multiprocessing.Process( target=_ms_loop, args=( self.dataset, index_queue, self.worker_result_queue, self.done_event, self.collate_fn, self.scale, base_seed + i, self.worker_init_fn, i ) ) w.daemon = True w.start() self.index_queues.append(index_queue) self.workers.append(w)
if self.pin_memory: self.data_queue = queue.Queue() pin_memory_thread = threading.Thread( target=_utils.pin_memory._pin_memory_loop, args=( self.worker_result_queue, self.data_queue, torch.cuda.current_device(), self.done_event ) ) pin_memory_thread.daemon = True pin_memory_thread.start() self.pin_memory_thread = pin_memory_thread else: self.data_queue = self.worker_result_queue
_utils.signal_handling._set_worker_pids( id(self), tuple(w.pid for w in self.workers) ) _utils.signal_handling._set_SIGCHLD_handler() self.worker_pids_set = True
for _ in range(2 * self.num_workers): self._put_indices()
class MSDataLoader(DataLoader):
def __init__(self, cfg, *args, **kwargs): super(MSDataLoader, self).__init__( *args, **kwargs, num_workers=cfg.n_threads ) self.scale = cfg.scale
def __iter__(self): return _MSDataLoaderIter(self)
|