diff options
author | Soumith Chintala <soumith@gmail.com> | 2014-06-26 18:14:13 +0400 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2014-06-26 18:14:13 +0400 |
commit | 1310a045ebc69a9f9e8c57d07af587a6535d5ae9 (patch) | |
tree | 46e52e5ae16e0653953424423aa81c81dd6b526f | |
parent | 896ad1c1bf5588b2944c79fb24a0aee1ae7db726 (diff) | |
parent | 9386e79b7eaf324c34ec1c16fbc873add39dff22 (diff) |
Merge pull request #20 from sergomezcol/view_module
Add nn.View module
-rw-r--r-- | View.lua | 36 | ||||
-rw-r--r-- | doc/simple.md | 69 | ||||
-rw-r--r-- | init.lua | 1 | ||||
-rw-r--r-- | test/test.lua | 19 |
4 files changed, 125 insertions, 0 deletions
diff --git a/View.lua b/View.lua new file mode 100644 index 0000000..75eda26 --- /dev/null +++ b/View.lua @@ -0,0 +1,36 @@ +local View, parent = torch.class('nn.View', 'nn.Module') + +function View:__init(...) + parent.__init(self) + self.size = ... + if select('#', ...) > 1 or type(self.size) == "number" then + self.size = torch.LongStorage({...}) + end + assert(torch.typename(self.size)=="torch.LongStorage", "expecting a LongStorage") + self.numElements = 1 + for i = 1,#self.size do + self.numElements = self.numElements * self.size[i] + end + + self.output = nil + self.gradInput = nil +end + +local function isMinibatch(input, numElements) + return input:dim() > 1 and + input:nElement()/input:size(1) == numElements +end + +function View:updateOutput(input) + if isMinibatch(input, self.numElements) then + self.output = input:view(input:size(1), unpack(self.size:totable())) + else + self.output = input:view(self.size) + end + return self.output +end + +function View:updateGradInput(input, gradOutput) + self.gradInput = gradOutput:view(input:size()) + return self.gradInput +end diff --git a/doc/simple.md b/doc/simple.md index 0fdffef..6f5b13d 100644 --- a/doc/simple.md +++ b/doc/simple.md @@ -468,8 +468,77 @@ Example: 12 16 [torch.Tensor of dimension 16] +``` + +<a name="nn.View"/> +## Reshape ## + +`module` = `View(sizes)` + +This module creates a new view of the input tensor using the `sizes` passed to +the constructor. The parameter `sizes` can either be a `LongStorage` or numbers. + +Example: +```lua +> x=torch.Tensor(4,4) +> for i=1,4 do +> for j=1,4 do +> x[i][j]=(i-1)*4+j; +> end +> end +> print(x) + +x=torch.Tensor(4,4) +for i=1,4 do + for j=1,4 do + x[i][j]=(i-1)*4+j; + end +end +print(x) + + 1 2 3 4 + 5 6 7 8 + 9 10 11 12 + 13 14 15 16 +[torch.Tensor of dimension 4x4] + +> print(nn.View(2,8):forward(x)) + + 1 2 3 4 5 6 7 8 + 9 10 11 12 13 14 15 16 +[torch.DoubleTensor of dimension 2x8] + +> print(nn.View(torch.LongStorage{8,2}):forward(x)) + + 1 2 + 3 4 + 5 6 + 7 8 + 9 10 + 11 12 + 13 14 + 15 16 +[torch.DoubleTensor of dimension 8x2] +> print(nn.View(16):forward(x)) + 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + 10 + 11 + 12 + 13 + 14 + 15 + 16 +[torch.DoubleTensor of dimension 16] ``` @@ -10,6 +10,7 @@ include('Sequential.lua') include('Linear.lua') include('SparseLinear.lua') include('Reshape.lua') +include('View.lua') include('Select.lua') include('Narrow.lua') include('Replicate.lua') diff --git a/test/test.lua b/test/test.lua index 3c4ec12..5db941a 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1867,6 +1867,25 @@ function nntest.SplitTable() end +function nntest.View() + local input = torch.rand(10) + local template = torch.rand(5,2) + local target = template:size():totable() + local module = nn.View(template:size()) + mytester:assertTableEq(module:forward(input):size():totable(), target, "Error in forward (1)") + local module = nn.View(unpack(target)) + mytester:assertTableEq(module:forward(input):size():totable(), target, "Error in forward (2)") + + -- Minibatch + local minibatch = torch.rand(5,10) + mytester:assertTableEq(module:forward(minibatch):size(1), + minibatch:size(1), + "Error in minibatch dimension") + mytester:assertTableEq(module:forward(minibatch):nElement(), + minibatch:nElement(), + "Error in minibatch nElement") +end + mytester:add(nntest) if not nn then |