Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2014-12-04 01:23:26 +0300
committerSoumith Chintala <soumith@gmail.com>2014-12-04 01:23:26 +0300
commitd1584c714469a46c70856c041734f781fcb454b9 (patch)
tree959de46597074dddf4ff4424d74e28ab20c94385
parentb9212fde4c889d2e811534777ebb124b08ba2535 (diff)
parentd5ab2ca3c2b4d4cba7bdfafd8d86daa63bea71f7 (diff)
Merge pull request #117 from torch/mulfix
removing the requirement for providing size in nn.Mul
-rw-r--r--Mul.lua10
-rw-r--r--doc/simple.md4
-rw-r--r--test.lua2
3 files changed, 6 insertions, 10 deletions
diff --git a/Mul.lua b/Mul.lua
index 7841470..289d83a 100644
--- a/Mul.lua
+++ b/Mul.lua
@@ -1,15 +1,11 @@
local Mul, parent = torch.class('nn.Mul', 'nn.Module')
-function Mul:__init(inputSize)
+function Mul:__init()
parent.__init(self)
self.weight = torch.Tensor(1)
self.gradWeight = torch.Tensor(1)
- -- state
- self.gradInput:resize(inputSize)
- self.output:resize(inputSize)
-
self:reset()
end
@@ -25,13 +21,13 @@ function Mul:reset(stdv)
end
function Mul:updateOutput(input)
- self.output:copy(input);
+ self.output:resizeAs(input):copy(input);
self.output:mul(self.weight[1]);
return self.output
end
function Mul:updateGradInput(input, gradOutput)
- self.gradInput:zero()
+ self.gradInput:resizeAs(input):zero()
self.gradInput:add(self.weight[1], gradOutput)
return self.gradInput
end
diff --git a/doc/simple.md b/doc/simple.md
index aa1a94d..2cc3ccf 100644
--- a/doc/simple.md
+++ b/doc/simple.md
@@ -262,7 +262,7 @@ to produce the output _y_.
<a name="nn.Mul"/>
## Mul ##
-`module` = `Mul(inputDimension)`
+`module` = `Mul()`
Applies a _single_ scaling factor to the incoming data, i.e.
_y= w x_, where _w_ is a scalar.
@@ -271,7 +271,7 @@ Example:
```lua
y=torch.Tensor(5);
mlp=nn.Sequential()
-mlp:add(nn.Mul(5))
+mlp:add(nn.Mul())
function gradUpdate(mlp, x, y, criterion, learningRate)
local pred = mlp:forward(x)
diff --git a/test.lua b/test.lua
index 4615b23..89ff7d5 100644
--- a/test.lua
+++ b/test.lua
@@ -671,7 +671,7 @@ function nntest.Mul()
local inj = math.random(3,5)
local ink = math.random(3,5)
local input = torch.Tensor(ini,inj,ink):zero()
- local module = nn.Mul(ini*inj*ink)
+ local module = nn.Mul()
local err = jac.testJacobian(module,input)
mytester:assertlt(err,precision, 'error on state ')