Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--LookupTable.lua38
-rw-r--r--test/test.lua1
2 files changed, 11 insertions, 28 deletions
diff --git a/LookupTable.lua b/LookupTable.lua
index 989bcdf..71511d4 100644
--- a/LookupTable.lua
+++ b/LookupTable.lua
@@ -118,11 +118,8 @@ function LookupTable:accUpdateGradParameters(input, gradOutput, lr)
if input:dim() == 1 then
for i=1,input:size(1) do
local k = input[j]
- local scale = 1
- if self.fairScale then
- scale = self:getFairScale(self.inputs[k])
- end
- self.weight:select(1, input[i]):add(-lr*scale, gradOutput:select(1, i))
+ local kscale = self:scaleUpdateByKey(k)
+ self.weight:select(1, input[i]):add(-lr*kscale, gradOutput:select(1, i))
end
elseif input:dim() == 2 then
for i=1,input:size(1) do
@@ -130,37 +127,24 @@ function LookupTable:accUpdateGradParameters(input, gradOutput, lr)
local gradOutput = gradOutput:select(1, i)
for j=1,input:size(1) do
local k = input[j]
- local scale = 1
- if self.fairScale then
- scale = self:getFairScale(self.inputs[k])
- end
- self.weight:select(1, k):add(-lr*scale, gradOutput:select(1, j))
+ local kscale = self:scaleUpdateByKey(k)
+ self.weight:select(1, k):add(-lr*kscale, gradOutput:select(1, j))
end
end
end
end
function LookupTable:updateParameters(learningRate)
- if not self.fairScale then
- for k,_ in pairs(self.inputs) do
- self.weight:select(1, k):add(-learningRate, self.gradWeight:select(1, k))
- end
- else
- for k,nBackward in pairs(self.inputs) do
- scale = self:getFairScale(nBackward)
- self.weight:select(1, k):add(-learningRate*scale, self.gradWeight:select(1, k))
- end
+ for k,nBackward in pairs(self.inputs) do
+ local kscale = self:scaleUpdateByKey(k)
+ self.weight:select(1, k):add(-learningRate*kscale, self.gradWeight:select(1, k))
end
end
-function LookupTable:getFairScale(nBackward)
- local scale
- if self.batchScaled then
- scale = self.nBackward/nBackward
- else
- scale = 1/nBackward
- end
- return scale
+-- scale the update for each key
+function LookupTable:scaleUpdateByKey(inputKey)
+ -- default is to perform no key-based scalling
+ return 1
end
-- we do not need to accumulate parameters when sharing
diff --git a/test/test.lua b/test/test.lua
index cea1a53..b342c36 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -1684,7 +1684,6 @@ function nntest.LookupTable()
'2D error on weight [%s]', t))
end
-
-- IO
module.gradInput = torch.Tensor(3,4):zero() --fixes an error
local ferr,berr = jac.testIO(module,input,minval,maxval)