Welcome to mirror list, hosted at ThFree Co, Russian Federation.

thread_pool.h « ctranslate2 « include - github.com/OpenNMT/CTranslate2.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 826b7e57f67cb5a7bddba5bf88e642a58440bdba (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#pragma once

#include <atomic>
#include <condition_variable>
#include <functional>
#include <limits>
#include <memory>
#include <mutex>
#include <queue>
#include <thread>
#include <vector>

namespace ctranslate2 {

  // Base class for asynchronous jobs.
  class Job {
  public:
    virtual ~Job();
    virtual void run() = 0;

    // The job counter is used to track the number of active jobs (queued and currently processed).
    void set_job_counter(std::atomic<size_t>& counter);

  private:
    std::atomic<size_t>* _counter = nullptr;
  };

  // A thread-safe queue of jobs.
  class JobQueue {
  public:
    JobQueue(size_t maximum_size);
    ~JobQueue();

    size_t size() const;

    // Puts a job in the queue. The method blocks until a free slot is available.
    void put(std::unique_ptr<Job> job);

    // Gets a job from the queue. The method blocks until a job is available.
    // If the queue is closed, the method returns a null pointer.
    std::unique_ptr<Job> get(const std::function<void()>& before_wait = nullptr);

    void close();

  private:
    bool can_get_job() const;

    mutable std::mutex _mutex;
    std::queue<std::unique_ptr<Job>> _queue;
    std::condition_variable _can_put_job;
    std::condition_variable _can_get_job;
    size_t _maximum_size;
    bool _request_end;
  };

  // A worker processing jobs in a thread.
  class Worker {
  public:
    virtual ~Worker() = default;

    void start(JobQueue& job_queue, int thread_affinity = -1);
    void join();

  protected:
    // Called before the work loop.
    virtual void initialize() {}

    // Called after the work loop.
    virtual void finalize() {}

    // Called before waiting for new jobs.
    virtual void idle() {}

  private:
    void run(JobQueue& job_queue);

    std::thread _thread;
  };

  // A pool of threads.
  class ThreadPool {
  public:
    // Default thread workers.
    ThreadPool(size_t num_threads,
               size_t maximum_queue_size = std::numeric_limits<size_t>::max(),
               int core_offset = -1);

    // User-defined thread workers.
    ThreadPool(std::vector<std::unique_ptr<Worker>> workers,
               size_t maximum_queue_size = std::numeric_limits<size_t>::max(),
               int core_offset = -1);

    ~ThreadPool();

    // Posts a new job. The method blocks if the job queue is full.
    void post(std::unique_ptr<Job> job);

    size_t num_threads() const;

    // Number of jobs in the queue.
    size_t num_queued_jobs() const;

    // Number of jobs in the queue and currently processed by a worker.
    size_t num_active_jobs() const;

    Worker& get_worker(size_t index);
    static Worker& get_local_worker();

  private:
    void start_workers(int core_offset);

    JobQueue _queue;
    std::vector<std::unique_ptr<Worker>> _workers;
    std::atomic<size_t> _num_active_jobs;
  };

}