diff options
author | Adam Paszke <adam.paszke@gmail.com> | 2016-04-26 19:34:24 +0300 |
---|---|---|
committer | Adam Paszke <adam.paszke@gmail.com> | 2016-04-26 21:48:35 +0300 |
commit | 7715d0b86f386b478f75b0bd5d4aaa604a9a0681 (patch) | |
tree | b6e984f3977d9b580780171cb3836826f88e8c87 /SpatialClassNLLCriterion.lua | |
parent | c8806b80ee211ce70c612addecebf236abdf8734 (diff) |
Add SpatialClassNLLCriterion
Diffstat (limited to 'SpatialClassNLLCriterion.lua')
-rw-r--r-- | SpatialClassNLLCriterion.lua | 74 |
1 files changed, 74 insertions, 0 deletions
diff --git a/SpatialClassNLLCriterion.lua b/SpatialClassNLLCriterion.lua new file mode 100644 index 0000000..8652e88 --- /dev/null +++ b/SpatialClassNLLCriterion.lua @@ -0,0 +1,74 @@ +local THNN = require 'nn.THNN' +local SpatialClassNLLCriterion, parent = torch.class('nn.SpatialClassNLLCriterion', 'nn.Criterion') + +function SpatialClassNLLCriterion:__init(weights, sizeAverage) + parent.__init(self) + if sizeAverage ~= nil then + self.sizeAverage = sizeAverage + else + self.sizeAverage = true + end + if weights then + assert(weights:dim() == 1, "weights input should be 1-D Tensor") + self.weights = weights + end + + self.output_tensor = torch.zeros(1) + self.total_weight_tensor = torch.ones(1) + self.target = torch.zeros(1):long() +end + +function SpatialClassNLLCriterion:__len() + if (self.weights) then + return #self.weights + else + return 0 + end +end + +function SpatialClassNLLCriterion:updateOutput(input, target) + if type(target) == 'number' then + if input:type() ~= 'torch.CudaTensor' then + self.target = self.target:long() + end + self.target[1] = target + elseif target:type() == 'torch.CudaTensor' then + self.target = target + else + self.target = target:long() + end + + input.THNN.SpatialClassNLLCriterion_updateOutput( + input:cdata(), + self.target:cdata(), + self.output_tensor:cdata(), + self.sizeAverage, + THNN.optionalTensor(self.weights), + self.total_weight_tensor:cdata() + ) + self.output = self.output_tensor[1] + return self.output, self.total_weight_tensor[1] +end + +function SpatialClassNLLCriterion:updateGradInput(input, target) + if type(target) == 'number' then + self.target[1] = target + elseif target:type() == 'torch.CudaTensor' then + self.target = target + else + self.target = target:long() + end + + self.gradInput:resizeAs(input):zero() + + input.THNN.SpatialClassNLLCriterion_updateGradInput( + input:cdata(), + self.target:cdata(), + self.gradInput:cdata(), + self.sizeAverage, + THNN.optionalTensor(self.weights), + self.total_weight_tensor:cdata() + ) + + return self.gradInput +end |