diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2011-07-08 00:18:34 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2011-07-08 00:18:34 +0400 |
commit | eb0140257b71ec5bc366f395c5a45756de884dcc (patch) | |
tree | 34a7ae6535bff1d1dcaed1745925a6b3fd16a9af /Replicate.lua | |
parent | e9dc19de2d81e8b0237e4810f28f28b9a0391d67 (diff) |
Added Replicate module.
Diffstat (limited to 'Replicate.lua')
-rw-r--r-- | Replicate.lua | 40 |
1 files changed, 40 insertions, 0 deletions
diff --git a/Replicate.lua b/Replicate.lua new file mode 100644 index 0000000..4c3f925 --- /dev/null +++ b/Replicate.lua @@ -0,0 +1,40 @@ + +local Replicate, parent = torch.class('nn.Replicate','nn.Module') + +function Replicate:__init(nf) + parent.__init(self) + self.nfeatures = nf +end + +function Replicate:forward(input) + local sz = torch.LongStorage(input:dim()+1) + sz[1] = self.nfeatures + for i = 1,input:dim() do + sz[i+1] = input:size(i) + end + local st = torch.LongStorage(input:stride()+1) + st[1] = 0 + for i = 1,input:stride() do + sz[i+1] = input:stride(i) + end + self.output:set(input:storage(),input:storageOffset(),sz,st) + return self.output +end + +function Replicate:backward(input, gradOutput) + self.gradInput:resizeAs(input):zero() + for k = 1,gradOutput:size(1) do + self.gradInput:add(gradOutput[k]) + end + return self.gradInput +end + +function Replicate:write(file) + parent.write(self,file) + file:writeInt(self.nfeatures) +end + +function Replicate:read(file) + parent.read(self,file) + self.nfeatures = file:readInt() +end |