diff options
Diffstat (limited to 'nesting.lua')
-rw-r--r-- | nesting.lua | 17 |
1 files changed, 17 insertions, 0 deletions
diff --git a/nesting.lua b/nesting.lua index 18899c1..e1ddd7b 100644 --- a/nesting.lua +++ b/nesting.lua @@ -51,6 +51,23 @@ function nesting.resizeNestedAs(output, input) end end +-- Copies all tensors in the output. +function nesting.copyNested(output, input) + if torch.isTensor(output) then + output:copy(input) + else + for key, child in pairs(input) do + nesting.copyNested(output[key], child) + end + -- Extra elements are removed from the output. + for key, child in pairs(output) do + if not input[key] then + output[key] = nil + end + end + end +end + -- Adds the input to the output. -- The input can contain nested tables. -- The output will contain the same nesting of tables. |