trident
Table Of Contents
trident
Table Of Contents

Source code for trident.optims.pytorch_optimizers

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import itertools as it
import math
import os
import sys
import time
import uuid
from collections import OrderedDict, defaultdict
from functools import partial
from shutil import copyfile

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.hooks as hooks
from torch.autograd import Variable
from torch.optim.optimizer import Optimizer

from trident.backend.common import get_session, addindent, get_time_suffix, get_class, format_time, get_terminal_size, snake2camel, camel2snake
from trident.backend.pytorch_backend import *
from trident.backend.pytorch_ops import *
from trident.backend.optimizer import OptimizerBase

__all__ = ['Adam','SGD','LBFGS','Adadelta','Adagrad','RMSprop','RAdam','PlainRAdam','AdamW','Lookahead','Ranger','get_optimizer']

[docs]class Adam(optim.Adam, OptimizerBase): pass
[docs]class SGD(optim.SGD, OptimizerBase): pass
[docs]class LBFGS(get_class('LBFGS',['torch.optim']), OptimizerBase): pass
[docs]class Adadelta(get_class('Adadelta',['torch.optim']), OptimizerBase): pass
[docs]class Adagrad(get_class('Adagrad',['torch.optim']), OptimizerBase): pass
[docs]class RMSprop(get_class('RMSprop',['torch.optim']), OptimizerBase): pass
[docs]class RAdam(Optimizer, OptimizerBase): def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) self.degenerated_to_sgd = degenerated_to_sgd if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict): for param in params: if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]): param['buffer'] = [[None, None, None] for _ in range(10)] defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)]) super(RAdam, self).__init__(params, defaults) def __setstate__(self, state): super(RAdam, self).__setstate__(state)
[docs] def step(self, closure=None): loss = None if closure is not None: loss = closure() for group in self.param_groups: for p in group['params']: if p.grad is None: continue grad = p.grad.data.float() if grad.is_sparse: raise RuntimeError('RAdam does not support sparse gradients') p_data_fp32 = p.data.float() state = self.state[p] if len(state) == 0: state['step'] = 0 state['exp_avg'] = torch.zeros_like(p_data_fp32) state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) else: state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] beta1, beta2 = group['betas'] exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) exp_avg.mul_(beta1).add_(1 - beta1, grad) state['step'] += 1 buffered = group['buffer'][int(state['step'] % 10)] if state['step'] == buffered[0]: N_sma, step_size = buffered[1], buffered[2] else: buffered[0] = state['step'] beta2_t = beta2 ** state['step'] N_sma_max = 2 / (1 - beta2) - 1 N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) buffered[1] = N_sma # more conservative since it's an approximated value if N_sma >= 5: step_size = math.sqrt( (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( N_sma_max - 2)) / (1 - beta1 ** state['step']) elif self.degenerated_to_sgd: step_size = 1.0 / (1 - beta1 ** state['step']) else: step_size = -1 buffered[2] = step_size # more conservative since it's an approximated value if N_sma >= 5: if group['weight_decay'] != 0: p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) denom = exp_avg_sq.sqrt().add_(group['eps']) p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) p.data.copy_(p_data_fp32) elif step_size > 0: if group['weight_decay'] != 0: p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) p_data_fp32.add_(-step_size * group['lr'], exp_avg) p.data.copy_(p_data_fp32) return loss
[docs]class PlainRAdam(Optimizer, OptimizerBase): def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) self.degenerated_to_sgd = degenerated_to_sgd defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) super(PlainRAdam, self).__init__(params, defaults) def __setstate__(self, state): super(PlainRAdam, self).__setstate__(state)
[docs] def step(self, closure=None): loss = None if closure is not None: loss = closure() for group in self.param_groups: for p in group['params']: if p.grad is None: continue grad = p.grad.data.float() if grad.is_sparse: raise RuntimeError('RAdam does not support sparse gradients') p_data_fp32 = p.data.float() state = self.state[p] if len(state) == 0: state['step'] = 0 state['exp_avg'] = torch.zeros_like(p_data_fp32) state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) else: state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] beta1, beta2 = group['betas'] exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) exp_avg.mul_(beta1).add_(1 - beta1, grad) state['step'] += 1 beta2_t = beta2 ** state['step'] N_sma_max = 2 / (1 - beta2) - 1 N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) # more conservative since it's an approximated value if N_sma >= 5: if group['weight_decay'] != 0: p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) step_size = group['lr'] * math.sqrt( (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( N_sma_max - 2)) / (1 - beta1 ** state['step']) denom = exp_avg_sq.sqrt().add_(group['eps']) p_data_fp32.addcdiv_(-step_size, exp_avg, denom) p.data.copy_(p_data_fp32) elif self.degenerated_to_sgd: if group['weight_decay'] != 0: p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) step_size = group['lr'] / (1 - beta1 ** state['step']) p_data_fp32.add_(-step_size, exp_avg) p.data.copy_(p_data_fp32) return loss
[docs]class AdamW(Optimizer, OptimizerBase): def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup=0): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, warmup=warmup) super(AdamW, self).__init__(params, defaults) def __setstate__(self, state): super(AdamW, self).__setstate__(state)
[docs] def step(self, closure=None): loss = None if closure is not None: loss = closure() for group in self.param_groups: for p in group['params']: if p.grad is None: continue grad = p.grad.data.float() if grad.is_sparse: raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') p_data_fp32 = p.data.float() state = self.state[p] if len(state) == 0: state['step'] = 0 state['exp_avg'] = torch.zeros_like(p_data_fp32) state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) else: state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] beta1, beta2 = group['betas'] state['step'] += 1 exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) exp_avg.mul_(beta1).add_(1 - beta1, grad) denom = exp_avg_sq.sqrt().add_(group['eps']) bias_correction1 = 1 - beta1 ** state['step'] bias_correction2 = 1 - beta2 ** state['step'] if group['warmup'] > state['step']: scheduled_lr = 1e-8 + state['step'] * group['lr'] / group['warmup'] else: scheduled_lr = group['lr'] step_size = scheduled_lr * math.sqrt(bias_correction2) / bias_correction1 if group['weight_decay'] != 0: p_data_fp32.add_(-group['weight_decay'] * scheduled_lr, p_data_fp32) p_data_fp32.addcdiv_(-step_size, exp_avg, denom) p.data.copy_(p_data_fp32) return loss
[docs]class Lookahead(Optimizer, OptimizerBase): def __init__(self, optimizer, k=5, alpha=0.5): self.optimizer = optimizer self.k = k self.alpha = alpha self.param_groups = self.optimizer.param_groups self.state = defaultdict(dict) self.fast_state = self.optimizer.state for group in self.param_groups: group["counter"] = 0
[docs] def update(self, group): for fast in group["params"]: param_state = self.state[fast] if "slow_param" not in param_state: param_state["slow_param"] = torch.zeros_like(fast.data) param_state["slow_param"].copy_(fast.data) slow = param_state["slow_param"] slow += (fast.data - slow) * self.alpha fast.data.copy_(slow)
[docs] def update_lookahead(self): for group in self.param_groups: self.update(group)
[docs] def step(self, closure=None): loss = self.optimizer.step(closure) for group in self.param_groups: if group["counter"] == 0: self.update(group) group["counter"] += 1 if group["counter"] >= self.k: group["counter"] = 0 return loss
[docs] def state_dict(self): fast_state_dict = self.optimizer.state_dict() slow_state = {(id(k) if isinstance(k, torch.Tensor) else k): v for k, v in self.state.items()} fast_state = fast_state_dict["state"] param_groups = fast_state_dict["param_groups"] return {"fast_state": fast_state, "slow_state": slow_state, "param_groups": param_groups, }
[docs] def load_state_dict(self, state_dict): slow_state_dict = {"state": state_dict["slow_state"], "param_groups": state_dict["param_groups"], } fast_state_dict = {"state": state_dict["fast_state"], "param_groups": state_dict["param_groups"], } super(Lookahead, self).load_state_dict(slow_state_dict) self.optimizer.load_state_dict(fast_state_dict) self.fast_state = self.optimizer.state
[docs] def add_param_group(self, param_group): param_group["counter"] = 0 self.optimizer.add_param_group(param_group)
[docs]class Ranger(Optimizer, OptimizerBase): ''' https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer/blob/master/ranger/ranger.py ''' def __init__(self, params, lr=1e-3, alpha=0.5, k=6, N_sma_threshhold=5, betas=(.95,0.999), eps=1e-5, weight_decay=0): #parameter checks if not 0.0 <= alpha <= 1.0: raise ValueError('Invalid slow update rate: {alpha}') if not 1 <= k: raise ValueError('Invalid lookahead steps: {k}') if not lr > 0: raise ValueError('Invalid Learning Rate: {lr}') if not eps > 0: raise ValueError('Invalid eps: {eps}') #parameter comments: # beta1 (momentum) of .95 seems to work better than .90... #N_sma_threshold of 5 seems better in testing than 4. #In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you. #prep defaults and init torch.optim base defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas, N_sma_threshhold=N_sma_threshhold, eps=eps, weight_decay=weight_decay) super().__init__(params,defaults) #adjustable threshold self.N_sma_threshhold = N_sma_threshhold #now we can get to work... #removed as we now use step from RAdam...no need for duplicate step counting #for group in self.param_groups: # group["step_counter"] = 0 #print("group step counter init") #look ahead params self.alpha = alpha self.k = k #radam buffer for state self.radam_buffer = [[None,None,None] for ind in range(10)] #self.first_run_check=0 #lookahead weights #9/2/19 - lookahead param tensors have been moved to state storage. #This should resolve issues with load/save where weights were left in GPU memory from first load, slowing down future runs. #self.slow_weights = [[p.clone().detach() for p in group['params']] # for group in self.param_groups] #don't use grad for lookahead weights #for w in it.chain(*self.slow_weights): # w.requires_grad = False def __setstate__(self, state): print("set state called") super(Ranger, self).__setstate__(state)
[docs] def step(self, closure=None): loss = None #note - below is commented out b/c I have other work that passes back the loss as a float, and thus not a callable closure. #Uncomment if you need to use the actual closure... #if closure is not None: #loss = closure() #Evaluate averages and grad, update param tensors for group in self.param_groups: for p in group['params']: if p.grad is None: continue grad = p.grad.data.float() if grad.is_sparse: raise RuntimeError('Ranger optimizer does not support sparse gradients') p_data_fp32 = p.data.float() state = self.state[p] #get state dict for this param if len(state) == 0: #if first time to run...init dictionary with our desired entries #if self.first_run_check==0: #self.first_run_check=1 #print("Initializing slow buffer...should not see this at load from saved model!") state['step'] = 0 state['exp_avg'] = torch.zeros_like(p_data_fp32) state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) #look ahead weight storage now in state dict state['slow_buffer'] = torch.empty_like(p.data) state['slow_buffer'].copy_(p.data) else: state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) #begin computations exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] beta1, beta2 = group['betas'] #compute variance mov avg exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) #compute mean moving avg exp_avg.mul_(beta1).add_(1 - beta1, grad) state['step'] += 1 buffered = self.radam_buffer[int(state['step'] % 10)] if state['step'] == buffered[0]: N_sma, step_size = buffered[1], buffered[2] else: buffered[0] = state['step'] beta2_t = beta2 ** state['step'] N_sma_max = 2 / (1 - beta2) - 1 N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) buffered[1] = N_sma if N_sma > self.N_sma_threshhold: step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) else: step_size = 1.0 / (1 - beta1 ** state['step']) buffered[2] = step_size if group['weight_decay'] != 0: p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) if N_sma > self.N_sma_threshhold: denom = exp_avg_sq.sqrt().add_(group['eps']) p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) else: p_data_fp32.add_(-step_size * group['lr'], exp_avg) p.data.copy_(p_data_fp32) #integrated look ahead... #we do it at the param level instead of group level if state['step'] % group['k'] == 0: slow_p = state['slow_buffer'] #get access to slow param tensor slow_p.add_(self.alpha, p.data - slow_p) #(fast weights - slow weights) * alpha p.data.copy_(slow_p) #copy interpolated weights to RAdam param tensor return loss
[docs]def get_optimizer(optimizer_name): if optimizer_name is None: return None optimizer_modules = ['trident.optims.pytorch_optimizers','torch.optim'] if optimizer_name in __all__: optimizer_class = get_class(optimizer_name, optimizer_modules) return optimizer_class else: try: optimizer_class = get_class(snake2camel(optimizer_name), optimizer_modules) return optimizer_class except Exception : optimizer_class = get_class(optimizer_name, optimizer_modules) return optimizer_class