diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2011-09-24 19:12:45 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2011-09-24 19:12:45 +0400 |
commit | 1d45c274eef362398b68761a6bbd30cc55742812 (patch) | |
tree | f71473508d923e5ad8ff7e574c9c2768e6df9b38 | |
parent | cef08770f419484fdd8dd76f19945d8a9250a97d (diff) |
Added missing module.
-rw-r--r-- | Type.lua | 34 |
1 files changed, 34 insertions, 0 deletions
diff --git a/Type.lua b/Type.lua new file mode 100644 index 0000000..f265263 --- /dev/null +++ b/Type.lua @@ -0,0 +1,34 @@ +local Type, parent = torch.class('nn.Type', 'nn.Sequential') + +function Type:__init(type) + parent.__init(self) + if not type:find('torch%..+Tensor') then + type = 'torch.' .. type .. 'Tensor' + end + self.type = type + self.defaulttype = torch.getdefaulttensortype() + self.convert_input = nn.Copy(self.defaulttype, self.type) + self.convert_gradOutput = nn.Copy(self.defaulttype, self.type) + self.convert_output = nn.Copy(self.type, self.defaulttype) + self.convert_gradInput = nn.Copy(self.type, self.defaulttype) +end + +function Type:add(module) + parent.add(self, module:type(self.type)) + return self +end + +function Type:forward(input) + input = self.convert_input:forward(input) + local output = parent.forward(self, input) + self.output = self.convert_output:forward(output) + return self.output +end + +function Type:backward(input, gradOutput) + input = self.convert_input:forward(input) + gradOutput = self.convert_gradOutput:forward(gradOutput) + local gradInput = parent.backward(self, input, gradOutput) + self.gradInput = self.convert_gradInput:forward(gradInput) + return self.gradInput +end |