Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <clement.farabet@gmail.com>2011-07-09 00:57:20 +0400
committerClement Farabet <clement.farabet@gmail.com>2011-07-09 00:57:20 +0400
commitd34a5b96b7a15d5b3fa329ac9e3589b48b27da3c (patch)
treea9f9c782d1caa8bf4b533f18b7edff1114532d0c /SpatialSparseCriterion.lua
parent44c2713c7027da9bda69c235606ed246057d00d5 (diff)
Added missng file.
Diffstat (limited to 'SpatialSparseCriterion.lua')
-rw-r--r--SpatialSparseCriterion.lua68
1 files changed, 68 insertions, 0 deletions
diff --git a/SpatialSparseCriterion.lua b/SpatialSparseCriterion.lua
new file mode 100644
index 0000000..abb2301
--- /dev/null
+++ b/SpatialSparseCriterion.lua
@@ -0,0 +1,68 @@
+local SpatialSparseCriterion, parent = torch.class('nn.SpatialSparseCriterion', 'nn.SparseCriterion')
+
+function SpatialSparseCriterion:__init(...)
+ parent.__init(self)
+
+ xlua.unpack_class(self, {...},
+ 'nn.SpatialSparseCriterion',
+ 'A spatial extension of the SparseCriterion class.\n'
+ ..' Provides a set of parameters to deal with spatial mini-batch training.',
+ {arg='nbGradients', type='number', help='number of gradients to backpropagate (-1:all, >=1:nb)', default=-1},
+ {arg='sizeAverage', type='number', help='if true, forward() returns an average instead of a sum of errors', default=true}
+ )
+end
+
+function SpatialSparseCriterion:forward(input)
+ self.fullOutput = self.fullOutput or torch.Tensor()
+ self.fullOutput:resize(input:size(2), input:size(3))
+ input.nn.SpatialSparseCriterion_forward(self, input)
+ if self.sizeAverage then
+ self.output = self.fullOutput:mean()
+ else
+ self.output = self.fullOutput:sum()
+ end
+ return self.output
+end
+
+function SpatialSparseCriterion:backward(input,target)
+ -- (1) retrieve adjusted target
+ target = self.target
+ -- (2) resize input gradient map
+ self.gradInput:resizeAs(input):zero()
+ -- (3) compute input gradients, based on the nbGradients param
+ if self.nbGradients == -1 then
+ -- dense gradients
+ input.nn.SpatialSparseCriterion_backward(self, input, self.gradInput)
+ elseif self.nbGradients == 1 then
+ -- only 1 gradient is computed, sampled in the center
+ self.fullGradInput = torch.Tensor() or self.fullGradInput
+ self.fullGradInput:resizeAs(input):zero()
+ input.nn.SpatialSparseCriterion_backward(self, input, self.fullGradInput)
+ local y = math.ceil(self.gradInput:size(2)/2)
+ local x = math.ceil(self.gradInput:size(3)/2)
+ self.gradInput:select(3,x):select(2,y):copy(self.fullGradInput:select(3,x):select(2,y))
+ else
+ -- only N gradients are computed, sampled in random locations
+ self.fullGradInput = torch.Tensor() or self.fullGradInput
+ self.fullGradInput:resizeAs(input):zero()
+ input.nn.SpatialSparseCriterion_backward(self, input, self.fullGradInput)
+ for i = 1,self.nbGradients do
+ local x = math.random(1,self.gradInput:size(1))
+ local y = math.random(1,self.gradInput:size(2))
+ self.gradInput:select(3,x):select(2,y):copy(self.fullGradInput:select(3,x):select(2,y))
+ end
+ end
+ return self.gradInput
+end
+
+function SpatialSparseCriterion:write(file)
+ parent.write(self, file)
+ file:writeDouble(self.resampleTarget)
+ file:writeInt(self.nbGradients)
+end
+
+function SpatialSparseCriterion:read(file)
+ parent.read(self, file)
+ self.resampleTarget= file:readDouble()
+ self.nbGradients = file:readInt()
+end