diff options
author | Luca Antiga <luca.antiga@orobix.com> | 2016-10-17 01:16:00 +0300 |
---|---|---|
committer | Luca Antiga <luca.antiga@orobix.com> | 2016-10-17 01:16:00 +0300 |
commit | f7b0accf5dafd14e95a346b51a8149252af7d73a (patch) | |
tree | bcbaf7868e964031e604dab5c531b1d5dfb79e5e /test | |
parent | e0925d78689211d7e9215f751f3ff945327f155c (diff) |
Add fn for using THNN-only modules during tests
Diffstat (limited to 'test')
-rw-r--r-- | test/LinearTHNN.lua | 20 |
1 files changed, 10 insertions, 10 deletions
diff --git a/test/LinearTHNN.lua b/test/LinearTHNN.lua index 4ac944e..dc690dc 100644 --- a/test/LinearTHNN.lua +++ b/test/LinearTHNN.lua @@ -1,6 +1,6 @@ -local Linear, parent = torch.class('nn.Linear', 'nn.Module') +local LinearTHNN, parent = torch.class('nn.LinearTHNN', 'nn.Module') -function Linear:__init(inputSize, outputSize, bias) +function LinearTHNN:__init(inputSize, outputSize, bias) parent.__init(self) local bias = ((bias == nil) and true) or bias self.weight = torch.Tensor(outputSize, inputSize) @@ -13,13 +13,13 @@ function Linear:__init(inputSize, outputSize, bias) self:reset() end -function Linear:noBias() +function LinearTHNN:noBias() self.bias = nil self.gradBias = nil return self end -function Linear:reset(stdv) +function LinearTHNN:reset(stdv) if stdv then stdv = stdv * math.sqrt(3) else @@ -43,7 +43,7 @@ function Linear:reset(stdv) return self end -function Linear:updateOutput(input) +function LinearTHNN:updateOutput(input) input.THNN.Linear_updateOutput( input:cdata(), self.output:cdata(), @@ -54,7 +54,7 @@ function Linear:updateOutput(input) return self.output end -function Linear:updateGradInput(input, gradOutput) +function LinearTHNN:updateGradInput(input, gradOutput) input.THNN.Linear_updateGradInput( input:cdata(), gradOutput:cdata(), @@ -64,7 +64,7 @@ function Linear:updateGradInput(input, gradOutput) return self.gradInput end -function Linear:accGradParameters(input, gradOutput, scale) +function LinearTHNN:accGradParameters(input, gradOutput, scale) input.THNN.Linear_accGradParameters( input:cdata(), gradOutput:cdata(), @@ -80,14 +80,14 @@ function Linear:accGradParameters(input, gradOutput, scale) end -- we do not need to accumulate parameters when sharing -Linear.sharedAccUpdateGradParameters = Linear.accUpdateGradParameters +LinearTHNN.sharedAccUpdateGradParameters = LinearTHNN.accUpdateGradParameters -function Linear:clearState() +function LinearTHNN:clearState() if self.addBuffer then self.addBuffer:set() end return parent.clearState(self) end -function Linear:__tostring__() +function LinearTHNN:__tostring__() return torch.type(self) .. string.format('(%d -> %d)', self.weight:size(2), self.weight:size(1)) .. (self.bias == nil and ' without bias' or '') |