local MulConstant, parent = torch.class('nn.MulConstant', 'nn.Module') function MulConstant:__init(constant_scalar,ip) parent.__init(self) assert(type(constant_scalar) == 'number', 'input is not scalar!') self.constant_scalar = constant_scalar -- default for inplace is false self.inplace = ip or false if (ip and type(ip) ~= 'boolean') then error('in-place flag must be boolean') end end function MulConstant:updateOutput(input) if self.inplace then input:mul(self.constant_scalar) self.output:set(input) else self.output:resizeAs(input) self.output:copy(input) self.output:mul(self.constant_scalar) end return self.output end function MulConstant:updateGradInput(input, gradOutput) if self.gradInput then if self.inplace then gradOutput:mul(self.constant_scalar) self.gradInput:set(gradOutput) -- restore previous input value input:div(self.constant_scalar) else self.gradInput:resizeAs(gradOutput) self.gradInput:copy(gradOutput) self.gradInput:mul(self.constant_scalar) end return self.gradInput end end