diff options
author | WenLi Zhuang <iamalbert@users.noreply.github.com> | 2016-08-04 18:53:35 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2016-08-04 18:53:35 +0300 |
commit | 40e42078dd166b349aa19fd28af0b97a57e43ec2 (patch) | |
tree | fca44763aa2a15e6ee3786fdbaff2b9d9b63eb28 | |
parent | 7ff1f86ecffe6816e92049ef219e8922db0c2b67 (diff) |
overload operator __unm and __sub to support module chaining (#113)
* add operator support
* Update README.md
-rw-r--r-- | README.md | 41 | ||||
-rw-r--r-- | init.lua | 19 |
2 files changed, 60 insertions, 0 deletions
@@ -51,6 +51,17 @@ To save the *graph* on file, specify the file name, and both a `dot` and `svg` f graph.dot(mlp.fg, 'MLP', 'myMLP') ``` +You can also use the `__unm__` and `__sub__` operators to replace all `__call__`: +```lua +h1 = - nn.Linear(20,10) +h2 = h1 + - nn.Tanh() + - nn.Linear(10,10) + - nn.Tanh() + - nn.Linear(10, 1) +mlp = nn.gModule({h1}, {h2}) +``` + ### A network with 2 inputs and 2 outputs @@ -72,6 +83,19 @@ gmod:updateGradInput({x1, x2}, {torch.rand(1), torch.rand(1)}) graph.dot(gmod.fg, 'Big MLP') ``` +Alternatively, you can use `-` to make your code looks like the data flow: + +```lua +h1 = - nn.Linear(20,20) +h2 = - nn.Linear(10,10) +hh1 = h1 - nn.Tanh() - nn.Linear(20,1) +hh2 = h2 - nn.Tanh() - nn.Linear(10,1) +madd = {hh1,hh2} - nn.CAddTable() +oA = madd - nn.Sigmoid() +oB = madd - nn.Tanh() +gmod = nn.gModule( {h1,h2}, {oA,oB} ) +``` + <img src= "https://raw.github.com/koraykv/torch-nngraph/master/doc/mlp2.png" width="300px"/> @@ -123,6 +147,23 @@ graph.dot(g.fg, 'Forward Graph') graph.dot(g.bg, 'Backward Graph') ``` +As your graph getting bigger and more complicated, the nested parentheses may become confusing. In this case, using `-` to chain the modules is a clearer and easier way: +```lua +input = - nn.Identity() +L1 = input + - nn.Linear(10, 20) + - nn.Tanh() +L2 = { input, L1 } + - nn.JoinTable(1) + - nn.Linear(30,60) + - nn.Tanh() +L3 = { L1,L2 } + - nn.JoinTable(1) + - nn.Linear(80,160) + - nn.Tanh() +g = nn.gModule({input},{L3}) +``` + <img src= "https://raw.github.com/koraykv/torch-nngraph/master/doc/mlp4_forward.png" width="300px"/> <img src= "https://raw.github.com/koraykv/torch-nngraph/master/doc/mlp4_backward.png" width="300px"/> @@ -58,4 +58,23 @@ function Criterion:__call__(...) return nn.ModuleFromCriterion(self)(...) end + + + +Module.__unm__ = function( obj ) + return obj() +end + +Module.__sub__ = function( prev, next ) + return next(prev) +end + + +do + local Node = torch.getmetatable('nngraph.Node') + Node.__sub__ = function( prev, next ) + return next(prev) + end +end + return nngraph |