diff options
author | Soumith Chintala <soumith@gmail.com> | 2014-11-27 02:35:33 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2014-11-27 02:35:33 +0300 |
commit | 4415d82948b0cf8317e7a5ba39b47f31eda4bccf (patch) | |
tree | 8610594d1fce22bab2eb42c8fbae1439a65189f1 | |
parent | 70f542492cbde0cf7f16afc915a3b8b674b77bd0 (diff) | |
parent | bdf79811221f2b814454dd29f2a44096dc4d82ba (diff) |
Merge pull request #114 from nicholas-leonard/LookupTable
LookupTable + Concat small fixes
-rw-r--r-- | Concat.lua | 2 | ||||
-rw-r--r-- | LookupTable.lua | 2 |
2 files changed, 3 insertions, 1 deletions
@@ -142,7 +142,7 @@ function Concat:__tostring__() local ext = ' | ' local extlast = ' ' local last = ' ... -> ' - local str = 'nn.Concat' + local str = torch.type(self) str = str .. ' {' .. line .. tab .. 'input' for i=1,#self.modules do if i == self.modules then diff --git a/LookupTable.lua b/LookupTable.lua index 71d7f62..5b5f565 100644 --- a/LookupTable.lua +++ b/LookupTable.lua @@ -117,6 +117,7 @@ function LookupTable:accUpdateGradParameters(input, gradOutput, lr) for i=1,input:size(1) do local k = input[i] local kscale = self:scaleUpdateByKey(k) + self.inputs[k] = (self.inputs[k] or 0) + 1 self.weight:select(1, input[i]):add(-lr*kscale, gradOutput:select(1, i)) end elseif input:dim() == 2 then @@ -126,6 +127,7 @@ function LookupTable:accUpdateGradParameters(input, gradOutput, lr) for j=1,input:size(1) do local k = input[j] local kscale = self:scaleUpdateByKey(k) + self.inputs[k] = (self.inputs[k] or 0) + 1 self.weight:select(1, k):add(-lr*kscale, gradOutput:select(1, j)) end end |