Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Update optimizers for TF2. (#13246)
1. remove epsilon and decay for all optimizers
2. add iterations into weight list into RMSProp, Adagrad, Adadelta,
Nadam
  • Loading branch information
tanzhenyu authored and fchollet committed Aug 26, 2019
1 parent 387aea3 commit 5446255
Showing 1 changed file with 74 additions and 66 deletions.
140 changes: 74 additions & 66 deletions keras/optimizers.py
Expand Up @@ -6,6 +6,7 @@

import six
import copy
import numpy as np
from six.moves import zip

from . import backend as K
Expand Down Expand Up @@ -170,20 +171,19 @@ class SGD(Optimizer):
learning_rate: float >= 0. Learning rate.
momentum: float >= 0. Parameter that accelerates SGD
in the relevant direction and dampens oscillations.
decay: float >= 0. Learning rate decay over each update.
nesterov: boolean. Whether to apply Nesterov momentum.
"""

def __init__(self, learning_rate=0.01, momentum=0., decay=0.,
def __init__(self, learning_rate=0.01, momentum=0.,
nesterov=False, **kwargs):
learning_rate = kwargs.pop('lr', None) or learning_rate
learning_rate = kwargs.pop('lr', learning_rate)
self.initial_decay = kwargs.pop('decay', 0.0)
super(SGD, self).__init__(**kwargs)
with K.name_scope(self.__class__.__name__):
self.iterations = K.variable(0, dtype='int64', name='iterations')
self.learning_rate = K.variable(learning_rate, name='learning_rate')
self.momentum = K.variable(momentum, name='momentum')
self.decay = K.variable(decay, name='decay')
self.initial_decay = decay
self.decay = K.variable(self.initial_decay, name='decay')
self.nesterov = nesterov

@interfaces.legacy_get_updates_support
Expand Down Expand Up @@ -236,27 +236,22 @@ class RMSprop(Optimizer):
# Arguments
learning_rate: float >= 0. Learning rate.
rho: float >= 0.
epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`.
decay: float >= 0. Learning rate decay over each update.
# References
- [rmsprop: Divide the gradient by a running average of its recent magnitude
](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
"""

def __init__(self, learning_rate=0.001, rho=0.9, epsilon=None, decay=0.,
**kwargs):
learning_rate = kwargs.pop('lr', None) or learning_rate
def __init__(self, learning_rate=0.001, rho=0.9, **kwargs):
self.initial_decay = kwargs.pop('decay', 0.0)
self.epsilon = kwargs.pop('epsilon', K.epsilon())
learning_rate = kwargs.pop('lr', learning_rate)
super(RMSprop, self).__init__(**kwargs)
with K.name_scope(self.__class__.__name__):
self.learning_rate = K.variable(learning_rate, name='learning_rate')
self.rho = K.variable(rho, name='rho')
self.decay = K.variable(decay, name='decay')
self.decay = K.variable(self.initial_decay, name='decay')
self.iterations = K.variable(0, dtype='int64', name='iterations')
if epsilon is None:
epsilon = K.epsilon()
self.epsilon = epsilon
self.initial_decay = decay

@interfaces.legacy_get_updates_support
@K.symbolic
Expand All @@ -266,7 +261,7 @@ def get_updates(self, loss, params):
dtype=K.dtype(p),
name='accumulator_' + str(i))
for (i, p) in enumerate(params)]
self.weights = accumulators
self.weights = [self.iterations] + accumulators
self.updates = [K.update_add(self.iterations, 1)]

lr = self.learning_rate
Expand All @@ -287,6 +282,15 @@ def get_updates(self, loss, params):
self.updates.append(K.update(p, new_p))
return self.updates

def set_weights(self, weights):
params = self.weights
# Override set_weights for backward compatibility of Keras 2.2.4 optimizer
# since it does not include iteration at head of the weight list. Set
# iteration to 0.
if len(params) == len(weights) + 1:
weights = [np.array(0)] + weights
super(RMSprop, self).set_weights(weights)

