diff options
-rw-r--r-- | SparseJacobian.lua | 58 | ||||
-rw-r--r-- | test/test.lua | 10 |
2 files changed, 68 insertions, 0 deletions
diff --git a/SparseJacobian.lua b/SparseJacobian.lua index 3dd489d..b778e67 100644 --- a/SparseJacobian.lua +++ b/SparseJacobian.lua @@ -217,3 +217,61 @@ function nn.SparseJacobian.testIO(module,input, minval, maxval) local errb = bo - bo2 return errf:abs():max(), errb:abs():max() end + +function nn.SparseJacobian.testAllUpdate(module, input, weight, gradWeight) + local gradOutput + local lr = torch.uniform(0.1, 1) + local errors = {} + + -- accGradParameters + local maccgp = module:clone() + local weightc = maccgp[weight]:clone() + maccgp:forward(input) + gradOutput = torch.rand(maccgp.output:size()) + maccgp:zeroGradParameters() + maccgp:updateGradInput(input, gradOutput) + maccgp:accGradParameters(input, gradOutput) + maccgp:updateParameters(lr) + errors["accGradParameters"] = (weightc-maccgp[gradWeight]*lr-maccgp[weight]):norm() + + -- accUpdateGradParameters + local maccugp = module:clone() + maccugp:forward(input) + maccugp:updateGradInput(input, gradOutput) + maccugp:accUpdateGradParameters(input, gradOutput, lr) + errors["accUpdateGradParameters"] = (maccugp[weight]-maccgp[weight]):norm() + + -- shared, accGradParameters + local macsh1 = module:clone() + local macsh2 = module:clone() + macsh2:share(macsh1, weight) + macsh1:forward(input) + macsh2:forward(input) + macsh1:zeroGradParameters() + macsh2:zeroGradParameters() + macsh1:updateGradInput(input, gradOutput) + macsh2:updateGradInput(input, gradOutput) + macsh1:accGradParameters(input, gradOutput) + macsh2:accGradParameters(input, gradOutput) + macsh1:updateParameters(lr) + macsh2:updateParameters(lr) + local err = (weightc-maccgp[gradWeight]*(lr*2)-macsh1[weight]):norm() + err = err + (weightc-maccgp[gradWeight]*(lr*2)-macsh2[weight]):norm() + errors["accGradParameters [shared]"] = err + + -- shared, accUpdateGradParameters + local macshu1 = module:clone() + local macshu2 = module:clone() + macshu2:share(macshu1, weight) + macshu1:forward(input) + macshu2:forward(input) + macshu1:updateGradInput(input, gradOutput) + macshu2:updateGradInput(input, gradOutput) + macshu1:accUpdateGradParameters(input, gradOutput, lr) + macshu2:accUpdateGradParameters(input, gradOutput, lr) + local err = (weightc-maccgp[gradWeight]*(lr*2)-macshu1[weight]):norm() + err = err + (weightc-maccgp[gradWeight]*(lr*2)-macshu2[weight]):norm() + errors["accUpdateGradParameters [shared]"] = err + + return errors +end diff --git a/test/test.lua b/test/test.lua index eb6dede..df80616 100644 --- a/test/test.lua +++ b/test/test.lua @@ -350,6 +350,16 @@ function nntest.SparseLinear() local err = sjac.testJacobianUpdateParameters(module, input, module.bias) mytester:assertlt(err,precision, 'error on bias [direct update] ') + + for t,err in pairs(sjac.testAllUpdate(module, input, 'weight', 'gradWeight')) do + mytester:assertlt(err, precision, string.format( + 'error on weight [%s]', t)) + end + + for t,err in pairs(sjac.testAllUpdate(module, input, 'bias', 'gradBias')) do + mytester:assertlt(err, precision, string.format( + 'error on bias [%s]', t)) + end local ferr, berr = sjac.testIO(module, input) mytester:asserteq(0, ferr, torch.typename(module) .. ' - i/o forward err ') |