Welcome to mirror list, hosted at ThFree Co, Russian Federation.

utils.lua - github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 4d89568e51425e6804ed606572fd1eb23c834188 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
nn.utils = {}

function nn.utils.recursiveType(param, type_str)
   if torch.type(param) == 'table' then
      for k, v in pairs(param) do
         param[k] = nn.utils.recursiveType(v, type_str)
      end
   elseif torch.isTypeOf(param, 'nn.Module') or
          torch.isTypeOf(param, 'nn.Criterion') then
      param:type(type_str)
   elseif torch.isTensor(param) then
       param = param:type(type_str)
   end
   return param
end

function nn.utils.recursiveResizeAs(t1,t2)
   if torch.type(t2) == 'table' then
      t1 = (torch.type(t1) == 'table') and t1 or {t1}
      for key,_ in pairs(t2) do
         t1[key], t2[key] = nn.utils.recursiveResizeAs(t1[key], t2[key])
      end
   elseif torch.isTensor(t2) then
      t1 = torch.isTensor(t1) and t1 or t2.new()
      t1:resizeAs(t2)
   else
      error("expecting nested tensors or tables. Got "..
            torch.type(t1).." and "..torch.type(t2).." instead")
   end
   return t1, t2
end

function nn.utils.recursiveFill(t2, val)
   if torch.type(t2) == 'table' then
      for key,_ in pairs(t2) do
         t2[key] = nn.utils.recursiveFill(t2[key], val)
      end
   elseif torch.isTensor(t2) then
      t2:fill(val)
   else
      error("expecting tensor or table thereof. Got "
           ..torch.type(t2).." instead")
   end
   return t2
end

function nn.utils.recursiveAdd(t1, val, t2)
   if not t2 then
      assert(val, "expecting at least two arguments")
      t2 = val
      val = 1
   end
   val = val or 1
   if torch.type(t2) == 'table' then
      t1 = (torch.type(t1) == 'table') and t1 or {t1}
      for key,_ in pairs(t2) do
         t1[key], t2[key] = nn.utils.recursiveAdd(t1[key], val, t2[key])
      end
   elseif torch.isTensor(t2) and torch.isTensor(t2) then
      t1:add(val, t2)
   else
      error("expecting nested tensors or tables. Got "..
            torch.type(t1).." and "..torch.type(t2).." instead")
   end
   return t1, t2
end

function nn.utils.addSingletonDimension(t, dim)
  assert(torch.isTensor(t), "input tensor expected")
  local dim = dim or 1
  assert(dim > 0 and dim <= t:dim(), "invalid dimension: " .. dim)

  local view = t.new()
  local size = torch.LongStorage(t:dim() + 1)
  local stride = torch.LongStorage(t:dim() + 1)

  for d = 1, dim - 1 do
    size[d] = t:size(d)
    stride[d] = t:stride(d)
  end
  size[dim] = 1
  stride[dim] = 1
  for d = dim + 1, t:dim() + 1 do
    size[d] = t:size(d - 1)
    stride[d] = t:stride(d - 1)
  end

  view:set(t:storage(), t:storageOffset(), size, stride)
  return view
end


table.unpack = table.unpack or unpack