Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas Leonard <nick@nikopia.org>2014-11-27 02:10:40 +0300
committerNicholas Leonard <nick@nikopia.org>2014-11-27 02:10:40 +0300
commitbdf79811221f2b814454dd29f2a44096dc4d82ba (patch)
tree8610594d1fce22bab2eb42c8fbae1439a65189f1
parentcc7ce5c95ebde85126039eb203fd8e60c628d521 (diff)
LookupTable keeps track of inputs for accUpdate
-rw-r--r--LookupTable.lua2
1 files changed, 2 insertions, 0 deletions
diff --git a/LookupTable.lua b/LookupTable.lua
index 71d7f62..5b5f565 100644
--- a/LookupTable.lua
+++ b/LookupTable.lua
@@ -117,6 +117,7 @@ function LookupTable:accUpdateGradParameters(input, gradOutput, lr)
for i=1,input:size(1) do
local k = input[i]
local kscale = self:scaleUpdateByKey(k)
+ self.inputs[k] = (self.inputs[k] or 0) + 1
self.weight:select(1, input[i]):add(-lr*kscale, gradOutput:select(1, i))
end
elseif input:dim() == 2 then
@@ -126,6 +127,7 @@ function LookupTable:accUpdateGradParameters(input, gradOutput, lr)
for j=1,input:size(1) do
local k = input[j]
local kscale = self:scaleUpdateByKey(k)
+ self.inputs[k] = (self.inputs[k] or 0) + 1
self.weight:select(1, k):add(-lr*kscale, gradOutput:select(1, j))
end
end