Signum optimizer #9220

Merged
merged 6 commits into from Jan 12, 2018

Conversation

Projects
None yet
4 participants
Contributor

yuxiangw commented Dec 28, 2017

Description

Added the C++ implementation of the Signum optimizer.

Bernstein, Wang, Azizzadenesheli and Anandkumar (2017) "The Signum optimiser: a theory of momentum in quantised stochastic optimisation"
Link to pdf

What's also included is the implementation of an option to do the alternative version of weight decay regularization due to Loshchilov and Hutter via option 'wd_lh'.
"Fixing Weight Decay Regularization in Adam"
Link to arxiv

Checklist

Essentials

  • Passed code style checking (make lint)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • Added the Signum optimizer to mxnet's list of optimizers
  • A special case is SignSGD optimizer, a stand-alone implementation whenever "momentum" is set to 0.

Comments

  • TODO1: add sparse matrix support for this optimizer
  • TODO2: Take advantage of the 1-bit gradient compression interpretation of SignSGD and Signum.
  • TODO3: Adding 'wd_lh' support for Adam and other adaptive gradient optimizers.

Thanks for the contribution!! Please see detailed comments in code.

python/mxnet/optimizer.py
@@ -57,6 +58,10 @@ class Optimizer(object):
The weight decay (or L2 regularization) coefficient. Modifies objective
by adding a penalty for having large weights.
+ wd_lh: float, optional
@eric-haibin-lin

eric-haibin-lin Jan 2, 2018

Contributor

I don't see a change in the Optimizer class constructor. Why is this changed?

@yuxiangw

yuxiangw Jan 4, 2018

Contributor

I added that to the constructor at some point, cuz wd_lh is something more generally applicable to other algorithms too (in particular, Adam).

@yuxiangw

yuxiangw Jan 4, 2018

Contributor

removed that line.

