From eb0140257b71ec5bc366f395c5a45756de884dcc Mon Sep 17 00:00:00 2001 From: Clement Farabet Date: Thu, 7 Jul 2011 16:18:34 -0400 Subject: Added Replicate module. --- Replicate.lua | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 Replicate.lua (limited to 'Replicate.lua') diff --git a/Replicate.lua b/Replicate.lua new file mode 100644 index 0000000..4c3f925 --- /dev/null +++ b/Replicate.lua @@ -0,0 +1,40 @@ + +local Replicate, parent = torch.class('nn.Replicate','nn.Module') + +function Replicate:__init(nf) + parent.__init(self) + self.nfeatures = nf +end + +function Replicate:forward(input) + local sz = torch.LongStorage(input:dim()+1) + sz[1] = self.nfeatures + for i = 1,input:dim() do + sz[i+1] = input:size(i) + end + local st = torch.LongStorage(input:stride()+1) + st[1] = 0 + for i = 1,input:stride() do + sz[i+1] = input:stride(i) + end + self.output:set(input:storage(),input:storageOffset(),sz,st) + return self.output +end + +function Replicate:backward(input, gradOutput) + self.gradInput:resizeAs(input):zero() + for k = 1,gradOutput:size(1) do + self.gradInput:add(gradOutput[k]) + end + return self.gradInput +end + +function Replicate:write(file) + parent.write(self,file) + file:writeInt(self.nfeatures) +end + +function Replicate:read(file) + parent.read(self,file) + self.nfeatures = file:readInt() +end -- cgit v1.2.3