Welcome to mirror list, hosted at ThFree Co, Russian Federation.

sampling_util.cc « multi_functions « intern « functions « blender « source - git.blender.org/blender.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 29c0d52c00c6eda4bbf5612afe587439104634f2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
#include "sampling_util.h"

#include "BLI_vector_adaptor.h"

namespace FN {

using BLI::VectorAdaptor;

float compute_cumulative_distribution(ArrayRef<float> weights,
                                      MutableArrayRef<float> r_cumulative_weights)
{
  BLI_assert(weights.size() + 1 == r_cumulative_weights.size());

  r_cumulative_weights[0] = 0;
  for (uint i : weights.index_range()) {
    r_cumulative_weights[i + 1] = r_cumulative_weights[i] + weights[i];
  }

  float weight_sum = r_cumulative_weights.last();
  return weight_sum;
}

static void sample_cumulative_distribution__recursive(RNG *rng,
                                                      uint amount,
                                                      uint start,
                                                      uint one_after_end,
                                                      ArrayRef<float> cumulative_weights,
                                                      VectorAdaptor<uint> &sampled_indices)
{
  BLI_assert(start <= one_after_end);
  uint size = one_after_end - start;
  if (size == 0) {
    BLI_assert(amount == 0);
  }
  else if (amount == 0) {
    return;
  }
  else if (size == 1) {
    sampled_indices.append_n_times(start, amount);
  }
  else {
    uint middle = start + size / 2;
    float left_weight = cumulative_weights[middle] - cumulative_weights[start];
    float right_weight = cumulative_weights[one_after_end] - cumulative_weights[middle];
    BLI_assert(left_weight >= 0.0f && right_weight >= 0.0f);
    float weight_sum = left_weight + right_weight;
    BLI_assert(weight_sum > 0.0f);

    float left_factor = left_weight / weight_sum;
    float right_factor = right_weight / weight_sum;

    uint left_amount = amount * left_factor;
    uint right_amount = amount * right_factor;

    if (left_amount + right_amount < amount) {
      BLI_assert(left_amount + right_amount + 1 == amount);
      float weight_per_item = weight_sum / amount;
      float total_remaining_weight = weight_sum - (left_amount + right_amount) * weight_per_item;
      float left_remaining_weight = left_weight - left_amount * weight_per_item;
      float left_remaining_factor = left_remaining_weight / total_remaining_weight;
      if (BLI_rng_get_float(rng) < left_remaining_factor) {
        left_amount++;
      }
      else {
        right_amount++;
      }
    }

    sample_cumulative_distribution__recursive(
        rng, left_amount, start, middle, cumulative_weights, sampled_indices);
    sample_cumulative_distribution__recursive(
        rng, right_amount, middle, one_after_end, cumulative_weights, sampled_indices);
  }
}

void sample_cumulative_distribution(RNG *rng,
                                    ArrayRef<float> cumulative_weights,
                                    MutableArrayRef<uint> r_sampled_indices)
{
  BLI_assert(r_sampled_indices.size() == 0 || cumulative_weights.last() > 0.0f);

  uint amount = r_sampled_indices.size();
  VectorAdaptor<uint> sampled_indices(r_sampled_indices.begin(), amount);
  sample_cumulative_distribution__recursive(
      rng, amount, 0, cumulative_weights.size() - 1, cumulative_weights, sampled_indices);
  BLI_assert(sampled_indices.is_full());
}

}  // namespace FN