Welcome to mirror list, hosted at ThFree Co, Russian Federation.

LookupTable.lua - github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 7dfda7a34ac9937bbd6fa98a6c9e064120b9ca34 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
local LookupTable, parent = torch.class('nn.LookupTable', 'nn.Module')

LookupTable.__version = 2

function LookupTable:__init(nIndex, ...)
   parent.__init(self)
   local arg = {...}

   if select('#', ...) == 1 and type(arg[1]) ~= "number" then
      local size = arg[1]
      self.size = torch.LongStorage(#size + 1)
      for i=1,#size do
         self.size[i+1] = size[i]
      end
   else
      self.size = torch.LongStorage(select('#', ...)+1)
      for i=1,select('#',...) do
         self.size[i+1] = arg[i]
      end
   end

   self.size[1] = nIndex
   self.weight = torch.Tensor(self.size)
   self.gradWeight = torch.Tensor(self.size):zero()
   self.inputs = {}

   self:reset()
end

function LookupTable:reset(stdv)
   stdv = stdv or 1
   self.weight:apply(function()
                        return torch.normal(0, stdv)
                     end)
end

function LookupTable:updateOutput(input)
   local nIndex = input:size(1)
   self.size[1] = nIndex
   self.output:resize(self.size)

   for i=1,nIndex do
      self.output:select(1, i):copy(self.weight:select(1, input[i]))
   end

   return self.output
end

function LookupTable:zeroGradParameters()
   for k,_ in pairs(self.inputs) do
      self.gradWeight:select(1, k):zero()
   end
   self.inputs = {}
end

function LookupTable:accGradParameters(input, gradOutput, scale)
   for i=1,input:size(1) do
      local k = input[i]
      self.inputs[k] = true
      self.gradWeight:select(1, k):add(scale, gradOutput:select(1, i))
   end
end

function LookupTable:accUpdateGradParameters(input, gradOutput, lr)
   for i=1,input:size(1) do
      self.weight:select(1, input[i]):add(-lr, gradOutput:select(1, i))
   end
end

function LookupTable:updateParameters(learningRate)
   for k,_ in pairs(self.inputs) do
      self.weight:select(1, k):add(-learningRate, self.gradWeight:select(1, k))
   end
end

-- we do not need to accumulate parameters when sharing
LookupTable.sharedAccUpdateGradParameters = LookupTable.accUpdateGradParameters