diff options
author | soumith <soumith@fb.com> | 2016-02-20 01:19:52 +0300 |
---|---|---|
committer | soumith <soumith@fb.com> | 2016-02-20 01:19:52 +0300 |
commit | f4a5951c28666ce36317ba4cf8f9a28a92cf2624 (patch) | |
tree | bda5e6c4456d6a0c96f62523ab391f1d5f7040c6 /LookupTable.lua | |
parent | f6b2a3bec54ee95949c6993128f2009214cb71fe (diff) |
allow non-contiguous gradOutput in LookupTable
Diffstat (limited to 'LookupTable.lua')
-rw-r--r-- | LookupTable.lua | 7 |
1 files changed, 7 insertions, 0 deletions
diff --git a/LookupTable.lua b/LookupTable.lua index 95bf0fc..88eabab 100644 --- a/LookupTable.lua +++ b/LookupTable.lua @@ -70,6 +70,12 @@ function LookupTable:accGradParameters(input, gradOutput, scale) error("input must be a vector or matrix") end + if not gradOutput:isContiguous() then + self._gradOutput = self._gradOutput or gradOutput.new() + self._gradOutput:resizeAs(gradOutput):copy(gradOutput) + gradOutput = self._gradOutput + end + self.gradWeight.THNN.LookupTable_accGradParameters( input:cdata(), gradOutput:cdata(), @@ -101,6 +107,7 @@ function LookupTable:type(type, tensorCache) end function LookupTable:clearState() + self._gradOutput = nil return self end |