diff options
author | Nicholas Leonard <nick@nikopia.org> | 2014-11-27 02:10:40 +0300 |
---|---|---|
committer | Nicholas Leonard <nick@nikopia.org> | 2014-11-27 02:10:40 +0300 |
commit | bdf79811221f2b814454dd29f2a44096dc4d82ba (patch) | |
tree | 8610594d1fce22bab2eb42c8fbae1439a65189f1 | |
parent | cc7ce5c95ebde85126039eb203fd8e60c628d521 (diff) |
LookupTable keeps track of inputs for accUpdate
-rw-r--r-- | LookupTable.lua | 2 |
1 files changed, 2 insertions, 0 deletions
diff --git a/LookupTable.lua b/LookupTable.lua index 71d7f62..5b5f565 100644 --- a/LookupTable.lua +++ b/LookupTable.lua @@ -117,6 +117,7 @@ function LookupTable:accUpdateGradParameters(input, gradOutput, lr) for i=1,input:size(1) do local k = input[i] local kscale = self:scaleUpdateByKey(k) + self.inputs[k] = (self.inputs[k] or 0) + 1 self.weight:select(1, input[i]):add(-lr*kscale, gradOutput:select(1, i)) end elseif input:dim() == 2 then @@ -126,6 +127,7 @@ function LookupTable:accUpdateGradParameters(input, gradOutput, lr) for j=1,input:size(1) do local k = input[j] local kscale = self:scaleUpdateByKey(k) + self.inputs[k] = (self.inputs[k] or 0) + 1 self.weight:select(1, k):add(-lr*kscale, gradOutput:select(1, j)) end end |