diff options
-rw-r--r-- | LookupTable.lua | 4 | ||||
-rw-r--r-- | test/test.lua | 54 |
2 files changed, 56 insertions, 2 deletions
diff --git a/LookupTable.lua b/LookupTable.lua index 8c23e82..7db20f8 100644 --- a/LookupTable.lua +++ b/LookupTable.lua @@ -22,7 +22,7 @@ function LookupTable:__init(nIndex, ...) self.size[1] = nIndex batchSize = torch.LongTensor(#self.size + 1) - batchSize:narrow(1, 2,#self.size):copy(self.size) + batchSize:narrow(1, 2,#self.size):copy(torch.LongTensor(self.size)) batchSize[1] = 1 self.batchSize = batchSize:storage() @@ -60,7 +60,6 @@ function LookupTable:updateOutput(input) local nIndex = input:size(1) self.size[1] = nIndex self.output:resize(self.size) - for i=1,nIndex do self.output:select(1, i):copy(self.weight:select(1, input[i])) end @@ -91,6 +90,7 @@ function LookupTable:zeroGradParameters() end function LookupTable:accGradParameters(input, gradOutput, scale) + scale = scale or 1 if input:dim() == 1 then self.nBackward = self.nBackward + 1 for i=1,input:size(1) do diff --git a/test/test.lua b/test/test.lua index b2b5f92..91b38e8 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1646,6 +1646,60 @@ function nntest.PairwiseDistance() end end +function nntest.LookupTable() + local totalIndex = math.random(10,100) + local nIndex = math.random(5,7) + local entry_size = math.random(5,7) + local input = torch.Tensor(nIndex):zero() + local module = nn.LookupTable(totalIndex, entry_size) + local minval = 1 + local maxval = totalIndex + + -- 1D + local err = jac.testJacobianParameters(module, input, module.weight, module.gradWeight, minval, maxval) + mytester:assertlt(err,precision, 'error on weight ') + + local err = jac.testJacobianUpdateParameters(module, input, module.weight, minval, maxval) + mytester:assertlt(err,precision, 'error on weight [direct update] ') + + module.gradWeight:zero() + for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do + mytester:assertlt(err, precision, string.format( + 'error on weight [%s]', t)) + end + + -- 2D + local nframe = math.random(50,70) + local input = torch.Tensor(nframe, nIndex):zero() + + local err = jac.testJacobianParameters(module, input, module.weight, module.gradWeight) + mytester:assertlt(err,precision, 'error on weight ') + + local err = jac.testJacobianParameters(module, input, module.bias, module.gradBias) + mytester:assertlt(err,precision, 'error on weight ') + + local err = jac.testJacobianUpdateParameters(module, input, module.weight) + mytester:assertlt(err,precision, 'error on weight [direct update] ') + + local err = jac.testJacobianUpdateParameters(module, input, module.bias) + mytester:assertlt(err,precision, 'error on bias [direct update] ') + + for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do + mytester:assertlt(err, precision, string.format( + 'error on weight [%s]', t)) + end + + for t,err in pairs(jac.testAllUpdate(module, input, 'bias', 'gradBias')) do + mytester:assertlt(err, precision, string.format( + 'error on bias [%s]', t)) + end + + -- IO + local ferr,berr = jac.testIO(module,input) + mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ') + mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') +end + mytester:add(nntest) if not nn then |