Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsoumith <soumith@fb.com>2016-02-20 01:19:52 +0300
committersoumith <soumith@fb.com>2016-02-20 01:19:52 +0300
commitf4a5951c28666ce36317ba4cf8f9a28a92cf2624 (patch)
treebda5e6c4456d6a0c96f62523ab391f1d5f7040c6 /LookupTable.lua
parentf6b2a3bec54ee95949c6993128f2009214cb71fe (diff)
allow non-contiguous gradOutput in LookupTable
Diffstat (limited to 'LookupTable.lua')
-rw-r--r--LookupTable.lua7
1 files changed, 7 insertions, 0 deletions
diff --git a/LookupTable.lua b/LookupTable.lua
index 95bf0fc..88eabab 100644
--- a/LookupTable.lua
+++ b/LookupTable.lua
@@ -70,6 +70,12 @@ function LookupTable:accGradParameters(input, gradOutput, scale)
error("input must be a vector or matrix")
end
+ if not gradOutput:isContiguous() then
+ self._gradOutput = self._gradOutput or gradOutput.new()
+ self._gradOutput:resizeAs(gradOutput):copy(gradOutput)
+ gradOutput = self._gradOutput
+ end
+
self.gradWeight.THNN.LookupTable_accGradParameters(
input:cdata(),
gradOutput:cdata(),
@@ -101,6 +107,7 @@ function LookupTable:type(type, tensorCache)
end
function LookupTable:clearState()
+ self._gradOutput = nil
return self
end