Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAdam Paszke <adam.paszke@gmail.com>2016-04-26 19:34:24 +0300
committerAdam Paszke <adam.paszke@gmail.com>2016-04-26 21:48:35 +0300
commit7715d0b86f386b478f75b0bd5d4aaa604a9a0681 (patch)
treeb6e984f3977d9b580780171cb3836826f88e8c87 /SpatialClassNLLCriterion.lua
parentc8806b80ee211ce70c612addecebf236abdf8734 (diff)
Add SpatialClassNLLCriterion
Diffstat (limited to 'SpatialClassNLLCriterion.lua')
-rw-r--r--SpatialClassNLLCriterion.lua74
1 files changed, 74 insertions, 0 deletions
diff --git a/SpatialClassNLLCriterion.lua b/SpatialClassNLLCriterion.lua
new file mode 100644
index 0000000..8652e88
--- /dev/null
+++ b/SpatialClassNLLCriterion.lua
@@ -0,0 +1,74 @@
+local THNN = require 'nn.THNN'
+local SpatialClassNLLCriterion, parent = torch.class('nn.SpatialClassNLLCriterion', 'nn.Criterion')
+
+function SpatialClassNLLCriterion:__init(weights, sizeAverage)
+ parent.__init(self)
+ if sizeAverage ~= nil then
+ self.sizeAverage = sizeAverage
+ else
+ self.sizeAverage = true
+ end
+ if weights then
+ assert(weights:dim() == 1, "weights input should be 1-D Tensor")
+ self.weights = weights
+ end
+
+ self.output_tensor = torch.zeros(1)
+ self.total_weight_tensor = torch.ones(1)
+ self.target = torch.zeros(1):long()
+end
+
+function SpatialClassNLLCriterion:__len()
+ if (self.weights) then
+ return #self.weights
+ else
+ return 0
+ end
+end
+
+function SpatialClassNLLCriterion:updateOutput(input, target)
+ if type(target) == 'number' then
+ if input:type() ~= 'torch.CudaTensor' then
+ self.target = self.target:long()
+ end
+ self.target[1] = target
+ elseif target:type() == 'torch.CudaTensor' then
+ self.target = target
+ else
+ self.target = target:long()
+ end
+
+ input.THNN.SpatialClassNLLCriterion_updateOutput(
+ input:cdata(),
+ self.target:cdata(),
+ self.output_tensor:cdata(),
+ self.sizeAverage,
+ THNN.optionalTensor(self.weights),
+ self.total_weight_tensor:cdata()
+ )
+ self.output = self.output_tensor[1]
+ return self.output, self.total_weight_tensor[1]
+end
+
+function SpatialClassNLLCriterion:updateGradInput(input, target)
+ if type(target) == 'number' then
+ self.target[1] = target
+ elseif target:type() == 'torch.CudaTensor' then
+ self.target = target
+ else
+ self.target = target:long()
+ end
+
+ self.gradInput:resizeAs(input):zero()
+
+ input.THNN.SpatialClassNLLCriterion_updateGradInput(
+ input:cdata(),
+ self.target:cdata(),
+ self.gradInput:cdata(),
+ self.sizeAverage,
+ THNN.optionalTensor(self.weights),
+ self.total_weight_tensor:cdata()
+ )
+
+ return self.gradInput
+end