Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <clement.farabet@gmail.com>2011-07-08 00:18:34 +0400
committerClement Farabet <clement.farabet@gmail.com>2011-07-08 00:18:34 +0400
commiteb0140257b71ec5bc366f395c5a45756de884dcc (patch)
tree34a7ae6535bff1d1dcaed1745925a6b3fd16a9af /Replicate.lua
parente9dc19de2d81e8b0237e4810f28f28b9a0391d67 (diff)
Added Replicate module.
Diffstat (limited to 'Replicate.lua')
-rw-r--r--Replicate.lua40
1 files changed, 40 insertions, 0 deletions
diff --git a/Replicate.lua b/Replicate.lua
new file mode 100644
index 0000000..4c3f925
--- /dev/null
+++ b/Replicate.lua
@@ -0,0 +1,40 @@
+
+local Replicate, parent = torch.class('nn.Replicate','nn.Module')
+
+function Replicate:__init(nf)
+ parent.__init(self)
+ self.nfeatures = nf
+end
+
+function Replicate:forward(input)
+ local sz = torch.LongStorage(input:dim()+1)
+ sz[1] = self.nfeatures
+ for i = 1,input:dim() do
+ sz[i+1] = input:size(i)
+ end
+ local st = torch.LongStorage(input:stride()+1)
+ st[1] = 0
+ for i = 1,input:stride() do
+ sz[i+1] = input:stride(i)
+ end
+ self.output:set(input:storage(),input:storageOffset(),sz,st)
+ return self.output
+end
+
+function Replicate:backward(input, gradOutput)
+ self.gradInput:resizeAs(input):zero()
+ for k = 1,gradOutput:size(1) do
+ self.gradInput:add(gradOutput[k])
+ end
+ return self.gradInput
+end
+
+function Replicate:write(file)
+ parent.write(self,file)
+ file:writeInt(self.nfeatures)
+end
+
+function Replicate:read(file)
+ parent.read(self,file)
+ self.nfeatures = file:readInt()
+end