diff options
-rw-r--r-- | base/base.pro | 2 | ||||
-rw-r--r-- | base/random.cpp | 22 | ||||
-rw-r--r-- | base/random.hpp | 9 | ||||
-rw-r--r-- | search/pre_ranker.cpp | 23 |
4 files changed, 37 insertions, 19 deletions
diff --git a/base/base.pro b/base/base.pro index 7eda31055f..71073004a6 100644 --- a/base/base.pro +++ b/base/base.pro @@ -18,6 +18,7 @@ SOURCES += \ logging.cpp \ lower_case.cpp \ normalize_unicode.cpp \ + random.cpp \ shared_buffer_manager.cpp \ src_point.cpp \ string_format.cpp \ @@ -59,6 +60,7 @@ HEADERS += \ mutex.hpp \ newtype.hpp \ observer_list.hpp \ + random.hpp \ range_iterator.hpp \ ref_counted.hpp \ regexp.hpp \ diff --git a/base/random.cpp b/base/random.cpp new file mode 100644 index 0000000000..a6decf8498 --- /dev/null +++ b/base/random.cpp @@ -0,0 +1,22 @@ +#include "base/random.hpp" + +#include <algorithm> +#include <numeric> + +namespace my +{ +std::vector<size_t> RandomSample(size_t n, size_t k, std::minstd_rand & rng) +{ + std::vector<size_t> result(std::min(k, n)); + std::iota(result.begin(), result.end(), 0); + + for (size_t i = k; i < n; ++i) + { + size_t const j = rng() % (i + 1); + if (j < k) + result[j] = i; + } + + return result; +} +} // my diff --git a/base/random.hpp b/base/random.hpp new file mode 100644 index 0000000000..b17189e36c --- /dev/null +++ b/base/random.hpp @@ -0,0 +1,9 @@ +#pragma once + +#include <random> +#include <vector> + +namespace my +{ +std::vector<size_t> RandomSample(size_t n, size_t k, std::minstd_rand & rng); +} // my diff --git a/search/pre_ranker.cpp b/search/pre_ranker.cpp index 70c849df4e..ba0b629513 100644 --- a/search/pre_ranker.cpp +++ b/search/pre_ranker.cpp @@ -9,6 +9,7 @@ #include "indexer/rank_table.hpp" #include "indexer/scales.hpp" +#include "base/random.hpp" #include "base/stl_helpers.hpp" #include "std/iterator.hpp" @@ -46,22 +47,6 @@ struct ComparePreResult1 } }; -// Selects a fair random subset of size min(|n|, |k|) from [0, 1, 2, ..., n - 1]. -vector<size_t> RandomSample(size_t n, size_t k, minstd_rand & rng) -{ - vector<size_t> result(std::min(k, n)); - iota(result.begin(), result.end(), 0); - - for (size_t i = k; i < n; ++i) - { - size_t const j = rng() % (i + 1); - if (j < k) - result[j] = i; - } - - return result; -} - void SweepNearbyResults(double eps, vector<PreResult1> & results) { NearbyPointsSweeper sweeper(eps); @@ -281,7 +266,7 @@ void PreRanker::FilterForViewportSearch() if (m <= old) { - for (size_t i : RandomSample(old, m, m_rng)) + for (size_t i : my::RandomSample(old, m, m_rng)) results.push_back(m_results[bucket[i]]); } else @@ -289,7 +274,7 @@ void PreRanker::FilterForViewportSearch() for (size_t i = 0; i < old; ++i) results.push_back(m_results[bucket[i]]); - for (size_t i : RandomSample(bucket.size() - old, m - old, m_rng)) + for (size_t i : my::RandomSample(bucket.size() - old, m - old, m_rng)) results.push_back(m_results[bucket[old + i]]); } } @@ -301,7 +286,7 @@ void PreRanker::FilterForViewportSearch() else { m_results.clear(); - for (size_t i : RandomSample(results.size(), BatchSize(), m_rng)) + for (size_t i : my::RandomSample(results.size(), BatchSize(), m_rng)) m_results.push_back(results[i]); } } |