Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'FlattenTable.lua')
-rw-r--r--FlattenTable.lua16
1 files changed, 8 insertions, 8 deletions
diff --git a/FlattenTable.lua b/FlattenTable.lua
index 1c18255..3fe2fd5 100644
--- a/FlattenTable.lua
+++ b/FlattenTable.lua
@@ -12,7 +12,7 @@ end
local function flatten(output, input)
local input_map -- has the same structure as input, but stores the
-- indices to the corresponding output
- if type(input) == 'table' then
+ if torch.type(input) == 'table' then
input_map = {}
-- forward DFS order
for i = 1, #input do
@@ -30,8 +30,8 @@ local function checkMapping(output, input, input_map)
if input_map == nil or output == nil or input == nil then
return false
end
- if type(input) == 'table' then
- if type(input_map) ~= 'table' then
+ if torch.type(input) == 'table' then
+ if torch.type(input_map) ~= 'table' then
return false
end
if #input ~= #input_map then
@@ -46,7 +46,7 @@ local function checkMapping(output, input, input_map)
end
return true
else
- if type(input_map) ~= 'number' then
+ if torch.type(input_map) ~= 'number' then
return false
end
return output[input_map] == input
@@ -56,7 +56,7 @@ end
-- During BPROP we have to build a gradInput with the same shape as the
-- input. This is a recursive function to build up a gradInput
local function inverseFlatten(gradOutput, input_map)
- if type(input_map) == 'table' then
+ if torch.type(input_map) == 'table' then
local gradInput = {}
for i = 1, #input_map do
gradInput[#gradInput + 1] = inverseFlatten(gradOutput, input_map[i])
@@ -68,7 +68,7 @@ local function inverseFlatten(gradOutput, input_map)
end
function FlattenTable:updateOutput(input)
- assert(type(input) == 'table', 'input must be a table')
+ assert(torch.type(input) == 'table', 'input must be a table')
-- to avoid updating rebuilding the flattened table every updateOutput call
-- we will do a DFS pass over the existing output table and the inputs to
-- see if it needs to be rebuilt.
@@ -80,8 +80,8 @@ function FlattenTable:updateOutput(input)
end
function FlattenTable:updateGradInput(input, gradOutput)
- assert(type(input) == 'table', 'input must be a table')
- assert(type(input) == 'table', 'gradOutput must be a table')
+ assert(torch.type(input) == 'table', 'input must be a table')
+ assert(torch.type(input) == 'table', 'gradOutput must be a table')
-- If the input changes between the updateOutput and updateGradInput call,
-- then we may have to rebuild the input_map! However, let's assume that
-- the input_map is valid and that forward has already been called.