Welcome to mirror list, hosted at ThFree Co, Russian Federation.

element.h « cpu « tensors « src - github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: d98a0f9b1a9ccc86ca611d57b1fc6c4fb4ee6f84 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#pragma once

#include "tensors/tensor.h"

namespace marian {
namespace cpu {

// Function in this header are supposed to execute element-wise operations
// (passed in as a Functor) on arbitrary numbers of tensors. The templates
// are required to implement correct broadcasting of operations across
// a fixed-at-compile-time but in principle arbitrary number of dimensions.

// @TODO: generalize to vector operations, possible using specializations

// single loop over outer dimension. Recursively creates nested loops
// down to inner dimension and to single elements. Since this is based
// on strides, it correctly broadcasts to all dimensions without additional
// computation.
// Compiler optimizes this to single construct with nested(?) loops.

namespace F = marian::functional;

template <size_t I = 0>
struct E {
  template <size_t numArg, class Functor, typename ElementType>
  static inline void element(
      const Functor& functor,
      F::Array<F::Tensor<ElementType>, numArg>& tensors,
      F::Array<int, numArg> indices) {
    const auto& shape = tensors[0].shape();

    // loop over outer-most dimension
    for(int i = 0; i < shape[I]; ++i) {
      // call loop for next-inner dimension
      E<I + 1>::element(functor, tensors, indices);

      // increase index for current dimension by stride or 0 if broadcasting.
      // bstride(i) is look-up value, either equal to stride if the
      // corresponding dim is larger 1 or 0 if the dim is 1.
      for(size_t k = 0; k < numArg; ++k) {
        //int stride = tensors[k].shape().stride(I);
        //indices[k] += stride == 1 ? 0 : stride;
        indices[k] += tensors[k].shape().bstride(I);
      }
    }
  }
};

// specialization for inner-most single element (recursive stopping criterion)
// using const reference for indices here to avoid copying. No loop.
template <>
struct E<F::Shape::size()> {
  template <size_t numArg, class Functor, typename ElementType>
  static inline void element(
      const Functor& functor,
      F::Array<F::Tensor<ElementType>, numArg>& tensors,
      const F::Array<int, numArg>& indices) {
    // just apply the function for all indexed elements across all tensors
    // @TODO: use converting operator[] on tensor
    tensors[0].data()[indices[0]] = F::apply(functor, tensors, indices);
  }
};

template <typename ElementType, class Functor, class... Tensors>
void element(const Functor& functor, marian::Tensor out, Tensors... tensors) {

  // Number of input tensors + 1 (output tensor)
  constexpr size_t argNum = sizeof...(tensors) + 1;
  // create and initialize indices to 0, one index per tensor
  F::Array<int, argNum> indices;
  indices.fill(0);

  // call elementwise operation going from outer-most dimension
  // to inner-most element.
  F::Array<F::Tensor<ElementType>, argNum> gTensors = {out, tensors...};
  E<0>::element(functor, gTensors, indices);
}

// Dispatch elementwise functions with float element type based on number of 
// elements. If dividable by 8 and AVX2 is available (TODO: check this?) use
// AVX2 specific intrinsics. Similar for 4 and AVX. TODO: Add AVX512 support.
template <class Functor, class... Tensors>
void elementFloat(const Functor& functor, marian::Tensor out, Tensors... tensors) {
#ifndef __CUDACC__
  std::vector<marian::Tensor> ts({tensors...});
  bool div8 = true;
  bool div4 = true;

  if(out->shape()[-1] % 8 != 0)
    div8 = false;
  if(out->shape()[-1] % 4 != 0)
    div4 = false;
  for(auto t : ts) {
    if(t->shape()[-1] % 8 != 0)
      div8 = false;
    if(t->shape()[-1] % 4 != 0)
      div4 = false;
  }

  if(div8) {
    // std::cerr << "8: " << functor.to_string() << std::endl;
#ifdef __AVX__
    element<float32x8>(functor, out, tensors...);
    return;
#endif
  }

  if(div4) {
    // std::cerr << "4: " << functor.to_string() << std::endl;
    element<float32x4>(functor, out, tensors...);
    return;
  }
#endif
  // std::cerr << "1: " << functor.to_string() << std::endl;
  element<float>(functor, out, tensors...);
}

// main call to function executing element-wise operation
template <class Functor, class... Tensors>
void Element(const Functor& functor, marian::Tensor out, Tensors... tensors) {
  switch(out->type()) {
    case Type::float32: elementFloat(functor, out, tensors...); break;
    //case Type::uint32:  element<uint32_t>(functor, out, tensors...); break;
    default: ABORT("Unsupported type for element-wise operation: {}", out->type()); break;
  }
}

}  // namespace cpu
}  // namespace marian