Skip to content

Commit ad103b4

Browse files
committed
v0.7.0 changes -- no more explicit residual layers
1 parent ff53cac commit ad103b4

2 files changed

Lines changed: 60 additions & 61 deletions

File tree

README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ Goals:
2626
* torch- and python-idiomatic
2727
* hackable
2828
* few external dependencies (currently only torch and torchvision)
29-
* ~world-record single-GPU training time (this repo holds the current world record at ~<7 (!!!) seconds on an A100, down from ~18.1 seconds originally).
29+
* ~world-record single-GPU training time (this repo holds the current world record at ~<6.3 (!!!) seconds on an A100, down from ~18.1 seconds originally).
3030
* <2 seconds training time in <2 years (yep!)
3131

3232
This is a neural network implementation of a very speedily-training network that originally started as a painstaking reproduction of [David Page's original ultra-fast CIFAR-10 implementation on a single GPU](https://myrtle.ai/learn/how-to-train-your-resnet/), but written nearly from the ground-up to be extremely rapid-experimentation-friendly. Part of the benefit of this is that we now hold the world record for single GPU training speeds on CIFAR10, for example.
@@ -39,6 +39,9 @@ What we've added:
3939
* dirac initializations on non-depth-transitional layers (information passthrough on init)
4040
* and more!
4141

42+
What we've removed:
43+
* explicit residual layers. yep.
44+
4245
This code, in comparison to David's original code, is in a single file and extremely flat, but is not as durable for long-term production-level bug maintenance. You're meant to check out a fresh repo whenever you have a new idea. It is excellent for rapid idea exploring -- almost everywhere in the pipeline is exposed and built to be user-friendly. I truly enjoy personally using this code, and hope you do as well! :D Please let me know if you have any feedback. I hope to continue publishing updates to this in the future, so your support is encouraged. Share this repo with someone you know that might like it!
4346

4447
Feel free to check out my[Patreon](https://www.patreon.com/user/posts?u=83632131) if you like what I'm doing here and want more!. Additionally, if you want me to work up to a part-time amount of hours with you, feel free to reach out to me at hire.tysam@gmail.com. I'd love to hear from you.

main.py

Lines changed: 56 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -43,25 +43,24 @@
4343
default_conv_kwargs = {'kernel_size': 3, 'padding': 'same', 'bias': False}
4444

4545
batchsize = 1024
46-
bias_scaler = 56
47-
# To replicate the ~95.78%-accuracy-in-113-seconds runs, you can change the base_depth from 64->128, train_epochs from 12.1->85, ['ema'] epochs 10->75, cutmix_size 3->9, and cutmix_epochs 6->75
46+
bias_scaler = 64
47+
# To replicate the ~95.79%-accuracy-in-110-seconds runs, you can change the base_depth from 64->128, train_epochs from 12.1->90, ['ema'] epochs 10->80, cutmix_size 3->10, and cutmix_epochs 6->80
4848
hyp = {
4949
'opt': {
50-
'bias_lr': 1.64 * bias_scaler/512, # TODO: Is there maybe a better way to express the bias and batchnorm scaling? :'))))
51-
'non_bias_lr': 1.64 / 512,
52-
'bias_decay': 1.08 * 6.45e-4 * batchsize/bias_scaler,
53-
'non_bias_decay': 1.08 * 6.45e-4 * batchsize,
50+
'bias_lr': 1.525 * bias_scaler/512, # TODO: Is there maybe a better way to express the bias and batchnorm scaling? :'))))
51+
'non_bias_lr': 1.525 / 512,
52+
'bias_decay': 6.687e-4 * batchsize/bias_scaler,
53+
'non_bias_decay': 6.687e-4 * batchsize,
5454
'scaling_factor': 1./9,
5555
'percent_start': .23,
56-
'loss_scale_scaler': 1./128, # * Regularizer inside the loss summing (range: ~1/512 - 16+). FP8 should help with this somewhat too, whenever it comes out. :)
56+
'loss_scale_scaler': 1./32, # * Regularizer inside the loss summing (range: ~1/512 - 16+). FP8 should help with this somewhat too, whenever it comes out. :)
5757
},
5858
'net': {
5959
'whitening': {
6060
'kernel_size': 2,
6161
'num_examples': 50000,
6262
},
63-
'batch_norm_momentum': .5, # * Don't forget momentum is 1 - momentum here (due to a quirk in the original paper... >:( )
64-
'conv_norm_pow': 2.6,
63+
'batch_norm_momentum': .4, # * Don't forget momentum is 1 - momentum here (due to a quirk in the original paper... >:( )
6564
'cutmix_size': 3,
6665
'cutmix_epochs': 6,
6766
'pad_amount': 2,
@@ -162,42 +161,34 @@ def __init__(self, num_features, eps=1e-12, momentum=hyp['net']['batch_norm_mome
162161
# Having an outer class like this does add space and complexity but offers us
163162
# a ton of freedom when it comes to hacking in unique functionality for each layer type
164163
class Conv(nn.Conv2d):
165-
def __init__(self, *args, norm=False, **kwargs):
164+
def __init__(self, *args, **kwargs):
166165
kwargs = {**default_conv_kwargs, **kwargs}
167166
super().__init__(*args, **kwargs)
168167
self.kwargs = kwargs
169-
self.norm = norm
170-
171-
def forward(self, x):
172-
if self.training and self.norm:
173-
# TODO: Do/should we always normalize along dimension 1 of the weight vector(s), or the height x width dims too?
174-
with torch.no_grad():
175-
F.normalize(self.weight.data, p=self.norm)
176-
return super().forward(x)
177168

178169
class Linear(nn.Linear):
179-
def __init__(self, *args, norm=False, **kwargs):
170+
def __init__(self, *args, temperature=None, **kwargs):
180171
super().__init__(*args, **kwargs)
181172
self.kwargs = kwargs
182-
self.norm = norm
173+
self.temperature = temperature
183174

184175
def forward(self, x):
185-
if self.training and self.norm:
186-
# TODO: Normalize on dim 1 or dim 0 for this guy?
187-
with torch.no_grad():
188-
F.normalize(self.weight.data, p=self.norm)
189-
return super().forward(x)
176+
if self.temperature is not None:
177+
weight = self.weight * self.temperature
178+
else:
179+
weight = self.weight
180+
return x @ weight.T
190181

191-
# can hack any changes to each residual group that you want directly in here
182+
# can hack any changes to each convolution group that you want directly in here
192183
class ConvGroup(nn.Module):
193-
def __init__(self, channels_in, channels_out, norm):
184+
def __init__(self, channels_in, channels_out):
194185
super().__init__()
195-
self.channels_in = channels_in
186+
self.channels_in = channels_in
196187
self.channels_out = channels_out
197188

198189
self.pool1 = nn.MaxPool2d(2)
199-
self.conv1 = Conv(channels_in, channels_out, norm=norm)
200-
self.conv2 = Conv(channels_out, channels_out, norm=norm)
190+
self.conv1 = Conv(channels_in, channels_out)
191+
self.conv2 = Conv(channels_out, channels_out)
201192

202193
self.norm1 = BatchNorm(channels_out)
203194
self.norm2 = BatchNorm(channels_out)
@@ -210,20 +201,11 @@ def forward(self, x):
210201
x = self.pool1(x)
211202
x = self.norm1(x)
212203
x = self.activ(x)
213-
residual = x
214204
x = self.conv2(x)
215205
x = self.norm2(x)
216206
x = self.activ(x)
217-
x = x + residual # haiku
218-
return x
219207

220-
class TemperatureScaler(nn.Module):
221-
def __init__(self, init_val):
222-
super().__init__()
223-
self.scaler = torch.tensor(init_val)
224-
225-
def forward(self, x):
226-
return x.mul(self.scaler)
208+
return x
227209

228210
class FastGlobalMaxPooling(nn.Module):
229211
def __init__(self):
@@ -275,7 +257,7 @@ def init_whitening_conv(layer, train_set=None, num_examples=None, previous_block
275257
eigenvalue_list.append(eigenvalues)
276258
eigenvector_list.append(eigenvectors)
277259

278-
eigenvalues = torch.stack(eigenvalue_list, dim=0).mean(0)
260+
eigenvalues = torch.stack(eigenvalue_list, dim=0).mean(0)
279261
eigenvectors = torch.stack(eigenvector_list, dim=0).mean(0)
280262
# i believe the eigenvalues and eigenvectors come out in float32 for this because we implicitly cast it to float32 in the patches function (for numerical stability)
281263
set_whitening_conv(layer, eigenvalues.to(dtype=layer.weight.dtype), eigenvectors.to(dtype=layer.weight.dtype), freeze=freeze)
@@ -284,7 +266,8 @@ def init_whitening_conv(layer, train_set=None, num_examples=None, previous_block
284266

285267
def set_whitening_conv(conv_layer, eigenvalues, eigenvectors, eps=1e-2, freeze=True):
286268
shape = conv_layer.weight.data.shape
287-
conv_layer.weight.data[-eigenvectors.shape[0]:, :, :, :] = (eigenvectors/torch.sqrt(eigenvalues+eps))[-shape[0]:, :, :, :] # set the first n filters of the weight data to the top n significant (sorted by importance) filters from the eigenvectors
269+
eigenvectors_sliced = (eigenvectors/torch.sqrt(eigenvalues+eps))[-shape[0]:, :, :, :] # set the first n filters of the weight data to the top n significant (sorted by importance) filters from the eigenvectors
270+
conv_layer.weight.data = torch.cat((eigenvectors_sliced, -eigenvectors_sliced), dim=0)
288271
## We don't want to train this, since this is implicitly whitening over the whole dataset
289272
## For more info, see David Page's original blogposts (link in the README.md as of this commit.)
290273
if freeze:
@@ -304,7 +287,7 @@ def set_whitening_conv(conv_layer, eigenvalues, eigenvectors, eps=1e-2, freeze=T
304287
'num_classes': 10
305288
}
306289

307-
class SpeedyResNet(nn.Module):
290+
class SpeedyConvNet(nn.Module):
308291
def __init__(self, network_dict):
309292
super().__init__()
310293
self.net_dict = network_dict # flexible, defined in the make_net function
@@ -314,14 +297,12 @@ def forward(self, x):
314297
if not self.training:
315298
x = torch.cat((x, torch.flip(x, (-1,))))
316299
x = self.net_dict['initial_block']['whiten'](x)
317-
x = self.net_dict['initial_block']['project'](x)
318300
x = self.net_dict['initial_block']['activation'](x)
319-
x = self.net_dict['residual1'](x)
320-
x = self.net_dict['residual2'](x)
321-
x = self.net_dict['residual3'](x)
301+
x = self.net_dict['conv_group_1'](x)
302+
x = self.net_dict['conv_group_2'](x)
303+
x = self.net_dict['conv_group_3'](x)
322304
x = self.net_dict['pooling'](x)
323305
x = self.net_dict['linear'](x)
324-
x = self.net_dict['temperature'](x)
325306
if not self.training:
326307
# Average the predictions from the lr-flipped inputs during eval
327308
orig, flipped = x.split(x.shape[0]//2, dim=0)
@@ -335,18 +316,16 @@ def make_net():
335316
network_dict = nn.ModuleDict({
336317
'initial_block': nn.ModuleDict({
337318
'whiten': Conv(3, whiten_conv_depth, kernel_size=hyp['net']['whitening']['kernel_size'], padding=0),
338-
'project': Conv(whiten_conv_depth, depths['init'], kernel_size=1, norm=2.2), # The norm argument means we renormalize the weights to be length 1 for this as the power for the norm, each step
339319
'activation': nn.GELU(),
340320
}),
341-
'residual1': ConvGroup(depths['init'], depths['block1'], hyp['net']['conv_norm_pow']),
342-
'residual2': ConvGroup(depths['block1'], depths['block2'], hyp['net']['conv_norm_pow']),
343-
'residual3': ConvGroup(depths['block2'], depths['block3'], hyp['net']['conv_norm_pow']),
321+
'conv_group_1': ConvGroup(2*whiten_conv_depth, depths['block1']),
322+
'conv_group_2': ConvGroup(depths['block1'], depths['block2']),
323+
'conv_group_3': ConvGroup(depths['block2'], depths['block3']),
344324
'pooling': FastGlobalMaxPooling(),
345-
'linear': Linear(depths['block3'], depths['num_classes'], bias=False, norm=5.),
346-
'temperature': TemperatureScaler(hyp['opt']['scaling_factor'])
325+
'linear': Linear(depths['block3'], depths['num_classes'], bias=False, temperature=hyp['opt']['scaling_factor']),
347326
})
348327

349-
net = SpeedyResNet(network_dict)
328+
net = SpeedyConvNet(network_dict)
350329
net = net.to(hyp['misc']['device'])
351330
net = net.to(memory_format=torch.channels_last) # to appropriately use tensor cores/avoid thrash while training
352331
net.train()
@@ -365,18 +344,35 @@ def make_net():
365344
## the index lookup in the dataloader may give you some trouble depending
366345
## upon exactly how memory-limited you are
367346

368-
## We initialize the projections layer to return exactly the spatial inputs, this way we start
369-
## at a nice clean place (the whitened image in feature space, directly) and can iterate directly from there.
370-
torch.nn.init.dirac_(net.net_dict['initial_block']['project'].weight)
371347

372348
for layer_name in net.net_dict.keys():
373-
if 'residual' in layer_name:
374-
## We do the same for the second layer in each residual block, since this only
349+
if 'conv_group' in layer_name:
350+
# Create an implicit residual via a dirac-initialized tensor
351+
dirac_weights_in = torch.nn.init.dirac_(torch.empty_like(net.net_dict[layer_name].conv1.weight))
352+
353+
# Add the implicit residual to the already-initialized convolutional transition layer.
354+
# One can use more sophisticated initializations, but this one appeared worked best in testing.
355+
# What this does is brings up the features from the previous residual block virtually, so not only
356+
# do we have residual information flow within each block, we have a nearly direct connection from
357+
# the early layers of the network to the loss function.
358+
std_pre, mean_pre = torch.std_mean(net.net_dict[layer_name].conv1.weight.data)
359+
net.net_dict[layer_name].conv1.weight.data = net.net_dict[layer_name].conv1.weight.data + dirac_weights_in
360+
std_post, mean_post = torch.std_mean(net.net_dict[layer_name].conv1.weight.data)
361+
362+
# Renormalize the weights to match the original initialization statistics
363+
net.net_dict[layer_name].conv1.weight.data.sub_(mean_post).div_(std_post).mul_(std_pre).add_(mean_pre)
364+
365+
## We do the same for the second layer in each convolution group block, since this only
375366
## adds a simple multiplier to the inputs instead of the noise of a randomly-initialized
376367
## convolution. This can be easily scaled down by the network, and the weights can more easily
377368
## pivot in whichever direction they need to go now.
369+
## The reason that I believe that this works so well is because a combination of MaxPool2d
370+
## and the nn.GeLU function's positive bias encouraging values towards the nearly-linear
371+
## region of the GeLU activation function at network initialization. I am not currently
372+
## sure about this, however, it will require some more investigation. For now -- it works! D:
378373
torch.nn.init.dirac_(net.net_dict[layer_name].conv2.weight)
379374

375+
380376
return net
381377

382378
#############################################

0 commit comments

Comments
 (0)