Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorfsuzanomassa <fvsmassa@gmail.com>2015-04-22 19:47:52 +0300
committerfsuzanomassa <fvsmassa@gmail.com>2015-04-22 19:47:52 +0300
commitf5a9cd55277c9d101dc50f42703ce37cf6482250 (patch)
treeb0023792d785e3b6830f8459c100b8b89ab7152d /MulConstant.lua
parent418624f67da0c61dd2a7205373e3ebe816a94aae (diff)
Adding in-place AddConstant and MulConstant
Diffstat (limited to 'MulConstant.lua')
-rw-r--r--MulConstant.lua32
1 files changed, 25 insertions, 7 deletions
diff --git a/MulConstant.lua b/MulConstant.lua
index 982ab41..eb41d36 100644
--- a/MulConstant.lua
+++ b/MulConstant.lua
@@ -1,21 +1,39 @@
local MulConstant, parent = torch.class('nn.MulConstant', 'nn.Module')
-function MulConstant:__init(constant_scalar)
+function MulConstant:__init(constant_scalar,ip)
parent.__init(self)
assert(type(constant_scalar) == 'number', 'input is not scalar!')
self.constant_scalar = constant_scalar
+
+ -- default for inplace is false
+ self.inplace = ip or false
+ if (ip and type(ip) ~= 'boolean') then
+ error('in-place flag must be boolean')
+ end
end
function MulConstant:updateOutput(input)
- self.output:resizeAs(input)
- self.output:copy(input)
- self.output:mul(self.constant_scalar)
+ if self.inplace then
+ input:mul(self.constant_scalar)
+ self.output = input
+ else
+ self.output:resizeAs(input)
+ self.output:copy(input)
+ self.output:mul(self.constant_scalar)
+ end
return self.output
end
function MulConstant:updateGradInput(input, gradOutput)
- self.gradInput:resizeAs(gradOutput)
- self.gradInput:copy(gradOutput)
- self.gradInput:mul(self.constant_scalar)
+ if self.inplace then
+ gradOutput:mul(self.constant_scalar)
+ self.gradInput = gradOutput
+ -- restore previous input value
+ input:div(self.constant_scalar)
+ else
+ self.gradInput:resizeAs(gradOutput)
+ self.gradInput:copy(gradOutput)
+ self.gradInput:mul(self.constant_scalar)
+ end
return self.gradInput
end