def get_config(self):
config = {'learning_rate': float(K.get_value(self.learning_rate)),
'rho': float(K.get_value(self.rho)),
Expand All @@ -309,25 +313,21 @@ class Adagrad(Optimizer):
# Arguments
learning_rate: float >= 0. Initial learning rate.
epsilon: float >= 0. If `None`, defaults to `K.epsilon()`.
decay: float >= 0. Learning rate decay over each update.
# References
- [Adaptive Subgradient Methods for Online Learning and Stochastic
Optimization](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
"""

def __init__(self, learning_rate=0.01, epsilon=None, decay=0., **kwargs):
learning_rate = kwargs.pop('lr', None) or learning_rate
def __init__(self, learning_rate=0.01, **kwargs):
self.initial_decay = kwargs.pop('decay', 0.0)
self.epsilon = kwargs.pop('epsilon', K.epsilon())
learning_rate = kwargs.pop('lr', learning_rate)
super(Adagrad, self).__init__(**kwargs)
with K.name_scope(self.__class__.__name__):
self.learning_rate = K.variable(learning_rate, name='learning_rate')
self.decay = K.variable(decay, name='decay')
self.decay = K.variable(self.initial_decay, name='decay')
self.iterations = K.variable(0, dtype='int64', name='iterations')
if epsilon is None:
epsilon = K.epsilon()
self.epsilon = epsilon
self.initial_decay = decay

@interfaces.legacy_get_updates_support
@K.symbolic
Expand All @@ -336,7 +336,7 @@ def get_updates(self, loss, params):
shapes = [K.int_shape(p) for p in params]
accumulators = [K.zeros(shape, name='accumulator_' + str(i))
for (i, shape) in enumerate(shapes)]
self.weights = accumulators
self.weights = [self.iterations] + accumulators
self.updates = [K.update_add(self.iterations, 1)]

lr = self.learning_rate
Expand All @@ -356,6 +356,15 @@ def get_updates(self, loss, params):
self.updates.append(K.update(p, new_p))
return self.updates

def set_weights(self, weights):
params = self.weights
# Override set_weights for backward compatibility of Keras 2.2.4 optimizer
# since it does not include iteration at head of the weight list. Set
# iteration to 0.
if len(params) == len(weights) + 1:
weights = [np.array(0)] + weights
super(Adagrad, self).set_weights(weights)

def get_config(self):
config = {'learning_rate': float(K.get_value(self.learning_rate)),
'decay': float(K.get_value(self.decay)),
Expand Down Expand Up @@ -383,27 +392,22 @@ class Adadelta(Optimizer):
It is recommended to leave it at the default value.
rho: float >= 0. Adadelta decay factor, corresponding to fraction of
gradient to keep at each time step.
epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`.
decay: float >= 0. Initial learning rate decay.
# References
- [Adadelta - an adaptive learning rate method](
https://arxiv.org/abs/1212.5701)
"""

def __init__(self, learning_rate=1.0, rho=0.95, epsilon=None, decay=0.,
**kwargs):
learning_rate = kwargs.pop('lr', None) or learning_rate
def __init__(self, learning_rate=1.0, rho=0.95, **kwargs):
self.initial_decay = kwargs.pop('decay', 0.0)
self.epsilon = kwargs.pop('epsilon', K.epsilon())
learning_rate = kwargs.pop('lr', learning_rate)
super(Adadelta, self).__init__(**kwargs)
with K.name_scope(self.__class__.__name__):
self.learning_rate = K.variable(learning_rate, name='learning_rate')
self.decay = K.variable(decay, name='decay')
self.decay = K.variable(self.initial_decay, name='decay')
self.iterations = K.variable(0, dtype='int64', name='iterations')
if epsilon is None:
epsilon = K.epsilon()
self.rho = rho
self.epsilon = epsilon
self.initial_decay = decay

@interfaces.legacy_get_updates_support
@K.symbolic
Expand All @@ -414,7 +418,7 @@ def get_updates(self, loss, params):
for (i, shape) in enumerate(shapes)]
delta_accumulators = [K.zeros(shape, name='delta_accumulator_' + str(i))
for (i, shape) in enumerate(shapes)]
self.weights = accumulators + delta_accumulators
self.weights = [self.iterations] + accumulators + delta_accumulators
self.updates = [K.update_add(self.iterations, 1)]

lr = self.learning_rate
Expand Down Expand Up @@ -442,6 +446,15 @@ def get_updates(self, loss, params):
self.updates.append(K.update(d_a, new_d_a))
return self.updates

def set_weights(self, weights):
params = self.weights
# Override set_weights for backward compatibility of Keras 2.2.4 optimizer
# since it does not include iteration at head of the weight list. Set
# iteration to 0.
if len(params) == len(weights) + 1:
weights = [np.array(0)] + weights
super(Adadelta, self).set_weights(weights)

def get_config(self):
config = {'learning_rate': float(K.get_value(self.learning_rate)),
'rho': self.rho,
Expand All @@ -460,8 +473,6 @@ class Adam(Optimizer):
learning_rate: float >= 0. Learning rate.
beta_1: float, 0 < beta < 1. Generally close to 1.
beta_2: float, 0 < beta < 1. Generally close to 1.
epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`.
decay: float >= 0. Learning rate decay over each update.
amsgrad: boolean. Whether to apply the AMSGrad variant of this
algorithm from the paper "On the Convergence of Adam and
Beyond".
Expand All @@ -474,19 +485,17 @@ class Adam(Optimizer):
"""

def __init__(self, learning_rate=0.001, beta_1=0.9, beta_2=0.999,
epsilon=None, decay=0., amsgrad=False, **kwargs):
learning_rate = kwargs.pop('lr', None) or learning_rate
amsgrad=False, **kwargs):
self.initial_decay = kwargs.pop('decay', 0.0)
self.epsilon = kwargs.pop('epsilon', K.epsilon())
learning_rate = kwargs.pop('lr', learning_rate)
super(Adam, self).__init__(**kwargs)
with K.name_scope(self.__class__.__name__):
self.iterations = K.variable(0, dtype='int64', name='iterations')
self.learning_rate = K.variable(learning_rate, name='learning_rate')
self.beta_1 = K.variable(beta_1, name='beta_1')
self.beta_2 = K.variable(beta_2, name='beta_2')
self.decay = K.variable(decay, name='decay')
if epsilon is None:
epsilon = K.epsilon()
self.epsilon = epsilon
self.initial_decay = decay
self.decay = K.variable(self.initial_decay, name='decay')
self.amsgrad = amsgrad

@interfaces.legacy_get_updates_support
Expand Down Expand Up @@ -565,28 +574,23 @@ class Adamax(Optimizer):
learning_rate: float >= 0. Learning rate.
beta_1: float, 0 < beta < 1. Generally close to 1.
beta_2: float, 0 < beta < 1. Generally close to 1.
epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`.
decay: float >= 0. Learning rate decay over each update.
# References
- [Adam - A Method for Stochastic Optimization](
https://arxiv.org/abs/1412.6980v8)
"""

def __init__(self, learning_rate=0.002, beta_1=0.9, beta_2=0.999,
epsilon=None, decay=0., **kwargs):
learning_rate = kwargs.pop('lr', None) or learning_rate
def __init__(self, learning_rate=0.002, beta_1=0.9, beta_2=0.999, **kwargs):
self.initial_decay = kwargs.pop('decay', 0.0)
self.epsilon = kwargs.pop('epsilon', K.epsilon())
learning_rate = kwargs.pop('lr', learning_rate)
super(Adamax, self).__init__(**kwargs)
with K.name_scope(self.__class__.__name__):
self.iterations = K.variable(0, dtype='int64', name='iterations')
self.learning_rate = K.variable(learning_rate, name='learning_rate')
self.beta_1 = K.variable(beta_1, name='beta_1')
self.beta_2 = K.variable(beta_2, name='beta_2')
self.decay = K.variable(decay, name='decay')
if epsilon is None:
epsilon = K.epsilon()
self.epsilon = epsilon
self.initial_decay = decay
self.decay = K.variable(self.initial_decay, name='decay')

@interfaces.legacy_get_updates_support
@K.symbolic
Expand Down Expand Up @@ -652,29 +656,24 @@ class Nadam(Optimizer):
learning_rate: float >= 0. Learning rate.
beta_1: float, 0 < beta < 1. Generally close to 1.
beta_2: float, 0 < beta < 1. Generally close to 1.
epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`.
schedule_decay: float, 0 < schedule_decay < 1.
# References
- [Nadam report](http://cs229.stanford.edu/proj2015/054_report.pdf)
- [On the importance of initialization and momentum in deep learning](
http://www.cs.toronto.edu/~fritz/absps/momentum.pdf)
"""

def __init__(self, learning_rate=0.002, beta_1=0.9, beta_2=0.999,
epsilon=None, schedule_decay=0.004, **kwargs):
learning_rate = kwargs.pop('lr', None) or learning_rate
def __init__(self, learning_rate=0.002, beta_1=0.9, beta_2=0.999, **kwargs):
self.schedule_decay = kwargs.pop('schedule_decay', 0.004)
self.epsilon = kwargs.pop('epsilon', K.epsilon())
learning_rate = kwargs.pop('lr', learning_rate)
super(Nadam, self).__init__(**kwargs)
with K.name_scope(self.__class__.__name__):
self.iterations = K.variable(0, dtype='int64', name='iterations')
self.m_schedule = K.variable(1., name='m_schedule')
self.learning_rate = K.variable(learning_rate, name='learning_rate')
self.beta_1 = K.variable(beta_1, name='beta_1')
self.beta_2 = K.variable(beta_2, name='beta_2')
if epsilon is None:
epsilon = K.epsilon()
self.epsilon = epsilon
self.schedule_decay = schedule_decay

@interfaces.legacy_get_updates_support
@K.symbolic
Expand All @@ -699,7 +698,7 @@ def get_updates(self, loss, params):
vs = [K.zeros(shape, name='v_' + str(i))
for (i, shape) in enumerate(shapes)]

self.weights = [self.iterations] + ms + vs
self.weights = [self.iterations, self.m_schedule] + ms + vs

for p, g, m, v in zip(params, grads, ms, vs):
# the following equations given in [1]
Expand All @@ -725,6 +724,15 @@ def get_updates(self, loss, params):
self.updates.append(K.update(p, new_p))
return self.updates

def set_weights(self, weights):
params = self.weights
# Override set_weights for backward compatibility of Keras 2.2.4 optimizer
# since it does not include m_schedule at head of the weight list. Set
# m_schedule to 1.
if len(params) == len(weights) + 1:
weights = [weights[0]] + [np.array(1.)] + weights[1:]
super(Nadam, self).set_weights(weights)

def get_config(self):
config = {'learning_rate': float(K.get_value(self.learning_rate)),
'beta_1': float(K.get_value(self.beta_1)),
Expand Down

0 comments on commit 5446255

Please sign in to comment.