python/mxnet/optimizer.py
+
+@register
+class Signum(Optimizer):
+ """The SGD optimizer with momentum and weight decay.
@eric-haibin-lin

eric-haibin-lin Jan 2, 2018

Contributor

The one line summary should also mention it only takes the sign. Otherwise the readers don't know it until they see line 547

@yuxiangw

yuxiangw Jan 4, 2018

Contributor

added details to the doc accordingly.

python/mxnet/optimizer.py
+ momentum : float, optional
+ The momentum value.
+ wd_lh : float, optitional
+ The amount of decoupled weight decay regularization.
@eric-haibin-lin

eric-haibin-lin Jan 2, 2018

Contributor

Let's also add a reference/link to the original paper

@yuxiangw

yuxiangw Jan 4, 2018

Contributor

added the temp link to pdf hosted on jeremy's site. will update to arxiv or a published version when they are ready.

+ float lr;
+ float wd;
+ float rescale_grad;
+ float clip_gradient;
@eric-haibin-lin

eric-haibin-lin Jan 2, 2018

Contributor

If the clip_gradient param has no effect on both SignSGD and Signum, can we just remove this param from signsgd_update and signum_update? That would also simply the c++ kernels

@yuxiangw

yuxiangw Jan 4, 2018

Contributor

It has an effect on Signum. Because it will lead to different result whether we use gradient or clipped gradient for calculating momentum.

@eric-haibin-lin

eric-haibin-lin Jan 4, 2018

Contributor

Ah, I see. Thanks for the explanation!

src/operator/optimizer_op.cc
+
+NNVM_REGISTER_OP(signsgd_update)
+// MXNET_ADD_SPARSE_OP_ALIAS(signsgd_update)
+.describe(R"code(Update function for SignSGDoptimizer.
@eric-haibin-lin

eric-haibin-lin Jan 2, 2018

Contributor

nit: SignSGD optimizer

@yuxiangw

yuxiangw Jan 4, 2018

Contributor

done. and added the math description block similar to other optimizers.

src/operator/optimizer_op.cc
+ weight = weight - learning_rate * sign(gradient)
+
+
+** Sparse matrix not supported for this optimizer yet.
@eric-haibin-lin

eric-haibin-lin Jan 2, 2018

Contributor

Same comment for documentation rendering and FInferStorageType in signum_update

@yuxiangw

yuxiangw Jan 4, 2018

Contributor

done.

src/operator/optimizer_op.cc
+
+Where the parameter ``momentum`` is the decay rate of momentum estimates at each epoch.
+
+** Sparse matrix not supported for this optimizer yet.
@eric-haibin-lin

eric-haibin-lin Jan 2, 2018

Contributor

Not sure if sentence starting with ** renders well in API doc. What about adding a "note" section like rint?

.. note::
- For input ``n.5`` ``rint`` returns ``n`` while ``round`` returns ``n+1``.
- For input ``-n.5`` both ``rint`` and ``round`` returns ``-n-1``.

Also, term "sparse ndarray" instead of "sparse matrix" is preferred :)

@yuxiangw

yuxiangw Jan 4, 2018

Contributor

Done.

src/operator/optimizer_op.cc
+
+** Sparse matrix not supported for this optimizer yet.
+
+If weight and momentum are both of ``row_sparse`` storage type,
@eric-haibin-lin

eric-haibin-lin Jan 2, 2018

Contributor

I'd rather remove the line 81-87 since sparse update is not supported anyway.

@yuxiangw

yuxiangw Jan 4, 2018

Contributor

Done.

src/operator/optimizer_op.cc
+.set_attr_parser(ParamParser<SignumParam>)
+.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<3, 1>)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 1>)
+.set_attr<FInferStorageType>("FInferStorageType", ElemwiseStorageType<3, 1, false, true, false>)
@eric-haibin-lin

eric-haibin-lin Jan 2, 2018

Contributor

Since sparse implementation is missing for this operator, removing FInferStorageType registration is fine, it will by default infer all storage types as "default" (dense).

@yuxiangw

yuxiangw Jan 4, 2018

Contributor

Done.

src/operator/optimizer_op.cc
+ return std::vector<uint32_t>{2};
+ })
+.set_attr<FCompute>("FCompute<cpu>", SignumUpdate<cpu>)
+// .set_attr<FComputeEx>("FComputeEx<cpu>", SGDMomUpdateEx<cpu>)
@eric-haibin-lin

eric-haibin-lin Jan 2, 2018

Contributor

Please removed unused lines (also the ones in line 42, 65)

@yuxiangw

yuxiangw Jan 4, 2018

Contributor

Done.

src/operator/optimizer_op.cu
@@ -28,6 +28,14 @@
namespace mxnet {
namespace op {
+NNVM_REGISTER_OP(signsgd_update)
+.set_attr<FCompute>("FCompute<gpu>", SignSGDUpdate<gpu>);
+// .set_attr<FComputeEx>("FComputeEx<gpu>", SignSGDUpdateEx<gpu>);
@eric-haibin-lin

eric-haibin-lin Jan 2, 2018

Contributor

Could you remove unused lines?

@yuxiangw

yuxiangw Jan 4, 2018

Contributor

Done.

Contributor

eric-haibin-lin commented Jan 6, 2018

There are new conflicts now. Do you mind resolving them again?
BTW - the files under cpp-package are only needed if you use cpp as front end to train networks. Do you actually need it?

yuxiangw added some commits Dec 22, 2017

Contributor

yuxiangw commented Jan 8, 2018

Done fixing the conflicts.

Contributor

eric-haibin-lin commented Jan 8, 2018

@lx75249 could you help review the code for cpp-package?

Contributor

lx75249 commented Jan 9, 2018

+ signum_update(weight, grad, state, out=weight,
+ lr=lr, wd=wd, **kwargs)
+ else:
+ signsgd_update(weight, grad, out=weight,
@piiswrong

piiswrong Jan 9, 2018

Contributor

what's this?

@yuxiangw

yuxiangw Jan 9, 2018

Contributor

well, signsgd takes the sign of stochastic gradient.

+
+ rescaled_grad = rescale_grad * clip(grad, clip_gradient) + wd * weight
+ state = momentum * state + (1-momentum)*rescaled_grad
+ weight = (1 - lr * wd_lh) * weight - lr * sign(state)
@piiswrong

piiswrong Jan 9, 2018

Contributor

what's wd_lh? Is it from the original paper?

@yuxiangw

yuxiangw Jan 9, 2018

Contributor

It is an alternative weight decay. See the descriptions.

@eric-haibin-lin

eric-haibin-lin Jan 11, 2018

Contributor

Since wd_lh is new, I suggest put a reference link to the original paper by Loshchilov and Frank Hutter in the documentation

+ kwargs['wd_lh'] = self.wd_lh
+
+ if state is not None:
+ signum_update(weight, grad, state, out=weight,
@piiswrong

piiswrong Jan 9, 2018

Contributor

call these signum_momentum_update and signum_update to be consistent with others

@yuxiangw

yuxiangw Jan 9, 2018

Contributor

RE: naming.

  • signum means SIGN momentUM. So the semantics of the momentum is already in there. -
  • SignSGD is the special case of Signum that goes without momentum. And it has been used before.

Unless we change the names in our paper, let's keep them the way they are.

One final comment. Otherwise LGTM. Thanks for the contribution!

+
+ rescaled_grad = rescale_grad * clip(grad, clip_gradient) + wd * weight
+ state = momentum * state + (1-momentum)*rescaled_grad
+ weight = (1 - lr * wd_lh) * weight - lr * sign(state)
@piiswrong

piiswrong Jan 9, 2018

Contributor

what's wd_lh? Is it from the original paper?

@yuxiangw

yuxiangw Jan 9, 2018

Contributor

It is an alternative weight decay. See the descriptions.

@eric-haibin-lin

eric-haibin-lin Jan 11, 2018

Contributor

Since wd_lh is new, I suggest put a reference link to the original paper by Loshchilov and Frank Hutter in the documentation

Contributor

yuxiangw commented Jan 11, 2018

Added the reference the documentation as suggested. Thanks guys for reviewing the PR!

@piiswrong piiswrong merged commit 5251b86 into apache:master Jan 12, 2018

2 checks passed

continuous-integration/jenkins/pr-head This commit looks good
Details
continuous-integration/jenkins/pr-merge This commit looks good
Details
Contributor

piiswrong commented Jan 12, 2018

Thanks

CodingCat added a commit to CodingCat/mxnet that referenced this pull request Jan 16, 2018

Signum optimizer (#9220)
* the c++ version of signum and signsgd optimizer

* optimizer signum, tested working with mac on cpuusing mnist

* unit test for signum

* fix lint and incorporate haibin's code review

* rerun jenkins

* adding link to the Loshachilov and Hutter to the documentation
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment