Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsoumith <soumith@fb.com>2016-04-27 22:04:59 +0300
committersoumith <soumith@fb.com>2016-04-27 22:15:01 +0300
commit4438ec99b8ed74c1b5691b4eea12e66216e8be07 (patch)
tree85d6519f885b95269437bb08862ade0fa0966d8b /Sum.lua
parentc8806b80ee211ce70c612addecebf236abdf8734 (diff)
MultiLabelMarginCriterion fixes for CUDA
Diffstat (limited to 'Sum.lua')
-rw-r--r--Sum.lua10
1 files changed, 10 insertions, 0 deletions
diff --git a/Sum.lua b/Sum.lua
index 77d4fce..5d61c28 100644
--- a/Sum.lua
+++ b/Sum.lua
@@ -41,6 +41,11 @@ function Sum:updateGradInput(input, gradOutput)
-- Instead, do a deepcopy
local size = input:size()
size[dimension] = 1
+ if not gradOutput:isContiguous() then
+ self._gradOutput = self._gradOutput or gradOutput.new()
+ self._gradOutput:resizeAs(gradOutput):copy(gradOutput)
+ gradOutput = self._gradOutput
+ end
gradOutput = gradOutput:view(size)
self.gradInput:resizeAs(input)
self.gradInput:copy(gradOutput:expandAs(input))
@@ -49,3 +54,8 @@ function Sum:updateGradInput(input, gradOutput)
end
return self.gradInput
end
+
+function Sum:clearState()
+ nn.utils.clear(self, '_gradOutput')
+ return parent.clearState(self)
+end