Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nngraph.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWenLi Zhuang <iamalbert@users.noreply.github.com>2016-08-04 18:53:35 +0300
committerSoumith Chintala <soumith@gmail.com>2016-08-04 18:53:35 +0300
commit40e42078dd166b349aa19fd28af0b97a57e43ec2 (patch)
treefca44763aa2a15e6ee3786fdbaff2b9d9b63eb28
parent7ff1f86ecffe6816e92049ef219e8922db0c2b67 (diff)
overload operator __unm and __sub to support module chaining (#113)
* add operator support * Update README.md
-rw-r--r--README.md41
-rw-r--r--init.lua19
2 files changed, 60 insertions, 0 deletions
diff --git a/README.md b/README.md
index e11ee12..10c8ad0 100644
--- a/README.md
+++ b/README.md
@@ -51,6 +51,17 @@ To save the *graph* on file, specify the file name, and both a `dot` and `svg` f
graph.dot(mlp.fg, 'MLP', 'myMLP')
```
+You can also use the `__unm__` and `__sub__` operators to replace all `__call__`:
+```lua
+h1 = - nn.Linear(20,10)
+h2 = h1
+ - nn.Tanh()
+ - nn.Linear(10,10)
+ - nn.Tanh()
+ - nn.Linear(10, 1)
+mlp = nn.gModule({h1}, {h2})
+```
+
### A network with 2 inputs and 2 outputs
@@ -72,6 +83,19 @@ gmod:updateGradInput({x1, x2}, {torch.rand(1), torch.rand(1)})
graph.dot(gmod.fg, 'Big MLP')
```
+Alternatively, you can use `-` to make your code looks like the data flow:
+
+```lua
+h1 = - nn.Linear(20,20)
+h2 = - nn.Linear(10,10)
+hh1 = h1 - nn.Tanh() - nn.Linear(20,1)
+hh2 = h2 - nn.Tanh() - nn.Linear(10,1)
+madd = {hh1,hh2} - nn.CAddTable()
+oA = madd - nn.Sigmoid()
+oB = madd - nn.Tanh()
+gmod = nn.gModule( {h1,h2}, {oA,oB} )
+```
+
<img src= "https://raw.github.com/koraykv/torch-nngraph/master/doc/mlp2.png" width="300px"/>
@@ -123,6 +147,23 @@ graph.dot(g.fg, 'Forward Graph')
graph.dot(g.bg, 'Backward Graph')
```
+As your graph getting bigger and more complicated, the nested parentheses may become confusing. In this case, using `-` to chain the modules is a clearer and easier way:
+```lua
+input = - nn.Identity()
+L1 = input
+ - nn.Linear(10, 20)
+ - nn.Tanh()
+L2 = { input, L1 }
+ - nn.JoinTable(1)
+ - nn.Linear(30,60)
+ - nn.Tanh()
+L3 = { L1,L2 }
+ - nn.JoinTable(1)
+ - nn.Linear(80,160)
+ - nn.Tanh()
+g = nn.gModule({input},{L3})
+```
+
<img src= "https://raw.github.com/koraykv/torch-nngraph/master/doc/mlp4_forward.png" width="300px"/>
<img src= "https://raw.github.com/koraykv/torch-nngraph/master/doc/mlp4_backward.png" width="300px"/>
diff --git a/init.lua b/init.lua
index 455e1b0..0e354f6 100644
--- a/init.lua
+++ b/init.lua
@@ -58,4 +58,23 @@ function Criterion:__call__(...)
return nn.ModuleFromCriterion(self)(...)
end
+
+
+
+Module.__unm__ = function( obj )
+ return obj()
+end
+
+Module.__sub__ = function( prev, next )
+ return next(prev)
+end
+
+
+do
+ local Node = torch.getmetatable('nngraph.Node')
+ Node.__sub__ = function( prev, next )
+ return next(prev)
+ end
+end
+
return nngraph