Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <clement.farabet@gmail.com>2011-07-09 01:27:45 +0400
committerClement Farabet <clement.farabet@gmail.com>2011-07-09 01:27:45 +0400
commit30e99f80b5067654d4363842c7d5136b314f4e83 (patch)
tree5a1d905775282bfd0dd638fcb83554eaa17b0c20 /SpatialGraph.lua
parent89f2393671e55332855bca9ac005dea3c2c021e2 (diff)
Added SpatialGraph module, in its quite early version.
Diffstat (limited to 'SpatialGraph.lua')
-rw-r--r--SpatialGraph.lua69
1 files changed, 69 insertions, 0 deletions
diff --git a/SpatialGraph.lua b/SpatialGraph.lua
new file mode 100644
index 0000000..253a800
--- /dev/null
+++ b/SpatialGraph.lua
@@ -0,0 +1,69 @@
+local SpatialGraph, parent = torch.class('nn.SpatialGraph', 'nn.Module')
+
+local help_desc =
+[[Creates an edge-weighted graph from a set of N feature
+maps.
+
+The input is a 3D tensor width x height x nInputPlane, the
+output is a 3D tensor width x height x 2. The first slice
+of the output contains horizontal edges, the second vertical
+edges.
+
+The input features are assumed to be >= 0.
+More precisely:
++ dist == 'euclid' and norm == true: the input features should
+ also be <= 1, to produce properly normalized distances (btwn 0 and 1);
++ dist == 'cosine': the input features do not need to be bounded,
+ as the cosine dissimilarity normalizes with respect to each vector.
+ An epsilon is automatically added, so that components that are == 0
+ are properly considered as being similar.
+]]
+
+function SpatialGraph:__init(...)
+ parent.__init(self)
+
+ xlua.unpack_class(
+ self, {...},
+ 'nn.SpatialGraph', help_desc,
+ {arg='dist', type='string', help='distance metric to use', default='euclid'},
+ {arg='normalize', type='boolean', help='normalize euclidean distances btwn 0 and 1 (assumes input range to be btwn 0 and 1)', default=true},
+ {arg='connex', type='number', help='connexity', default=4}
+ )
+
+ if self.connex ~= 4 then
+ xlua.error('4 is the only connexity supported, for now', 'nn.SpatialGraph',self.usage)
+ end
+ self.dist = ((self.dist == 'euclid') and 0) or ((self.dist == 'cosine') and 1)
+ or xerror('euclid is the only distance supported, for now','nn.SpatialGraph',self.usage)
+ self.normalize = (self.normalize and 1) or 0
+ if self.dist == 'cosine' and self.normalize == 1 then
+ xerror('normalized cosine is not supported for now [just because I couldnt figure out the gradient :-)]',
+ 'nn.SpatialGraph', self.usage)
+ end
+end
+
+function SpatialGraph:forward(input)
+ self.output:resize(self.connex / 2, input:size(2), input:size(3))
+ input.nn.SpatialGraph_forward(self, input)
+ return self.output
+end
+
+function SpatialGraph:backward(input, gradOutput)
+ self.gradInput:resizeAs(input)
+ input.nn.SpatialGraph_backward(self, input, gradOutput)
+ return self.gradInput
+end
+
+function SpatialGraph:write(file)
+ parent.write(self, file)
+ file:writeInt(self.connex)
+ file:writeInt(self.dist)
+ file:writeInt(self.normalize)
+end
+
+function SpatialGraph:read(file)
+ parent.read(self, file)
+ self.connex = file:readInt()
+ self.dist = file:readInt()
+ self.normalize = file:readInt()
+end