Welcome to mirror list, hosted at ThFree Co, Russian Federation.

WeightNorm.lua - github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 3ffcd90aa0fe7cb9bc62e5cdc2bb59fd147b30ec (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
-- Weight Normalization
-- https://arxiv.org/pdf/1602.07868v3.pdf
local WeightNorm, parent = torch.class("nn.WeightNorm", "nn.Decorator")

function WeightNorm:__init(module, outputDim)
    -- this container will apply Weight Normalization to any module it wraps
    -- it accepts parameter ``outputDim`` that represents the dimension of the output of the weight
    -- if outputDim is not 1, the container will transpose the weight
    -- if the weight is not 2D, the container will view the weight into a 2D shape
    -- that is nOut x (nIn x kw x dw x ...)

    parent.__init(self, module)
    assert(module.weight)

    if module.bias then
        self.bias = module.bias
        self.gradBias = module.gradBias
    end
    self.gradWeight = module.gradWeight
    self.weight = module.weight

    self.outputDim = outputDim or 1

    -- track the non-output weight dimensions
    self.otherDims = 1
    for i = 1, self.weight:dim() do
        if i ~= self.outputDim then
            self.otherDims = self.otherDims * self.weight:size(i)
        end
    end

    -- view size for weight norm 2D calculations
    self.viewIn = torch.LongStorage({self.weight:size(self.outputDim), self.otherDims})

    -- view size back to original weight
    self.viewOut = self.weight:size()
    self.weightSize = self.weight:size()

    -- bubble outputDim size up to the front
    for i = self.outputDim - 1, 1, -1 do
        self.viewOut[i], self.viewOut[i + 1] = self.viewOut[i + 1], self.viewOut[i]
    end

    -- weight is reparametrized to decouple the length from the direction
    -- such that w = g * ( v / ||v|| )
    self.v = torch.Tensor(self.viewIn[1], self.viewIn[2])
    self.g = torch.Tensor(self.viewIn[1])

    self._norm = torch.Tensor(self.viewIn[1])
    self._scale = torch.Tensor(self.viewIn[1])

    -- gradient of g
    self.gradG = torch.Tensor(self.viewIn[1]):zero()
    -- gradient of v
    self.gradV = torch.Tensor(self.viewIn)

    self:resetInit()
end

function WeightNorm:permuteIn(inpt)
    local ans = inpt
    for i = self.outputDim - 1, 1, -1 do
        ans = ans:transpose(i, i+1)
    end
    return ans
end

function WeightNorm:permuteOut(inpt)
    local ans = inpt
    for i = 1, self.outputDim - 1 do
        ans = ans:transpose(i, i+1)
    end
    return ans
end

function WeightNorm:resetInit(inputSize, outputSize)
    self.v:normal(0, math.sqrt(2/self.viewIn[2]))
    self.g:norm(self.v, 2, 2)
    if self.bias then
        self.bias:zero()
    end
end

function WeightNorm:evaluate()
    if not(self.train == false) then
        self:updateWeight()
        parent.evaluate(self)
    end
end

function WeightNorm:updateWeight()
    -- view to 2D when weight norm container operates
    self.gradV:copy(self:permuteIn(self.weight))
    self.gradV = self.gradV:view(self.viewIn)

    -- ||w||
    self._norm:norm(self.v, 2, 2):pow(2):add(10e-5):sqrt()
    -- g * w / ||w||
    self.gradV:copy(self.v)
    self._scale:copy(self.g):cdiv(self._norm)
    self.gradV:cmul(self._scale:view(self.viewIn[1], 1)
                               :expand(self.viewIn[1], self.viewIn[2]))

    -- otherwise maintain size of original module weight
    self.gradV = self.gradV:view(self.viewOut)

    self.weight:copy(self:permuteOut(self.gradV))
end

function WeightNorm:updateOutput(input)
    if not(self.train == false) then
        self:updateWeight()
    end
    self.output:set(self.modules[1]:updateOutput(input))
    return self.output
end

function WeightNorm:accGradParameters(input, gradOutput, scale)
    scale = scale or 1
    self.modules[1]:accGradParameters(input, gradOutput, scale)

    self.weight:copy(self:permuteIn(self.weight))
    self.gradV:copy(self:permuteIn(self.gradWeight))
    self.weight = self.weight:view(self.viewIn)

    local norm = self._norm:view(self.viewIn[1], 1):expand(self.viewIn[1], self.viewIn[2])
    local scale = self._scale:view(self.viewIn[1], 1):expand(self.viewIn[1], self.viewIn[2])

    -- dL / dw * (w / ||w||)
    self.weight:copy(self.gradV)
    self.weight:cmul(self.v):cdiv(norm)
    self.gradG:sum(self.weight, 2)

    -- dL / dw * g / ||w||
    self.gradV:cmul(scale)

    -- dL / dg * (w * g / ||w||^2)
    self.weight:copy(self.v):cmul(scale):cdiv(norm)
    self.weight:cmul(self.gradG:view(self.viewIn[1], 1)
                               :expand(self.viewIn[1], self.viewIn[2]))

    -- dL / dv update
    self.gradV:add(-1, self.weight)

    self.gradV = self.gradV:view(self.viewOut)
    self.weight = self.weight:view(self.viewOut)
    self.gradWeight:copy(self:permuteOut(self.gradV))
end

function WeightNorm:updateGradInput(input, gradOutput)
    self.gradInput:set(self.modules[1]:updateGradInput(input, gradOutput))
    return self.gradInput
end

function WeightNorm:zeroGradParameters()
    self.modules[1]:zeroGradParameters()
    self.gradV:zero()
    self.gradG:zero()
end

function WeightNorm:updateParameters(lr)
    self.modules[1]:updateParameters(lr)
    self.g:add(-lr, self.gradG)
    self.v:add(-lr, self.gradV)
end

function WeightNorm:parameters()
    if self.bias then
        return {self.v, self.g, self.bias}, {self.gradV, self.gradG, self.gradBias}
    else
        return {self.v, self.g}, {self.gradV, self.gradG}
    end
end

function WeightNorm:write(file)
    -- Don't save weight and gradWeight since we can easily re-compute it from v
    -- and g.
    local weight = self.modules[1].weight
    local gradWeight = self.modules[1].gradWeight
    self.weight = nil
    self.gradWeight = nil
    self.modules[1].weight = nil
    self.modules[1].gradWeight = nil
    if not self.weightSize then
        self.weightSize = weight:size()
    end

    parent.write(self, file)

    self.modules[1].weight = weight
    self.modules[1].gradWeight = gradWeight
    self.weight = weight
    self.gradWeight = gradWeight
end

function WeightNorm:read(file)
    parent.read(self, file)

    -- Re-compute weight and gradWeight
    if not self.weight then
        self.modules[1].weight = self.v.new(self.weightSize)
        self.modules[1].gradWeight = self.v.new(self.weightSize)
        self.weight = self.modules[1].weight
        self.gradWeight = self.modules[1].gradWeight
        self:updateWeight()
        self.gradWeight:copy(self:permuteOut(self.gradV))
    end
end