Welcome to mirror list, hosted at ThFree Co, Russian Federation.

THStorage.c « TH « lib - github.com/torch/torch7.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: cea0f950e6ec0d21abd16a4e351f13e921d1da93 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
#include "THAtomic.h"
#include "THStorage.h"

#include "generic/THStorage.c"
#include "THGenerateAllTypes.h"

#include "generic/THStorage.c"
#include "THGenerateHalfType.h"

#include "generic/THStorageCopy.c"
#include "THGenerateAllTypes.h"

#include "generic/THStorageCopy.c"
#include "THGenerateHalfType.h"


THDescBuff THLongStorage_sizeDesc(const THLongStorage *size) {
  const int L = TH_DESC_BUFF_LEN;
  THDescBuff buf;
  char *str = buf.str;
  int n = 0;
  n += snprintf(str, L-n, "[");
  int i;
  for(i = 0; i < size->size; i++) {
    if(n >= L) break;
    n += snprintf(str+n, L-n, "%ld", size->data[i]);
    if(i < size->size-1) {
      n += snprintf(str+n, L-n, " x ");
    }
  }
  if(n < L - 2) {
    snprintf(str+n, L-n, "]");
  } else {
    snprintf(str+L-5, 5, "...]");
  }
  return buf;
}

TH_API THLongStorage *THLongStorage_newInferSize(THLongStorage *size, ptrdiff_t nElement)
{
  ptrdiff_t total_size = (size->size > 0 ? 1 : 0);
  ptrdiff_t dim_infer = -1;
  ptrdiff_t i;
  for (i = 0; i < size->size; i++) {
    if (size->data[i] == -1) {
      THArgCheck(dim_infer == -1, 1, "only one dimension can be inferred");
      dim_infer = i;
    } else {
      total_size *= size->data[i];
    }
  }
  if (dim_infer != -1) {
    THDescBuff buf = THLongStorage_sizeDesc(size);
    THArgCheck(total_size > 0 && nElement % total_size == 0, 2,
        "size '%s' is invalid for input of with %td elements", buf.str, nElement);
  } else {
    THDescBuff buf = THLongStorage_sizeDesc(size);
    THArgCheck(nElement == total_size, 2,
        "size '%s' is invalid for input of with %td elements", buf.str, nElement);
  }
  THLongStorage* copy = THLongStorage_newWithSize(size->size);
  THLongStorage_copy(copy, size);
  if (dim_infer != -1) {
    copy->data[dim_infer] = nElement / total_size;
  }
  return copy;
}

TH_API void THLongStorage_calculateExpandGeometry(long *tensorSizes, long *tensorStrides, long tensorDim, THLongStorage *sizes, long **esz, long **est) {
  ptrdiff_t ndim = THLongStorage_size(sizes);
  long numUnsqueezed = ndim - tensorDim;

  long *expandedSizes = THAlloc(sizeof(long)*ndim);
  long *expandedStrides = THAlloc(sizeof(long)*ndim);

  for (long i = numUnsqueezed; i < ndim; ++i) {
    expandedSizes[i] = tensorSizes[i - numUnsqueezed];
    expandedStrides[i] = tensorStrides[i - numUnsqueezed];
  }

  for (long i = numUnsqueezed - 1; i > -1; --i) {
    expandedSizes[i] = 1;
    expandedStrides[i] = expandedSizes[i+1] * expandedStrides[i+1];
  }

  // create a new geometry for the tensor
  for (long i = 0; i < ndim; ++i) {
    long size = expandedSizes[i];
    long targetSize = THLongStorage_data(sizes)[i];
    if (size == 1) {
      if (targetSize != 1) {
        expandedSizes[i] = targetSize;
        expandedStrides[i] = 0;
      }
    } else if (size != targetSize) {
      THFree(expandedSizes);
      THFree(expandedStrides);
      THError("The expanded size of the tensor (%d) must match the existing size (%d) at \
              non-singleton dimension %ld.", targetSize, size, i);
    }
  }
  *esz = expandedSizes;
  *est = expandedStrides;
}