diff options
author | Soumith Chintala <soumith@gmail.com> | 2014-12-04 01:23:26 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2014-12-04 01:23:26 +0300 |
commit | d1584c714469a46c70856c041734f781fcb454b9 (patch) | |
tree | 959de46597074dddf4ff4424d74e28ab20c94385 | |
parent | b9212fde4c889d2e811534777ebb124b08ba2535 (diff) | |
parent | d5ab2ca3c2b4d4cba7bdfafd8d86daa63bea71f7 (diff) |
Merge pull request #117 from torch/mulfix
removing the requirement for providing size in nn.Mul
-rw-r--r-- | Mul.lua | 10 | ||||
-rw-r--r-- | doc/simple.md | 4 | ||||
-rw-r--r-- | test.lua | 2 |
3 files changed, 6 insertions, 10 deletions
@@ -1,15 +1,11 @@ local Mul, parent = torch.class('nn.Mul', 'nn.Module') -function Mul:__init(inputSize) +function Mul:__init() parent.__init(self) self.weight = torch.Tensor(1) self.gradWeight = torch.Tensor(1) - -- state - self.gradInput:resize(inputSize) - self.output:resize(inputSize) - self:reset() end @@ -25,13 +21,13 @@ function Mul:reset(stdv) end function Mul:updateOutput(input) - self.output:copy(input); + self.output:resizeAs(input):copy(input); self.output:mul(self.weight[1]); return self.output end function Mul:updateGradInput(input, gradOutput) - self.gradInput:zero() + self.gradInput:resizeAs(input):zero() self.gradInput:add(self.weight[1], gradOutput) return self.gradInput end diff --git a/doc/simple.md b/doc/simple.md index aa1a94d..2cc3ccf 100644 --- a/doc/simple.md +++ b/doc/simple.md @@ -262,7 +262,7 @@ to produce the output _y_. <a name="nn.Mul"/> ## Mul ## -`module` = `Mul(inputDimension)` +`module` = `Mul()` Applies a _single_ scaling factor to the incoming data, i.e. _y= w x_, where _w_ is a scalar. @@ -271,7 +271,7 @@ Example: ```lua y=torch.Tensor(5); mlp=nn.Sequential() -mlp:add(nn.Mul(5)) +mlp:add(nn.Mul()) function gradUpdate(mlp, x, y, criterion, learningRate) local pred = mlp:forward(x) @@ -671,7 +671,7 @@ function nntest.Mul() local inj = math.random(3,5) local ink = math.random(3,5) local input = torch.Tensor(ini,inj,ink):zero() - local module = nn.Mul(ini*inj*ink) + local module = nn.Mul() local err = jac.testJacobian(module,input) mytester:assertlt(err,precision, 'error on state ') |