Utils¶
Torch_Utils¶
Misc PyTorch utils
- Usage:
>>> from torch_template import torch_utils >>> torch_utils.func_name() # to call functions in this file
-
class
torch_template.utils.torch_utils.
AverageMeters
(dic=None, total_num=None)[source]¶ AverageMeter class
- Example
>>> avg_meters = AverageMeters() >>> for i in range(100): >>> avg_meters.update({'f': i}) >>> print(str(avg_meters))
-
class
torch_template.utils.torch_utils.
ExponentialMovingAverage
(decay=0.9, dic=None, total_num=None)[source]¶ EMA class
- Example
>>> ema_meters = ExponentialMovingAverage(0.98) >>> for i in range(100): >>> ema_meters.update({'f': i}) >>> print(str(ema_meters))
-
class
torch_template.utils.torch_utils.
LR_Scheduler
(mode, base_lr, num_epochs, iters_per_epoch=0, lr_step=0, warmup_epochs=0, logger=None)[source]¶ Learning Rate Scheduler
Example
>>> scheduler = LR_Scheduler('cos', opt.lr, opt.epochs, len(dataloader), warmup_epochs=20) >>> for i, data in enumerate(dataloader) >>> scheduler(self.g_optimizer, i, epoch)
Step mode:
lr = baselr * 0.1 ^ {floor(epoch-1 / lr_step)}
每到达lr_step, lr就乘以0.1Cosine mode:
lr = baselr * 0.5 * (1 + cos(iter/maxiter))
Poly mode:
lr = baselr * (1 - iter/maxiter) ^ 0.9
iters_per_epoch: number of iterations per epoch
-
torch_template.utils.torch_utils.
clamp
(x, min=0.01, max=0.99)[source]¶ clamp a tensor.
Parameters: - x (torch.Tensor) – input tensor
- min (float) – value < min will be set to min.
- max (float) – value > max will be set to max.
Returns: a clamped tensor.
Return type: (torch.Tensor)
-
torch_template.utils.torch_utils.
create_summary_writer
(log_dir)[source]¶ Create a tensorboard summary writer.
Parameters: log_dir – log directory. Returns: a summary writer. Return type: (SummaryWriter) - Example
>>> writer = create_summary_writer(os.path.join(self.basedir, 'logs')) >>> write_meters_loss(writer, 'train', avg_meters, iteration) >>> write_loss(writer, 'train', 'F1', 0.78, iteration) >>> write_image(writer, 'train', 'input', img, iteration) >>> # shell >>> tensorboard --logdir {base_path}/logs
-
torch_template.utils.torch_utils.
load_ckpt
(model, ckpt_path)[source]¶ Load checkpoint.
Parameters: - model (nn.Module) – object of a subclass of nn.Module.
- ckpt_path (str) – *.pt file to load.
- Example
>>> class Model(nn.Module): >>> pass >>> >>> model = Model().cuda() >>> load_ckpt(model, 'model.pt')
-
torch_template.utils.torch_utils.
print_network
(net: <sphinx.ext.autodoc.importer._MockObject object at 0x7f8a407454a8>, print_size=False)[source]¶ Print network structure and number of parameters.
Parameters: - net (nn.Module) – network model.
- print_size (bool) – print parameter num of each layer.
- Example
>>> import torchvision as tv >>> from torch_template import torch_utils >>> >>> vgg16 = tv.models.vgg16() >>> torch_utils.print_network(vgg16) >>> ''' >>> features.0.weight [3, 64, 3, 3] >>> features.2.weight [64, 64, 3, 3] >>> features.5.weight [64, 128, 3, 3] >>> features.7.weight [128, 128, 3, 3] >>> features.10.weight [128, 256, 3, 3] >>> features.12.weight [256, 256, 3, 3] >>> features.14.weight [256, 256, 3, 3] >>> features.17.weight [256, 512, 3, 3] >>> features.19.weight [512, 512, 3, 3] >>> features.21.weight [512, 512, 3, 3] >>> features.24.weight [512, 512, 3, 3] >>> features.26.weight [512, 512, 3, 3] >>> features.28.weight [512, 512, 3, 3] >>> classifier.0.weight [25088, 4096] >>> classifier.3.weight [4096, 4096] >>> classifier.6.weight [4096, 1000] >>> Total number of parameters: 138,357,544 >>> '''
-
torch_template.utils.torch_utils.
repeat
(x: <sphinx.ext.autodoc.importer._MockObject object at 0x7f8a407454e0>, *sizes)[source]¶ Repeat a dimension of a tensor.
Parameters: - x (torch.Tensor) – input tensor.
- sizes – repeat times for each dimension.
Returns: a repeated tensor.
Return type: (torch.Tensor)
- Example
>>> t = repeat(t, 1, 3, 1, 1) # same as t = t.repeat(1, 3, 1, 1) or t = torch.cat([t, t, t], dim=1)
-
torch_template.utils.torch_utils.
save_ckpt
(model, ckpt_path)[source]¶ Save checkpoint.
Parameters: - model (nn.Module) – object of a subclass of nn.Module.
- ckpt_path (str) – *.pt file to save.
- Example
>>> class Model(nn.Module): >>> pass >>> >>> model = Model().cuda() >>> save_ckpt(model, 'model.pt')
-
torch_template.utils.torch_utils.
tensor2im
(x: <sphinx.ext.autodoc.importer._MockObject object at 0x7f8a40745518>, norm=False, to_save=False)[source]¶ Convert tensor to image.
Parameters: - x (torch.Tensor) – input tensor, [n, c, h, w] float32 type.
- norm (bool) – if the tensor should be denormed first
- to_save (bool) – if False, a float32 image of [h, w, c], if True, a uint8 image of [h, w, c].
Returns: an image in shape of [h, w, c] if to_save else [c, h, w].
-
torch_template.utils.torch_utils.
write_graph
(writer: <sphinx.ext.autodoc.importer._MockObject object at 0x7f8a40745438>, model, inputs_to_model=None)[source]¶ Write net graph into writer.
Parameters: - writer (SummaryWriter) – writer created by create_summary_writer()
- model (nn.Module) – model.
- inputs_to_model (tuple or list) – forward inputs.
- Example
>>> from tensorboardX import SummaryWriter >>> input_data = Variable(torch.rand(16, 3, 224, 224)) >>> vgg16 = torchvision.models.vgg16() >>> >>> writer = SummaryWriter(log_dir='logs') >>> write_graph(vgg16, (input_data,))
-
torch_template.utils.torch_utils.
write_image
(writer: <sphinx.ext.autodoc.importer._MockObject object at 0x7f8a40745438>, prefix, image_name: str, img, iteration, dataformats='CHW')[source]¶ Write images into writer.
Parameters: - writer (SummaryWriter) – writer created by create_summary_writer()
- prefix (str) – any string, e.g. ‘train’.
- image_name (str) – image name.
- img – image tensor in [C, H, W] shape.
- iteration (int) – epochs or iterations.
- dataformats (str) – ‘CHW’ or ‘HWC’ or ‘NCHW’.
- Example
>>> write_image(writer, 'train', 'input', img, iteration)
-
torch_template.utils.torch_utils.
write_loss
(writer: <sphinx.ext.autodoc.importer._MockObject object at 0x7f8a40745438>, prefix, loss_name: str, value: float, iteration)[source]¶ Write loss into writer.
Parameters: - writer (SummaryWriter) – writer created by create_summary_writer()
- prefix (str) – any string, e.g. ‘train’.
- loss_name (str) – loss name.
- value (float) – loss value.
- iteration (int) – epochs or iterations.
- Example
>>> write_loss(writer, 'train', 'F1', 0.78, iteration)
-
torch_template.utils.torch_utils.
write_meters_loss
(writer: <sphinx.ext.autodoc.importer._MockObject object at 0x7f8a40745438>, prefix, avg_meters: torch_template.utils.torch_utils.Meters, iteration)[source]¶ Write all losses in a meter class into writer.
Parameters: - writer (SummaryWriter) – writer created by create_summary_writer()
- prefix (str) – any string, e.g. ‘train’.
- avg_meters (AverageMeters or ExponentialMovingAverage) – meters.
- iteration (int) – epochs or iterations.
- Example
>>> writer = create_summary_writer(os.path.join(self.basedir, 'logs')) >>> ema_meters = ExponentialMovingAverage(0.98) >>> for i in range(100): >>> ema_meters.update({'f1': i, 'f2': i*0.5}) >>> write_meters_loss(writer, 'train', ema_meters, i)