1st's Studio.

超分代码详解

字数统计: 871阅读时长: 4 min
2022/04/24

近段时间在Github上git了很多超分代码,绞尽脑汁也无法推理,于是就想挑一份代码详细研究一下他的架构,话不多说,开干!

整体架构

超分代码大多存在于src文件夹中,首先我们先来看看代码的大致结构:

可以看出,代码被包装成三个子部分以及main函数部分。

  • main函数部分
  • data文件夹
  • loss文件夹
  • model文件夹

main函数部分

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch

import utility
import data
import model
import loss
from option import args
from trainer import Trainer

torch.manual_seed(args.seed) #
checkpoint = utility.checkpoint(args)

def main():
global model
if args.data_test == ['video']:
from videotester import VideoTester
model = model.Model(args, checkpoint)
t = VideoTester(args, model, checkpoint)
t.test()
else:
if checkpoint.ok:
loader = data.Data(args)
_model = model.Model(args, checkpoint)
_loss = loss.Loss(args, checkpoint) if not args.test_only else None
t = Trainer(args, loader, _model, _loss, checkpoint)
while not t.terminate():
t.train()
t.test()

checkpoint.done()

if __name__ == '__main__':
main()

下面我们来仔细刨析这短短的几行代码:

首先先设置了随机数种子,设计随机数种子的好处在于每次从头开始训练网络时,网络的权重都是一致的。随后创建了一个checkpoint对象。

1
2
torch.manual_seed(args.seed)
checkpoint=utility.checkpoint(args)

接下来就是main函数部分,结构比较经典:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def main():
global model # 创建全局变量global
if args.data_test == ['video']: # 进入推理部分
from videotester import VideoTester
model = model.Model(args, checkpoint)
t = VideoTester(args, model, checkpoint)
t.test()
else: # 进入训练部分
if checkpoint.ok:
loader = data.Data(args) # 读取数据
_model = model.Model(args, checkpoint) # 初始化模型Model
_loss = loss.Loss(args, checkpoint) if not args.test_only else None # 初始化损失模型
t = Trainer(args, loader, _model, _loss, checkpoint) # 初始化Trainer类
while not t.terminate():
t.train() # 训练
t.test() # 推理,

checkpoint.done() # 关闭,类似于释放内存

Dataloader.py 加载数据部分

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155

import torch
import torch.multiprocessing as multiprocessing
from torch.utils.data import DataLoader
from torch.utils.data import SequentialSampler
from torch.utils.data import RandomSampler
from torch.utils.data import BatchSampler
from torch.utils.data import _utils
from torch.utils.data.dataloader import _DataLoaderIter

from torch.utils.data._utils import collate
from torch.utils.data._utils import signal_handling
from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL
from torch.utils.data._utils import ExceptionWrapper
from torch.utils.data._utils import IS_WINDOWS
from torch.utils.data._utils.worker import ManagerWatchdog

from torch._six import queue

def _ms_loop(dataset, index_queue, data_queue, done_event, collate_fn, scale, seed, init_fn, worker_id):
try:
collate._use_shared_memory = True
signal_handling._set_worker_signal_handlers()

torch.set_num_threads(1)
random.seed(seed)
torch.manual_seed(seed)

data_queue.cancel_join_thread()

if init_fn is not None:
init_fn(worker_id)

watchdog = ManagerWatchdog()

while watchdog.is_alive():
try:
r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
except queue.Empty:
continue

if r is None:
assert done_event.is_set()
return
elif done_event.is_set():
continue

idx, batch_indices = r
try:
idx_scale = 0
if len(scale) > 1 and dataset.train:
idx_scale = random.randrange(0, len(scale))
dataset.set_scale(idx_scale)

samples = collate_fn([dataset[i] for i in batch_indices])
samples.append(idx_scale)
except Exception:
data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
else:
data_queue.put((idx, samples))
del samples

except KeyboardInterrupt:
pass

class _MSDataLoaderIter(_DataLoaderIter):

def __init__(self, loader):
self.dataset = loader.dataset
self.scale = loader.scale
self.collate_fn = loader.collate_fn
self.batch_sampler = loader.batch_sampler
self.num_workers = loader.num_workers
self.pin_memory = loader.pin_memory and torch.cuda.is_available()
self.timeout = loader.timeout

self.sample_iter = iter(self.batch_sampler)

base_seed = torch.LongTensor(1).random_().item()

if self.num_workers > 0:
self.worker_init_fn = loader.worker_init_fn
self.worker_queue_idx = 0
self.worker_result_queue = multiprocessing.Queue()
self.batches_outstanding = 0
self.worker_pids_set = False
self.shutdown = False
self.send_idx = 0
self.rcvd_idx = 0
self.reorder_dict = {}
self.done_event = multiprocessing.Event()

base_seed = torch.LongTensor(1).random_()[0]

self.index_queues = []
self.workers = []
for i in range(self.num_workers):
index_queue = multiprocessing.Queue()
index_queue.cancel_join_thread()
w = multiprocessing.Process(
target=_ms_loop,
args=(
self.dataset,
index_queue,
self.worker_result_queue,
self.done_event,
self.collate_fn,
self.scale,
base_seed + i,
self.worker_init_fn,
i
)
)
w.daemon = True
w.start()
self.index_queues.append(index_queue)
self.workers.append(w)

if self.pin_memory:
self.data_queue = queue.Queue()
pin_memory_thread = threading.Thread(
target=_utils.pin_memory._pin_memory_loop,
args=(
self.worker_result_queue,
self.data_queue,
torch.cuda.current_device(),
self.done_event
)
)
pin_memory_thread.daemon = True
pin_memory_thread.start()
self.pin_memory_thread = pin_memory_thread
else:
self.data_queue = self.worker_result_queue

_utils.signal_handling._set_worker_pids(
id(self), tuple(w.pid for w in self.workers)
)
_utils.signal_handling._set_SIGCHLD_handler()
self.worker_pids_set = True

for _ in range(2 * self.num_workers):
self._put_indices()


class MSDataLoader(DataLoader):

def __init__(self, cfg, *args, **kwargs):
super(MSDataLoader, self).__init__(
*args, **kwargs, num_workers=cfg.n_threads
)
self.scale = cfg.scale

def __iter__(self):
return _MSDataLoaderIter(self)
CATALOG
  1. 1. 整体架构
  2. 2. main函数部分
  3. 3. Dataloader.py 加载数据部分