diff options
author | soumith <soumith@fb.com> | 2016-04-27 22:04:59 +0300 |
---|---|---|
committer | soumith <soumith@fb.com> | 2016-04-27 22:15:01 +0300 |
commit | 4438ec99b8ed74c1b5691b4eea12e66216e8be07 (patch) | |
tree | 85d6519f885b95269437bb08862ade0fa0966d8b /Sum.lua | |
parent | c8806b80ee211ce70c612addecebf236abdf8734 (diff) |
MultiLabelMarginCriterion fixes for CUDA
Diffstat (limited to 'Sum.lua')
-rw-r--r-- | Sum.lua | 10 |
1 files changed, 10 insertions, 0 deletions
@@ -41,6 +41,11 @@ function Sum:updateGradInput(input, gradOutput) -- Instead, do a deepcopy local size = input:size() size[dimension] = 1 + if not gradOutput:isContiguous() then + self._gradOutput = self._gradOutput or gradOutput.new() + self._gradOutput:resizeAs(gradOutput):copy(gradOutput) + gradOutput = self._gradOutput + end gradOutput = gradOutput:view(size) self.gradInput:resizeAs(input) self.gradInput:copy(gradOutput:expandAs(input)) @@ -49,3 +54,8 @@ function Sum:updateGradInput(input, gradOutput) end return self.gradInput end + +function Sum:clearState() + nn.utils.clear(self, '_gradOutput') + return parent.clearState(self) +end |