diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2011-07-09 01:27:45 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2011-07-09 01:27:45 +0400 |
commit | 30e99f80b5067654d4363842c7d5136b314f4e83 (patch) | |
tree | 5a1d905775282bfd0dd638fcb83554eaa17b0c20 /SpatialGraph.lua | |
parent | 89f2393671e55332855bca9ac005dea3c2c021e2 (diff) |
Added SpatialGraph module, in its quite early version.
Diffstat (limited to 'SpatialGraph.lua')
-rw-r--r-- | SpatialGraph.lua | 69 |
1 files changed, 69 insertions, 0 deletions
diff --git a/SpatialGraph.lua b/SpatialGraph.lua new file mode 100644 index 0000000..253a800 --- /dev/null +++ b/SpatialGraph.lua @@ -0,0 +1,69 @@ +local SpatialGraph, parent = torch.class('nn.SpatialGraph', 'nn.Module') + +local help_desc = +[[Creates an edge-weighted graph from a set of N feature +maps. + +The input is a 3D tensor width x height x nInputPlane, the +output is a 3D tensor width x height x 2. The first slice +of the output contains horizontal edges, the second vertical +edges. + +The input features are assumed to be >= 0. +More precisely: ++ dist == 'euclid' and norm == true: the input features should + also be <= 1, to produce properly normalized distances (btwn 0 and 1); ++ dist == 'cosine': the input features do not need to be bounded, + as the cosine dissimilarity normalizes with respect to each vector. + An epsilon is automatically added, so that components that are == 0 + are properly considered as being similar. +]] + +function SpatialGraph:__init(...) + parent.__init(self) + + xlua.unpack_class( + self, {...}, + 'nn.SpatialGraph', help_desc, + {arg='dist', type='string', help='distance metric to use', default='euclid'}, + {arg='normalize', type='boolean', help='normalize euclidean distances btwn 0 and 1 (assumes input range to be btwn 0 and 1)', default=true}, + {arg='connex', type='number', help='connexity', default=4} + ) + + if self.connex ~= 4 then + xlua.error('4 is the only connexity supported, for now', 'nn.SpatialGraph',self.usage) + end + self.dist = ((self.dist == 'euclid') and 0) or ((self.dist == 'cosine') and 1) + or xerror('euclid is the only distance supported, for now','nn.SpatialGraph',self.usage) + self.normalize = (self.normalize and 1) or 0 + if self.dist == 'cosine' and self.normalize == 1 then + xerror('normalized cosine is not supported for now [just because I couldnt figure out the gradient :-)]', + 'nn.SpatialGraph', self.usage) + end +end + +function SpatialGraph:forward(input) + self.output:resize(self.connex / 2, input:size(2), input:size(3)) + input.nn.SpatialGraph_forward(self, input) + return self.output +end + +function SpatialGraph:backward(input, gradOutput) + self.gradInput:resizeAs(input) + input.nn.SpatialGraph_backward(self, input, gradOutput) + return self.gradInput +end + +function SpatialGraph:write(file) + parent.write(self, file) + file:writeInt(self.connex) + file:writeInt(self.dist) + file:writeInt(self.normalize) +end + +function SpatialGraph:read(file) + parent.read(self, file) + self.connex = file:readInt() + self.dist = file:readInt() + self.normalize = file:readInt() +end |