diff options
-rw-r--r-- | LookupTable.lua | 38 | ||||
-rw-r--r-- | test/test.lua | 1 |
2 files changed, 11 insertions, 28 deletions
diff --git a/LookupTable.lua b/LookupTable.lua index 989bcdf..71511d4 100644 --- a/LookupTable.lua +++ b/LookupTable.lua @@ -118,11 +118,8 @@ function LookupTable:accUpdateGradParameters(input, gradOutput, lr) if input:dim() == 1 then for i=1,input:size(1) do local k = input[j] - local scale = 1 - if self.fairScale then - scale = self:getFairScale(self.inputs[k]) - end - self.weight:select(1, input[i]):add(-lr*scale, gradOutput:select(1, i)) + local kscale = self:scaleUpdateByKey(k) + self.weight:select(1, input[i]):add(-lr*kscale, gradOutput:select(1, i)) end elseif input:dim() == 2 then for i=1,input:size(1) do @@ -130,37 +127,24 @@ function LookupTable:accUpdateGradParameters(input, gradOutput, lr) local gradOutput = gradOutput:select(1, i) for j=1,input:size(1) do local k = input[j] - local scale = 1 - if self.fairScale then - scale = self:getFairScale(self.inputs[k]) - end - self.weight:select(1, k):add(-lr*scale, gradOutput:select(1, j)) + local kscale = self:scaleUpdateByKey(k) + self.weight:select(1, k):add(-lr*kscale, gradOutput:select(1, j)) end end end end function LookupTable:updateParameters(learningRate) - if not self.fairScale then - for k,_ in pairs(self.inputs) do - self.weight:select(1, k):add(-learningRate, self.gradWeight:select(1, k)) - end - else - for k,nBackward in pairs(self.inputs) do - scale = self:getFairScale(nBackward) - self.weight:select(1, k):add(-learningRate*scale, self.gradWeight:select(1, k)) - end + for k,nBackward in pairs(self.inputs) do + local kscale = self:scaleUpdateByKey(k) + self.weight:select(1, k):add(-learningRate*kscale, self.gradWeight:select(1, k)) end end -function LookupTable:getFairScale(nBackward) - local scale - if self.batchScaled then - scale = self.nBackward/nBackward - else - scale = 1/nBackward - end - return scale +-- scale the update for each key +function LookupTable:scaleUpdateByKey(inputKey) + -- default is to perform no key-based scalling + return 1 end -- we do not need to accumulate parameters when sharing diff --git a/test/test.lua b/test/test.lua index cea1a53..b342c36 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1684,7 +1684,6 @@ function nntest.LookupTable() '2D error on weight [%s]', t)) end - -- IO module.gradInput = torch.Tensor(3,4):zero() --fixes an error local ferr,berr = jac.testIO(module,input,minval,maxval) |