Welcome to mirror list, hosted at ThFree Co, Russian Federation.

parallel_for_cxx.cc « ceres « internal « ceres « extern - git.blender.org/blender.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 5b78db19a444e36ccd8127cc0450dcf382f4e5ef (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
// Ceres Solver - A fast non-linear least squares minimizer
// Copyright 2018 Google Inc. All rights reserved.
// http://ceres-solver.org/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// * Redistributions of source code must retain the above copyright notice,
//   this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright notice,
//   this list of conditions and the following disclaimer in the documentation
//   and/or other materials provided with the distribution.
// * Neither the name of Google Inc. nor the names of its contributors may be
//   used to endorse or promote products derived from this software without
//   specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
// POSSIBILITY OF SUCH DAMAGE.
//
// Author: vitus@google.com (Michael Vitus)

// This include must come before any #ifndef check on Ceres compile options.
#include "ceres/internal/config.h"

#ifdef CERES_USE_CXX_THREADS

#include <cmath>
#include <condition_variable>
#include <memory>
#include <mutex>

#include "ceres/concurrent_queue.h"
#include "ceres/parallel_for.h"
#include "ceres/scoped_thread_token.h"
#include "ceres/thread_token_provider.h"
#include "glog/logging.h"

namespace ceres {
namespace internal {
namespace {
// This class creates a thread safe barrier which will block until a
// pre-specified number of threads call Finished.  This allows us to block the
// main thread until all the parallel threads are finished processing all the
// work.
class BlockUntilFinished {
 public:
  explicit BlockUntilFinished(int num_total)
      : num_finished_(0), num_total_(num_total) {}

  // Increment the number of jobs that have finished and signal the blocking
  // thread if all jobs have finished.
  void Finished() {
    std::lock_guard<std::mutex> lock(mutex_);
    ++num_finished_;
    CHECK_LE(num_finished_, num_total_);
    if (num_finished_ == num_total_) {
      condition_.notify_one();
    }
  }

  // Block until all threads have signaled they are finished.
  void Block() {
    std::unique_lock<std::mutex> lock(mutex_);
    condition_.wait(lock, [&]() { return num_finished_ == num_total_; });
  }

 private:
  std::mutex mutex_;
  std::condition_variable condition_;
  // The current number of jobs finished.
  int num_finished_;
  // The total number of jobs.
  int num_total_;
};

// Shared state between the parallel tasks. Each thread will use this
// information to get the next block of work to be performed.
struct SharedState {
  SharedState(int start, int end, int num_work_items)
      : start(start),
        end(end),
        num_work_items(num_work_items),
        i(0),
        thread_token_provider(num_work_items),
        block_until_finished(num_work_items) {}

  // The start and end index of the for loop.
  const int start;
  const int end;
  // The number of blocks that need to be processed.
  const int num_work_items;

  // The next block of work to be assigned to a worker.  The parallel for loop
  // range is split into num_work_items blocks of work, i.e. a single block of
  // work is:
  //  for (int j = start + i; j < end; j += num_work_items) { ... }.
  int i;
  std::mutex mutex_i;

  // Provides a unique thread ID among all active threads working on the same
  // group of tasks.  Thread-safe.
  ThreadTokenProvider thread_token_provider;

  // Used to signal when all the work has been completed.  Thread safe.
  BlockUntilFinished block_until_finished;
};

}  // namespace

int MaxNumThreadsAvailable() { return ThreadPool::MaxNumThreadsAvailable(); }

// See ParallelFor (below) for more details.
void ParallelFor(ContextImpl* context,
                 int start,
                 int end,
                 int num_threads,
                 const std::function<void(int)>& function) {
  CHECK_GT(num_threads, 0);
  CHECK(context != nullptr);
  if (end <= start) {
    return;
  }

  // Fast path for when it is single threaded.
  if (num_threads == 1) {
    for (int i = start; i < end; ++i) {
      function(i);
    }
    return;
  }

  ParallelFor(
      context, start, end, num_threads, [&function](int /*thread_id*/, int i) {
        function(i);
      });
}

// This implementation uses a fixed size max worker pool with a shared task
// queue. The problem of executing the function for the interval of [start, end)
// is broken up into at most num_threads blocks and added to the thread pool. To
// avoid deadlocks, the calling thread is allowed to steal work from the worker
// pool. This is implemented via a shared state between the tasks. In order for
// the calling thread or thread pool to get a block of work, it will query the
// shared state for the next block of work to be done. If there is nothing left,
// it will return. We will exit the ParallelFor call when all of the work has
// been done, not when all of the tasks have been popped off the task queue.
//
// A unique thread ID among all active tasks will be acquired once for each
// block of work.  This avoids the significant performance penalty for acquiring
// it on every iteration of the for loop. The thread ID is guaranteed to be in
// [0, num_threads).
//
// A performance analysis has shown this implementation is onpar with OpenMP and
// TBB.
void ParallelFor(ContextImpl* context,
                 int start,
                 int end,
                 int num_threads,
                 const std::function<void(int thread_id, int i)>& function) {
  CHECK_GT(num_threads, 0);
  CHECK(context != nullptr);
  if (end <= start) {
    return;
  }

  // Fast path for when it is single threaded.
  if (num_threads == 1) {
    // Even though we only have one thread, use the thread token provider to
    // guarantee the exact same behavior when running with multiple threads.
    ThreadTokenProvider thread_token_provider(num_threads);
    const ScopedThreadToken scoped_thread_token(&thread_token_provider);
    const int thread_id = scoped_thread_token.token();
    for (int i = start; i < end; ++i) {
      function(thread_id, i);
    }
    return;
  }

  // We use a std::shared_ptr because the main thread can finish all
  // the work before the tasks have been popped off the queue.  So the
  // shared state needs to exist for the duration of all the tasks.
  const int num_work_items = std::min((end - start), num_threads);
  std::shared_ptr<SharedState> shared_state(
      new SharedState(start, end, num_work_items));

  // A function which tries to perform a chunk of work. This returns false if
  // there is no work to be done.
  auto task_function = [shared_state, &function]() {
    int i = 0;
    {
      // Get the next available chunk of work to be performed. If there is no
      // work, return false.
      std::lock_guard<std::mutex> lock(shared_state->mutex_i);
      if (shared_state->i >= shared_state->num_work_items) {
        return false;
      }
      i = shared_state->i;
      ++shared_state->i;
    }

    const ScopedThreadToken scoped_thread_token(
        &shared_state->thread_token_provider);
    const int thread_id = scoped_thread_token.token();

    // Perform each task.
    for (int j = shared_state->start + i; j < shared_state->end;
         j += shared_state->num_work_items) {
      function(thread_id, j);
    }
    shared_state->block_until_finished.Finished();
    return true;
  };

  // Add all the tasks to the thread pool.
  for (int i = 0; i < num_work_items; ++i) {
    // Note we are taking the task_function as value so the shared_state
    // shared pointer is copied and the ref count is increased. This is to
    // prevent it from being deleted when the main thread finishes all the
    // work and exits before the threads finish.
    context->thread_pool.AddTask([task_function]() { task_function(); });
  }

  // Try to do any available work on the main thread. This may steal work from
  // the thread pool, but when there is no work left the thread pool tasks
  // will be no-ops.
  while (task_function()) {
  }

  // Wait until all tasks have finished.
  shared_state->block_until_finished.Block();
}

}  // namespace internal
}  // namespace ceres

#endif  // CERES_USE_CXX_THREADS