diff options
author | ano <anoidgit@users.noreply.github.com> | 2016-11-11 18:06:56 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2016-11-11 18:06:56 +0300 |
commit | c463e548ecf795b84b171121a0206e5e326d4858 (patch) | |
tree | 71ad500e9add2ac5001a7810cce1ef3a1e4c999b /CMaxTable.lua | |
parent | 7254dc5972200ef9ab4899f607ca9d9185eae099 (diff) |
support cuda (#1028)
fix break caused by different type between input[i] and self.gradInput[i] during updateGradInput, This may happened while you use cuda.
Diffstat (limited to 'CMaxTable.lua')
-rw-r--r-- | CMaxTable.lua | 2 |
1 files changed, 1 insertions, 1 deletions
diff --git a/CMaxTable.lua b/CMaxTable.lua index 3907faf..62cede9 100644 --- a/CMaxTable.lua +++ b/CMaxTable.lua @@ -19,7 +19,7 @@ end function CMaxTable:updateGradInput(input, gradOutput) for i=1,#input do - self.gradInput[i] = torch.Tensor() + self.gradInput[i] = input[i].new() self.gradInput[i]:resizeAs(input[i]):fill(0.0) local mask = torch.eq(self.maxIdx, i) self.gradInput[i]:maskedCopy(mask, gradOutput[mask]) |