Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorano <anoidgit@users.noreply.github.com>2016-11-11 18:06:56 +0300
committerSoumith Chintala <soumith@gmail.com>2016-11-11 18:06:56 +0300
commitc463e548ecf795b84b171121a0206e5e326d4858 (patch)
tree71ad500e9add2ac5001a7810cce1ef3a1e4c999b /CMaxTable.lua
parent7254dc5972200ef9ab4899f607ca9d9185eae099 (diff)
support cuda (#1028)
fix break caused by different type between input[i] and self.gradInput[i] during updateGradInput, This may happened while you use cuda.
Diffstat (limited to 'CMaxTable.lua')
-rw-r--r--CMaxTable.lua2
1 files changed, 1 insertions, 1 deletions
diff --git a/CMaxTable.lua b/CMaxTable.lua
index 3907faf..62cede9 100644
--- a/CMaxTable.lua
+++ b/CMaxTable.lua
@@ -19,7 +19,7 @@ end
function CMaxTable:updateGradInput(input, gradOutput)
for i=1,#input do
- self.gradInput[i] = torch.Tensor()
+ self.gradInput[i] = input[i].new()
self.gradInput[i]:resizeAs(input[i]):fill(0.0)
local mask = torch.eq(self.maxIdx, i)
self.gradInput[i]:maskedCopy(mask, gradOutput[mask])