Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRonan Collobert <ronan@collobert.com>2012-01-25 17:55:20 +0400
committerRonan Collobert <ronan@collobert.com>2012-01-25 17:55:20 +0400
commit4df3893abd1b9f840f1d9a8c1859799ccbf941de (patch)
treee8a1e1cc1b6ea6e47855347b157eaf419fdb357b /LookupTable.lua
initial revamp of torch7 tree
Diffstat (limited to 'LookupTable.lua')
-rw-r--r--LookupTable.lua76
1 files changed, 76 insertions, 0 deletions
diff --git a/LookupTable.lua b/LookupTable.lua
new file mode 100644
index 0000000..115f19c
--- /dev/null
+++ b/LookupTable.lua
@@ -0,0 +1,76 @@
+local LookupTable, parent = torch.class('nn.LookupTable', 'nn.Module')
+
+LookupTable.__version = 2
+
+function LookupTable:__init(nIndex, ...)
+ parent.__init(self)
+
+ if select('#', ...) == 1 and type(select(1, ...)) ~= "number" then
+ local size = select(1, ...)
+ self.size = torch.LongStorage(#size + 1)
+ for i=1,#size do
+ self.size[i+1] = size[i]
+ end
+ else
+ self.size = torch.LongStorage(select('#', ...)+1)
+ for i=1,select('#',...) do
+ self.size[i+1] = select(i, ...)
+ end
+ end
+
+ self.size[1] = nIndex
+ self.weight = torch.Tensor(self.size)
+ self.gradWeight = torch.Tensor(self.size):zero()
+ self.inputs = {}
+
+ self:reset()
+end
+
+function LookupTable:reset(stdv)
+ stdv = stdv or 1
+ self.weight:apply(function()
+ return torch.normal(0, stdv)
+ end)
+end
+
+function LookupTable:updateOutput(input)
+ local nIndex = input:size(1)
+ self.size[1] = nIndex
+ self.output:resize(self.size)
+
+ for i=1,nIndex do
+ self.output:select(1, i):copy(self.weight:select(1, input[i]))
+ end
+
+ return self.output
+end
+
+function LookupTable:zeroGradParameters()
+ for k,_ in pairs(self.inputs) do
+ self.gradWeight:select(1, k):zero()
+ end
+ self.inputs = {}
+end
+
+function LookupTable:accGradParameters(input, gradOutput, scale)
+ for i=1,input:size(1) do
+ local k = input[i]
+ self.inputs[k] = true
+ self.gradWeight:select(1, k):add(scale, gradOutput:select(1, i))
+ end
+end
+
+function LookupTable:accUpdateGradParameters(input, gradOutput, lr)
+ for i=1,input:size(1) do
+ self.weight:select(1, input[i]):add(-lr, gradOutput:select(1, i))
+ end
+end
+
+function LookupTable:updateParameters(learningRate)
+ for k,_ in pairs(self.inputs) do
+ self.weight:select(1, k):add(-learningRate, self.gradWeight:select(1, k))
+ end
+end
+
+-- we do not need to accumulate parameters when sharing
+LookupTable.sharedAccUpdateGradParameters = LookupTable.accUpdateGradParameters