diff options
author | Nicholas Leonard <nick@nikopia.org> | 2014-04-10 06:40:13 +0400 |
---|---|---|
committer | Nicholas Leonard <nick@nikopia.org> | 2014-04-10 06:40:13 +0400 |
commit | e5fbc5b3ee23207978b16f017153b4a67b98fcf1 (patch) | |
tree | 7cad48ccc05696880e4648ba5dcc84c761bd0654 | |
parent | 6783768eb6e06ebbf0c1bd2d6b5f4bc9f709b32f (diff) |
unit tests complete
-rw-r--r-- | LookupTable.lua | 21 | ||||
-rw-r--r-- | test/test.lua | 36 |
2 files changed, 30 insertions, 27 deletions
diff --git a/LookupTable.lua b/LookupTable.lua index 7db20f8..989bcdf 100644 --- a/LookupTable.lua +++ b/LookupTable.lua @@ -72,7 +72,9 @@ function LookupTable:updateOutput(input) for i=1,nExample do local output = self.output:select(1, i) + local input = input:select(1, i) for j=1,nIndex do + --print('test', i, j, input[j], output:size(), self.weight:size()) output:select(1, j):copy(self.weight:select(1, input[j])) end end @@ -105,7 +107,7 @@ function LookupTable:accGradParameters(input, gradOutput, scale) local gradOutput = gradOutput:select(1, i) for j=1,input:size(1) do local k = input[j] - self.input[k] = (self.inputs[k] or 0) + 1 + self.inputs[k] = (self.inputs[k] or 0) + 1 self.gradWeight:select(1, k):add(scale, gradOutput:select(1, j)) end end @@ -115,15 +117,24 @@ end function LookupTable:accUpdateGradParameters(input, gradOutput, lr) if input:dim() == 1 then for i=1,input:size(1) do - self.weight:select(1, input[i]):add(-lr, gradOutput:select(1, i)) + local k = input[j] + local scale = 1 + if self.fairScale then + scale = self:getFairScale(self.inputs[k]) + end + self.weight:select(1, input[i]):add(-lr*scale, gradOutput:select(1, i)) end elseif input:dim() == 2 then for i=1,input:size(1) do local input = input:select(1, i) local gradOutput = gradOutput:select(1, i) - for j=1,input:size(2) do - scale = self:getFairScale(nBackward) - self.weight:select(1, input[j]):add(-lr*scale, gradOutput:select(1, j)) + for j=1,input:size(1) do + local k = input[j] + local scale = 1 + if self.fairScale then + scale = self:getFairScale(self.inputs[k]) + end + self.weight:select(1, k):add(-lr*scale, gradOutput:select(1, j)) end end end diff --git a/test/test.lua b/test/test.lua index 91b38e8..cea1a53 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1650,52 +1650,44 @@ function nntest.LookupTable() local totalIndex = math.random(10,100) local nIndex = math.random(5,7) local entry_size = math.random(5,7) - local input = torch.Tensor(nIndex):zero() + local input = torch.IntTensor(nIndex):zero() local module = nn.LookupTable(totalIndex, entry_size) local minval = 1 local maxval = totalIndex -- 1D local err = jac.testJacobianParameters(module, input, module.weight, module.gradWeight, minval, maxval) - mytester:assertlt(err,precision, 'error on weight ') + mytester:assertlt(err,precision, '1D error on weight ') local err = jac.testJacobianUpdateParameters(module, input, module.weight, minval, maxval) - mytester:assertlt(err,precision, 'error on weight [direct update] ') + mytester:assertlt(err,precision, '1D error on weight [direct update] ') module.gradWeight:zero() for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do mytester:assertlt(err, precision, string.format( - 'error on weight [%s]', t)) + '1D error on weight [%s]', t)) end -- 2D local nframe = math.random(50,70) - local input = torch.Tensor(nframe, nIndex):zero() - - local err = jac.testJacobianParameters(module, input, module.weight, module.gradWeight) - mytester:assertlt(err,precision, 'error on weight ') - - local err = jac.testJacobianParameters(module, input, module.bias, module.gradBias) - mytester:assertlt(err,precision, 'error on weight ') - - local err = jac.testJacobianUpdateParameters(module, input, module.weight) - mytester:assertlt(err,precision, 'error on weight [direct update] ') + local input = torch.IntTensor(nframe, nIndex):zero() - local err = jac.testJacobianUpdateParameters(module, input, module.bias) - mytester:assertlt(err,precision, 'error on bias [direct update] ') + local err = jac.testJacobianParameters(module, input, module.weight, module.gradWeight, minval, maxval) + mytester:assertlt(err,precision, '2D error on weight ') + + local err = jac.testJacobianUpdateParameters(module, input, module.weight, minval, maxval) + mytester:assertlt(err,precision, '2D error on weight [direct update] ') + module.gradWeight:zero() for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do mytester:assertlt(err, precision, string.format( - 'error on weight [%s]', t)) + '2D error on weight [%s]', t)) end - for t,err in pairs(jac.testAllUpdate(module, input, 'bias', 'gradBias')) do - mytester:assertlt(err, precision, string.format( - 'error on bias [%s]', t)) - end -- IO - local ferr,berr = jac.testIO(module,input) + module.gradInput = torch.Tensor(3,4):zero() --fixes an error + local ferr,berr = jac.testIO(module,input,minval,maxval) mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ') mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') end |