Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2014-11-27 02:35:33 +0300
committerSoumith Chintala <soumith@gmail.com>2014-11-27 02:35:33 +0300
commit4415d82948b0cf8317e7a5ba39b47f31eda4bccf (patch)
tree8610594d1fce22bab2eb42c8fbae1439a65189f1
parent70f542492cbde0cf7f16afc915a3b8b674b77bd0 (diff)
parentbdf79811221f2b814454dd29f2a44096dc4d82ba (diff)
Merge pull request #114 from nicholas-leonard/LookupTable
LookupTable + Concat small fixes
-rw-r--r--Concat.lua2
-rw-r--r--LookupTable.lua2
2 files changed, 3 insertions, 1 deletions
diff --git a/Concat.lua b/Concat.lua
index 5743af9..c94808d 100644
--- a/Concat.lua
+++ b/Concat.lua
@@ -142,7 +142,7 @@ function Concat:__tostring__()
local ext = ' | '
local extlast = ' '
local last = ' ... -> '
- local str = 'nn.Concat'
+ local str = torch.type(self)
str = str .. ' {' .. line .. tab .. 'input'
for i=1,#self.modules do
if i == self.modules then
diff --git a/LookupTable.lua b/LookupTable.lua
index 71d7f62..5b5f565 100644
--- a/LookupTable.lua
+++ b/LookupTable.lua
@@ -117,6 +117,7 @@ function LookupTable:accUpdateGradParameters(input, gradOutput, lr)
for i=1,input:size(1) do
local k = input[i]
local kscale = self:scaleUpdateByKey(k)
+ self.inputs[k] = (self.inputs[k] or 0) + 1
self.weight:select(1, input[i]):add(-lr*kscale, gradOutput:select(1, i))
end
elseif input:dim() == 2 then
@@ -126,6 +127,7 @@ function LookupTable:accUpdateGradParameters(input, gradOutput, lr)
for j=1,input:size(1) do
local k = input[j]
local kscale = self:scaleUpdateByKey(k)
+ self.inputs[k] = (self.inputs[k] or 0) + 1
self.weight:select(1, k):add(-lr*kscale, gradOutput:select(1, j))
end
end