diff options
author | Soumith Chintala <soumith@gmail.com> | 2015-09-04 07:12:21 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2015-09-04 07:12:21 +0300 |
commit | bab75d48973ce1767a575fa971a2fb9dc7363ff4 (patch) | |
tree | 3ba71368a554420c2d71d035e04df64ac33506a2 | |
parent | 98ecaf98858ccb373c605437422949c1d0998d7c (diff) | |
parent | 7105419c8382c3a9c5c257661c8ab49872fdcddc (diff) |
Merge pull request #363 from adamlerer/getParameters
getParameters: improve memory use, fix bug with non-compact tensors
-rw-r--r-- | Module.lua | 159 | ||||
-rw-r--r-- | test.lua | 39 |
2 files changed, 140 insertions, 58 deletions
@@ -113,12 +113,15 @@ function Module:clone(...) return clone end -function Module:type(type) +function Module:type(type, tensorCache) assert(type, 'Module: must provide a type to convert to') + -- find all tensors and convert them for key,param in pairs(self) do self[key] = nn.utils.recursiveType(param, type) + end + return self end @@ -137,95 +140,135 @@ end function Module:reset() end --- this function flattens arbitrary lists of parameters, --- even complex shared ones +-- This function is not easy to understand. It works as follows: +-- +-- - gather all parameter tensors for this module (and children); +-- count all parameter values (floats) +-- - create one ginormous memory area (Storage object) with room for all +-- parameters +-- - remap each parameter tensor to point to an area within the ginormous +-- Storage, and copy it there +-- +-- It has the effect of making all parameters point to the same memory area, +-- which is then returned. +-- +-- The purpose is to allow operations over all parameters (such as momentum +-- updates and serialization), but it assumes that all parameters are of +-- the same type (and, in the case of CUDA, on the same device), which +-- is not always true. Use for_each() to iterate over this module and +-- children instead. +-- +-- Module._flattenTensorBuffer can be used by other packages (e.g. cunn) +-- to specify the type of temporary buffers. For example, the temporary +-- buffers for CudaTensor could be FloatTensor, to avoid GPU memory usage. +-- +-- TODO: This logically belongs to torch.Tensor, not nn. +Module._flattenTensorBuffer = {} function Module.flatten(parameters) - local function storageInSet(set, storage) - local storageAndOffset = set[torch.pointer(storage)] - if storageAndOffset == nil then - return nil - end - local _, offset = table.unpack(storageAndOffset) - return offset + + -- returns true if tensor occupies a contiguous region of memory (no holes) + local function isCompact(tensor) + local sortedStride, perm = torch.sort( + torch.LongTensor(tensor:nDimension()):set(tensor:stride()), 1, true) + local sortedSize = torch.LongTensor(tensor:nDimension()):set( + tensor:size()):index(1, perm) + local nRealDim = torch.clamp(sortedStride, 0, 1):sum() + sortedStride = sortedStride:narrow(1, 1, nRealDim):clone() + sortedSize = sortedSize:narrow(1, 1, nRealDim):clone() + local t = tensor.new():set(tensor:storage(), 1, + sortedSize:storage(), + sortedStride:storage()) + return t:isContiguous() end if not parameters or #parameters == 0 then return torch.Tensor() end local Tensor = parameters[1].new - local dtype = parameters[1]:type() + local TmpTensor = Module._flattenTensorBuffer[torch.type(parameters[1])] or Tensor + -- 1. construct the set of all unique storages referenced by parameter tensors local storages = {} local nParameters = 0 + local parameterMeta = {} for k = 1,#parameters do - if parameters[k]:type() ~= dtype then - error("Inconsistent parameter types. " .. parameters[k]:type() .. - " ~= " .. dtype) - end + local param = parameters[k] local storage = parameters[k]:storage() - if not storageInSet(storages, storage) then - storages[torch.pointer(storage)] = {storage, nParameters} + local storageKey = torch.pointer(storage) + + if not storages[storageKey] then + storages[storageKey] = {storage, nParameters} nParameters = nParameters + storage:size() end + + parameterMeta[k] = {storageOffset = param:storageOffset() + + storages[storageKey][2], + size = param:size(), + stride = param:stride()} end - local flatParameters = Tensor(nParameters):fill(1) - local flatStorage = flatParameters:storage() + -- 2. construct a single tensor that will hold all the parameters + local flatParameters = TmpTensor(nParameters):zero() + -- 3. determine if there are elements in the storage that none of the + -- parameter tensors reference ('holes') + local tensorsCompact = true for k = 1,#parameters do - local storageOffset = storageInSet(storages, parameters[k]:storage()) - parameters[k]:set(flatStorage, - storageOffset + parameters[k]:storageOffset(), - parameters[k]:size(), - parameters[k]:stride()) - parameters[k]:zero() + local meta = parameterMeta[k] + local tmp = TmpTensor():set( + flatParameters:storage(), meta.storageOffset, meta.size, meta.stride) + tmp:fill(1) + tensorsCompact = tensorsCompact and isCompact(tmp) end - local maskParameters = flatParameters:float():clone() - local cumSumOfHoles = flatParameters:float():cumsum(1) - local nUsedParameters = nParameters - cumSumOfHoles[#cumSumOfHoles] - local flatUsedParameters = Tensor(nUsedParameters) - local flatUsedStorage = flatUsedParameters:storage() + local maskParameters = flatParameters:byte():clone() + local compactOffsets = flatParameters:long():cumsum(1) + local nUsedParameters = compactOffsets[-1] - for k = 1,#parameters do - local offset = cumSumOfHoles[parameters[k]:storageOffset()] - parameters[k]:set(flatUsedStorage, - parameters[k]:storageOffset() - offset, - parameters[k]:size(), - parameters[k]:stride()) + -- 4. copy storages into the flattened parameter tensor + for _, storageAndOffset in pairs(storages) do + local storage, offset = table.unpack(storageAndOffset) + flatParameters[{{offset+1,offset+storage:size()}}]:copy(Tensor():set(storage)) end - for _, storageAndOffset in pairs(storages) do - local k, v = table.unpack(storageAndOffset) - flatParameters[{{v+1,v+k:size()}}]:copy(Tensor():set(k)) + -- 5. allow garbage collection + storages = nil + for k = 1,#parameters do + parameters[k]:set(Tensor()) end - if cumSumOfHoles:sum() == 0 then - flatUsedParameters:copy(flatParameters) - else - local counter = 0 - for k = 1,flatParameters:nElement() do - if maskParameters[k] == 0 then - counter = counter + 1 - flatUsedParameters[counter] = flatParameters[counter+cumSumOfHoles[k]] - end + -- 6. compact the flattened parameters if there were holes + if nUsedParameters ~= nParameters then + assert(tensorsCompact, + "Cannot gather tensors that are not compact") + + flatParameters = TmpTensor(nUsedParameters):copy( + flatParameters:maskedSelect(maskParameters)) + for k = 1,#parameters do + parameterMeta[k].storageOffset = + compactOffsets[parameterMeta[k].storageOffset] end - assert (counter == nUsedParameters) end - return flatUsedParameters + + if TmpTensor ~= Tensor then + flatParameters = Tensor(flatParameters:nElement()):copy(flatParameters) + end + + -- 7. fix up the parameter tensors to point at the flattened parameters + for k = 1,#parameters do + parameters[k]:set(flatParameters:storage(), + parameterMeta[k].storageOffset, + parameterMeta[k].size, + parameterMeta[k].stride) + end + + return flatParameters end function Module:getParameters() -- get parameters local parameters,gradParameters = self:parameters() - -- flatten parameters and gradients - local flatParameters = nn.Module.flatten(parameters) - collectgarbage() - local flatGradParameters = nn.Module.flatten(gradParameters) - collectgarbage() - - -- return new flat vector that contains all discrete parameters - return flatParameters, flatGradParameters + return Module.flatten(parameters), Module.flatten(gradParameters) end function Module:__call__(input, gradOutput) @@ -2793,6 +2793,45 @@ function nntest.Module_getParameters_8() end +function nntest.Module_getParameters_10() + -- tensors are non-contiguous but compact; they can be gathered + local L = nn.Linear(10,10) + L.weight = torch.Tensor(10,10):t():fill(1) + local tmp = torch.Tensor(10,10):fill(2) + L.bias = tmp:select(1,2) + local P = L:getParameters() + mytester:asserteq(L.weight:mean(), 1) + mytester:asserteq(L.bias:mean(), 2) + mytester:asserteq(L.weight:storage(), L.bias:storage()) + mytester:asserteq(P:nElement(), 110) + mytester:asserteq(P:storage():size(), 110) + mytester:assertlt(L.bias[{ {10} }]:storageOffset() - 1, L.bias:storage():size()) +end + +function nntest.Module_getParameters_11() + -- tensors are non-compact; they can't be gathered + local L = nn.Linear(10,10) + local tmp = torch.Tensor(10,10):fill(2) + L.bias = tmp:select(2,2) + local ok, err = pcall(L.getParameters, L) + mytester:assert(not ok) +end + +function nntest.Module_getParameters_12() + -- tensors are expanded (i.e. have dimension 0) + local L = nn.Linear(10,10) + L.weight = torch.Tensor(10, 1):fill(1) + torch.expand(L.weight, 10, 10) + L.bias = torch.Tensor(10):fill(2) + local P = L:getParameters() + mytester:asserteq(L.weight:mean(), 1) + mytester:asserteq(L.bias:mean(), 2) + mytester:asserteq(L.weight:storage(), L.bias:storage()) + mytester:asserteq(P:nElement(), 20) + mytester:asserteq(P:storage():size(), 20) + mytester:assertlt(L.bias[{ {10} }]:storageOffset() - 1, L.bias:storage():size()) +end + function nntest.Module_listModules() local batchSize = 4 local inputSize, outputSize = 7, 6 |