diff options
author | fsuzanomassa <fvsmassa@gmail.com> | 2015-04-22 19:47:52 +0300 |
---|---|---|
committer | fsuzanomassa <fvsmassa@gmail.com> | 2015-04-22 19:47:52 +0300 |
commit | f5a9cd55277c9d101dc50f42703ce37cf6482250 (patch) | |
tree | b0023792d785e3b6830f8459c100b8b89ab7152d /MulConstant.lua | |
parent | 418624f67da0c61dd2a7205373e3ebe816a94aae (diff) |
Adding in-place AddConstant and MulConstant
Diffstat (limited to 'MulConstant.lua')
-rw-r--r-- | MulConstant.lua | 32 |
1 files changed, 25 insertions, 7 deletions
diff --git a/MulConstant.lua b/MulConstant.lua index 982ab41..eb41d36 100644 --- a/MulConstant.lua +++ b/MulConstant.lua @@ -1,21 +1,39 @@ local MulConstant, parent = torch.class('nn.MulConstant', 'nn.Module') -function MulConstant:__init(constant_scalar) +function MulConstant:__init(constant_scalar,ip) parent.__init(self) assert(type(constant_scalar) == 'number', 'input is not scalar!') self.constant_scalar = constant_scalar + + -- default for inplace is false + self.inplace = ip or false + if (ip and type(ip) ~= 'boolean') then + error('in-place flag must be boolean') + end end function MulConstant:updateOutput(input) - self.output:resizeAs(input) - self.output:copy(input) - self.output:mul(self.constant_scalar) + if self.inplace then + input:mul(self.constant_scalar) + self.output = input + else + self.output:resizeAs(input) + self.output:copy(input) + self.output:mul(self.constant_scalar) + end return self.output end function MulConstant:updateGradInput(input, gradOutput) - self.gradInput:resizeAs(gradOutput) - self.gradInput:copy(gradOutput) - self.gradInput:mul(self.constant_scalar) + if self.inplace then + gradOutput:mul(self.constant_scalar) + self.gradInput = gradOutput + -- restore previous input value + input:div(self.constant_scalar) + else + self.gradInput:resizeAs(gradOutput) + self.gradInput:copy(gradOutput) + self.gradInput:mul(self.constant_scalar) + end return self.gradInput end |