Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <clement.farabet@gmail.com>2011-09-24 19:12:45 +0400
committerClement Farabet <clement.farabet@gmail.com>2011-09-24 19:12:45 +0400
commit1d45c274eef362398b68761a6bbd30cc55742812 (patch)
treef71473508d923e5ad8ff7e574c9c2768e6df9b38
parentcef08770f419484fdd8dd76f19945d8a9250a97d (diff)
Added missing module.
-rw-r--r--Type.lua34
1 files changed, 34 insertions, 0 deletions
diff --git a/Type.lua b/Type.lua
new file mode 100644
index 0000000..f265263
--- /dev/null
+++ b/Type.lua
@@ -0,0 +1,34 @@
+local Type, parent = torch.class('nn.Type', 'nn.Sequential')
+
+function Type:__init(type)
+ parent.__init(self)
+ if not type:find('torch%..+Tensor') then
+ type = 'torch.' .. type .. 'Tensor'
+ end
+ self.type = type
+ self.defaulttype = torch.getdefaulttensortype()
+ self.convert_input = nn.Copy(self.defaulttype, self.type)
+ self.convert_gradOutput = nn.Copy(self.defaulttype, self.type)
+ self.convert_output = nn.Copy(self.type, self.defaulttype)
+ self.convert_gradInput = nn.Copy(self.type, self.defaulttype)
+end
+
+function Type:add(module)
+ parent.add(self, module:type(self.type))
+ return self
+end
+
+function Type:forward(input)
+ input = self.convert_input:forward(input)
+ local output = parent.forward(self, input)
+ self.output = self.convert_output:forward(output)
+ return self.output
+end
+
+function Type:backward(input, gradOutput)
+ input = self.convert_input:forward(input)
+ gradOutput = self.convert_gradOutput:forward(gradOutput)
+ local gradInput = parent.backward(self, input, gradOutput)
+ self.gradInput = self.convert_gradInput:forward(gradInput)
+ return self.gradInput
+end