diff options
Diffstat (limited to 'FlattenTable.lua')
-rw-r--r-- | FlattenTable.lua | 16 |
1 files changed, 8 insertions, 8 deletions
diff --git a/FlattenTable.lua b/FlattenTable.lua index 1c18255..3fe2fd5 100644 --- a/FlattenTable.lua +++ b/FlattenTable.lua @@ -12,7 +12,7 @@ end local function flatten(output, input) local input_map -- has the same structure as input, but stores the -- indices to the corresponding output - if type(input) == 'table' then + if torch.type(input) == 'table' then input_map = {} -- forward DFS order for i = 1, #input do @@ -30,8 +30,8 @@ local function checkMapping(output, input, input_map) if input_map == nil or output == nil or input == nil then return false end - if type(input) == 'table' then - if type(input_map) ~= 'table' then + if torch.type(input) == 'table' then + if torch.type(input_map) ~= 'table' then return false end if #input ~= #input_map then @@ -46,7 +46,7 @@ local function checkMapping(output, input, input_map) end return true else - if type(input_map) ~= 'number' then + if torch.type(input_map) ~= 'number' then return false end return output[input_map] == input @@ -56,7 +56,7 @@ end -- During BPROP we have to build a gradInput with the same shape as the -- input. This is a recursive function to build up a gradInput local function inverseFlatten(gradOutput, input_map) - if type(input_map) == 'table' then + if torch.type(input_map) == 'table' then local gradInput = {} for i = 1, #input_map do gradInput[#gradInput + 1] = inverseFlatten(gradOutput, input_map[i]) @@ -68,7 +68,7 @@ local function inverseFlatten(gradOutput, input_map) end function FlattenTable:updateOutput(input) - assert(type(input) == 'table', 'input must be a table') + assert(torch.type(input) == 'table', 'input must be a table') -- to avoid updating rebuilding the flattened table every updateOutput call -- we will do a DFS pass over the existing output table and the inputs to -- see if it needs to be rebuilt. @@ -80,8 +80,8 @@ function FlattenTable:updateOutput(input) end function FlattenTable:updateGradInput(input, gradOutput) - assert(type(input) == 'table', 'input must be a table') - assert(type(input) == 'table', 'gradOutput must be a table') + assert(torch.type(input) == 'table', 'input must be a table') + assert(torch.type(input) == 'table', 'gradOutput must be a table') -- If the input changes between the updateOutput and updateGradInput call, -- then we may have to rebuild the input_map! However, let's assume that -- the input_map is valid and that forward has already been called. |