Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas LĂ©onard <nick@nikopia.org>2017-06-23 20:24:50 +0300
committerGitHub <noreply@github.com>2017-06-23 20:24:50 +0300
commitee35779c3a5b849f45290c942d66c6cdc7dd6b45 (patch)
tree2da6e709fff6c7396c6cfef93641a37792e75248
parentb41b5ffad8d68fb7f33709ce7067f159f3092dd3 (diff)
parentad95cb622fad88e2108f4d3d00ecd41070fe8672 (diff)
Merge pull request #1245 from nicholas-leonard/recursive-resize
recursive functions use resize instead of resizeAs
-rw-r--r--utils.lua15
1 files changed, 10 insertions, 5 deletions
diff --git a/utils.lua b/utils.lua
index 8f9c203..17b52af 100644
--- a/utils.lua
+++ b/utils.lua
@@ -87,7 +87,7 @@ function nn.utils.recursiveResizeAs(t1,t2)
end
elseif torch.isTensor(t2) then
t1 = torch.isTensor(t1) and t1 or t2.new()
- t1:resizeAs(t2)
+ t1:resize(t2:size())
else
error("expecting nested tensors or tables. Got "..
torch.type(t1).." and "..torch.type(t2).." instead")
@@ -130,15 +130,20 @@ function nn.utils.recursiveAdd(t1, val, t2)
return t1, t2
end
-function nn.utils.recursiveCopy(t1,t2)
+function nn.utils.recursiveCopy(t1,t2,async)
if torch.type(t2) == 'table' then
t1 = (torch.type(t1) == 'table') and t1 or {t1}
for key,_ in pairs(t2) do
- t1[key], t2[key] = nn.utils.recursiveCopy(t1[key], t2[key])
+ t1[key], t2[key] = nn.utils.recursiveCopy(t1[key], t2[key], async)
end
elseif torch.isTensor(t2) then
t1 = torch.isTensor(t1) and t1 or t2.new()
- t1:resizeAs(t2):copy(t2)
+ t1:resize(t2:size())
+ if async then
+ t1:copyAsync(t2)
+ else
+ t1:copy(t2)
+ end
else
error("expecting nested tensors or tables. Got "..
torch.type(t1).." and "..torch.type(t2).." instead")
@@ -185,7 +190,7 @@ function nn.utils.contiguousView(output, input, ...)
if input:isContiguous() then
output:view(input, ...)
else
- output:resizeAs(input)
+ output:resize(input:size())
output:copy(input)
output:view(output, ...)
end