Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2015-09-04 07:12:21 +0300
committerSoumith Chintala <soumith@gmail.com>2015-09-04 07:12:21 +0300
commitbab75d48973ce1767a575fa971a2fb9dc7363ff4 (patch)
tree3ba71368a554420c2d71d035e04df64ac33506a2
parent98ecaf98858ccb373c605437422949c1d0998d7c (diff)
parent7105419c8382c3a9c5c257661c8ab49872fdcddc (diff)
Merge pull request #363 from adamlerer/getParameters
getParameters: improve memory use, fix bug with non-compact tensors
-rw-r--r--Module.lua159
-rw-r--r--test.lua39
2 files changed, 140 insertions, 58 deletions
diff --git a/Module.lua b/Module.lua
index f65e1b1..4c8c21c 100644
--- a/Module.lua
+++ b/Module.lua
@@ -113,12 +113,15 @@ function Module:clone(...)
return clone
end
-function Module:type(type)
+function Module:type(type, tensorCache)
assert(type, 'Module: must provide a type to convert to')
+
-- find all tensors and convert them
for key,param in pairs(self) do
self[key] = nn.utils.recursiveType(param, type)
+
end
+
return self
end
@@ -137,95 +140,135 @@ end
function Module:reset()
end
--- this function flattens arbitrary lists of parameters,
--- even complex shared ones
+-- This function is not easy to understand. It works as follows:
+--
+-- - gather all parameter tensors for this module (and children);
+-- count all parameter values (floats)
+-- - create one ginormous memory area (Storage object) with room for all
+-- parameters
+-- - remap each parameter tensor to point to an area within the ginormous
+-- Storage, and copy it there
+--
+-- It has the effect of making all parameters point to the same memory area,
+-- which is then returned.
+--
+-- The purpose is to allow operations over all parameters (such as momentum
+-- updates and serialization), but it assumes that all parameters are of
+-- the same type (and, in the case of CUDA, on the same device), which
+-- is not always true. Use for_each() to iterate over this module and
+-- children instead.
+--
+-- Module._flattenTensorBuffer can be used by other packages (e.g. cunn)
+-- to specify the type of temporary buffers. For example, the temporary
+-- buffers for CudaTensor could be FloatTensor, to avoid GPU memory usage.
+--
+-- TODO: This logically belongs to torch.Tensor, not nn.
+Module._flattenTensorBuffer = {}
function Module.flatten(parameters)
- local function storageInSet(set, storage)
- local storageAndOffset = set[torch.pointer(storage)]
- if storageAndOffset == nil then
- return nil
- end
- local _, offset = table.unpack(storageAndOffset)
- return offset
+
+ -- returns true if tensor occupies a contiguous region of memory (no holes)
+ local function isCompact(tensor)
+ local sortedStride, perm = torch.sort(
+ torch.LongTensor(tensor:nDimension()):set(tensor:stride()), 1, true)
+ local sortedSize = torch.LongTensor(tensor:nDimension()):set(
+ tensor:size()):index(1, perm)
+ local nRealDim = torch.clamp(sortedStride, 0, 1):sum()
+ sortedStride = sortedStride:narrow(1, 1, nRealDim):clone()
+ sortedSize = sortedSize:narrow(1, 1, nRealDim):clone()
+ local t = tensor.new():set(tensor:storage(), 1,
+ sortedSize:storage(),
+ sortedStride:storage())
+ return t:isContiguous()
end
if not parameters or #parameters == 0 then
return torch.Tensor()
end
local Tensor = parameters[1].new
- local dtype = parameters[1]:type()
+ local TmpTensor = Module._flattenTensorBuffer[torch.type(parameters[1])] or Tensor
+ -- 1. construct the set of all unique storages referenced by parameter tensors
local storages = {}
local nParameters = 0
+ local parameterMeta = {}
for k = 1,#parameters do
- if parameters[k]:type() ~= dtype then
- error("Inconsistent parameter types. " .. parameters[k]:type() ..
- " ~= " .. dtype)
- end
+ local param = parameters[k]
local storage = parameters[k]:storage()
- if not storageInSet(storages, storage) then
- storages[torch.pointer(storage)] = {storage, nParameters}
+ local storageKey = torch.pointer(storage)
+
+ if not storages[storageKey] then
+ storages[storageKey] = {storage, nParameters}
nParameters = nParameters + storage:size()
end
+
+ parameterMeta[k] = {storageOffset = param:storageOffset() +
+ storages[storageKey][2],
+ size = param:size(),
+ stride = param:stride()}
end
- local flatParameters = Tensor(nParameters):fill(1)
- local flatStorage = flatParameters:storage()
+ -- 2. construct a single tensor that will hold all the parameters
+ local flatParameters = TmpTensor(nParameters):zero()
+ -- 3. determine if there are elements in the storage that none of the
+ -- parameter tensors reference ('holes')
+ local tensorsCompact = true
for k = 1,#parameters do
- local storageOffset = storageInSet(storages, parameters[k]:storage())
- parameters[k]:set(flatStorage,
- storageOffset + parameters[k]:storageOffset(),
- parameters[k]:size(),
- parameters[k]:stride())
- parameters[k]:zero()
+ local meta = parameterMeta[k]
+ local tmp = TmpTensor():set(
+ flatParameters:storage(), meta.storageOffset, meta.size, meta.stride)
+ tmp:fill(1)
+ tensorsCompact = tensorsCompact and isCompact(tmp)
end
- local maskParameters = flatParameters:float():clone()
- local cumSumOfHoles = flatParameters:float():cumsum(1)
- local nUsedParameters = nParameters - cumSumOfHoles[#cumSumOfHoles]
- local flatUsedParameters = Tensor(nUsedParameters)
- local flatUsedStorage = flatUsedParameters:storage()
+ local maskParameters = flatParameters:byte():clone()
+ local compactOffsets = flatParameters:long():cumsum(1)
+ local nUsedParameters = compactOffsets[-1]
- for k = 1,#parameters do
- local offset = cumSumOfHoles[parameters[k]:storageOffset()]
- parameters[k]:set(flatUsedStorage,
- parameters[k]:storageOffset() - offset,
- parameters[k]:size(),
- parameters[k]:stride())
+ -- 4. copy storages into the flattened parameter tensor
+ for _, storageAndOffset in pairs(storages) do
+ local storage, offset = table.unpack(storageAndOffset)
+ flatParameters[{{offset+1,offset+storage:size()}}]:copy(Tensor():set(storage))
end
- for _, storageAndOffset in pairs(storages) do
- local k, v = table.unpack(storageAndOffset)
- flatParameters[{{v+1,v+k:size()}}]:copy(Tensor():set(k))
+ -- 5. allow garbage collection
+ storages = nil
+ for k = 1,#parameters do
+ parameters[k]:set(Tensor())
end
- if cumSumOfHoles:sum() == 0 then
- flatUsedParameters:copy(flatParameters)
- else
- local counter = 0
- for k = 1,flatParameters:nElement() do
- if maskParameters[k] == 0 then
- counter = counter + 1
- flatUsedParameters[counter] = flatParameters[counter+cumSumOfHoles[k]]
- end
+ -- 6. compact the flattened parameters if there were holes
+ if nUsedParameters ~= nParameters then
+ assert(tensorsCompact,
+ "Cannot gather tensors that are not compact")
+
+ flatParameters = TmpTensor(nUsedParameters):copy(
+ flatParameters:maskedSelect(maskParameters))
+ for k = 1,#parameters do
+ parameterMeta[k].storageOffset =
+ compactOffsets[parameterMeta[k].storageOffset]
end
- assert (counter == nUsedParameters)
end
- return flatUsedParameters
+
+ if TmpTensor ~= Tensor then
+ flatParameters = Tensor(flatParameters:nElement()):copy(flatParameters)
+ end
+
+ -- 7. fix up the parameter tensors to point at the flattened parameters
+ for k = 1,#parameters do
+ parameters[k]:set(flatParameters:storage(),
+ parameterMeta[k].storageOffset,
+ parameterMeta[k].size,
+ parameterMeta[k].stride)
+ end
+
+ return flatParameters
end
function Module:getParameters()
-- get parameters
local parameters,gradParameters = self:parameters()
- -- flatten parameters and gradients
- local flatParameters = nn.Module.flatten(parameters)
- collectgarbage()
- local flatGradParameters = nn.Module.flatten(gradParameters)
- collectgarbage()
-
- -- return new flat vector that contains all discrete parameters
- return flatParameters, flatGradParameters
+ return Module.flatten(parameters), Module.flatten(gradParameters)
end
function Module:__call__(input, gradOutput)
diff --git a/test.lua b/test.lua
index 3911118..3fff151 100644
--- a/test.lua
+++ b/test.lua
@@ -2793,6 +2793,45 @@ function nntest.Module_getParameters_8()
end
+function nntest.Module_getParameters_10()
+ -- tensors are non-contiguous but compact; they can be gathered
+ local L = nn.Linear(10,10)
+ L.weight = torch.Tensor(10,10):t():fill(1)
+ local tmp = torch.Tensor(10,10):fill(2)
+ L.bias = tmp:select(1,2)
+ local P = L:getParameters()
+ mytester:asserteq(L.weight:mean(), 1)
+ mytester:asserteq(L.bias:mean(), 2)
+ mytester:asserteq(L.weight:storage(), L.bias:storage())
+ mytester:asserteq(P:nElement(), 110)
+ mytester:asserteq(P:storage():size(), 110)
+ mytester:assertlt(L.bias[{ {10} }]:storageOffset() - 1, L.bias:storage():size())
+end
+
+function nntest.Module_getParameters_11()
+ -- tensors are non-compact; they can't be gathered
+ local L = nn.Linear(10,10)
+ local tmp = torch.Tensor(10,10):fill(2)
+ L.bias = tmp:select(2,2)
+ local ok, err = pcall(L.getParameters, L)
+ mytester:assert(not ok)
+end
+
+function nntest.Module_getParameters_12()
+ -- tensors are expanded (i.e. have dimension 0)
+ local L = nn.Linear(10,10)
+ L.weight = torch.Tensor(10, 1):fill(1)
+ torch.expand(L.weight, 10, 10)
+ L.bias = torch.Tensor(10):fill(2)
+ local P = L:getParameters()
+ mytester:asserteq(L.weight:mean(), 1)
+ mytester:asserteq(L.bias:mean(), 2)
+ mytester:asserteq(L.weight:storage(), L.bias:storage())
+ mytester:asserteq(P:nElement(), 20)
+ mytester:asserteq(P:storage():size(), 20)
+ mytester:assertlt(L.bias[{ {10} }]:storageOffset() - 1, L.bias:storage():size())
+end
+
function nntest.Module_listModules()
local batchSize = 4
local inputSize, outputSize = 7, 6