diff options
author | Luca Antiga <luca.antiga@orobix.com> | 2016-10-16 20:02:52 +0300 |
---|---|---|
committer | Luca Antiga <luca.antiga@orobix.com> | 2016-10-16 20:02:52 +0300 |
commit | e0925d78689211d7e9215f751f3ff945327f155c (patch) | |
tree | c5cfa07c1637dc83e750b6c4a43bde8d6cb632df /test | |
parent | befcf335c7b4b97d0521b13db26b91dac1878aa2 (diff) |
Keep the Lua implementation for the Linear module
Diffstat (limited to 'test')
-rw-r--r-- | test/LinearTHNN.lua | 94 |
1 files changed, 94 insertions, 0 deletions
diff --git a/test/LinearTHNN.lua b/test/LinearTHNN.lua new file mode 100644 index 0000000..4ac944e --- /dev/null +++ b/test/LinearTHNN.lua @@ -0,0 +1,94 @@ +local Linear, parent = torch.class('nn.Linear', 'nn.Module') + +function Linear:__init(inputSize, outputSize, bias) + parent.__init(self) + local bias = ((bias == nil) and true) or bias + self.weight = torch.Tensor(outputSize, inputSize) + self.gradWeight = torch.Tensor(outputSize, inputSize) + if bias then + self.bias = torch.Tensor(outputSize) + self.gradBias = torch.Tensor(outputSize) + end + self.addBuffer = torch.Tensor(outputSize) + self:reset() +end + +function Linear:noBias() + self.bias = nil + self.gradBias = nil + return self +end + +function Linear:reset(stdv) + if stdv then + stdv = stdv * math.sqrt(3) + else + stdv = 1./math.sqrt(self.weight:size(2)) + end + if nn.oldSeed then + for i=1,self.weight:size(1) do + self.weight:select(1, i):apply(function() + return torch.uniform(-stdv, stdv) + end) + end + if self.bias then + for i=1,self.bias:nElement() do + self.bias[i] = torch.uniform(-stdv, stdv) + end + end + else + self.weight:uniform(-stdv, stdv) + if self.bias then self.bias:uniform(-stdv, stdv) end + end + return self +end + +function Linear:updateOutput(input) + input.THNN.Linear_updateOutput( + input:cdata(), + self.output:cdata(), + self.weight:cdata(), + self.bias and self.bias:cdata(), + self.addBuffer:cdata() + ) + return self.output +end + +function Linear:updateGradInput(input, gradOutput) + input.THNN.Linear_updateGradInput( + input:cdata(), + gradOutput:cdata(), + self.gradInput:cdata(), + self.weight:cdata() + ) + return self.gradInput +end + +function Linear:accGradParameters(input, gradOutput, scale) + input.THNN.Linear_accGradParameters( + input:cdata(), + gradOutput:cdata(), + self.gradInput:cdata(), + self.weight:cdata(), + self.bias and self.bias:cdata(), + self.gradWeight:cdata(), + self.bias and self.gradBias:cdata(), + self.addBuffer:cdata(), + scale or 1 + ) + return self.gradWeight +end + +-- we do not need to accumulate parameters when sharing +Linear.sharedAccUpdateGradParameters = Linear.accUpdateGradParameters + +function Linear:clearState() + if self.addBuffer then self.addBuffer:set() end + return parent.clearState(self) +end + +function Linear:__tostring__() + return torch.type(self) .. + string.format('(%d -> %d)', self.weight:size(2), self.weight:size(1)) .. + (self.bias == nil and ' without bias' or '') +end |