ExRandom  3.0
discrete_normal_dist.hpp
Go to the documentation of this file.
1 /**
2  * @file discrete_normal_dist.hpp
3  * @author Charles Karney <charles.karney@sri.com>
4  * @brief Definition of discrete_normal_dist
5  *
6  * Copyright (c) Charles Karney (2014-2020) and licensed under the MIT/X11
7  * License. For more information, see http://exrandom.sourceforge.net/
8  */
9 
10 #if !defined(EXRANDOM_DISCRETE_NORMAL_DIST_HPP)
11 #define EXRANDOM_DISCRETE_NORMAL_DIST_HPP 1
12 
13 #include <algorithm> // for std::min, std::max
14 #include <cstdlib> // for std::abs
15 #include <stdexcept> // for std::runtime_error
16 #include <iostream> // for std::ostream, etc.
17 
18 #include <exrandom/i_rand.hpp>
19 #include <exrandom/u_rand.hpp>
20 
21 namespace exrandom {
22 
23  /**
24  * @brief Partially sample exactly from the discrete normal distribution.
25  *
26  * This samples from the discrete normal distribution P<sub>i</sub> &prop;
27  * exp[&minus; ((i &minus; &mu;)/&sigma;)<sup>2</sup>/2]. This implements
28  * Algorithm D with improvements due to Du et al. (2020).
29  *
30  * @tparam digit_gen the type of digit generator.
31  *
32  * This class allows a i_rand to be returned via the
33  * discrete_normal_dist::generate member function or an int
34  * result via the discrete_normal_dist::operator()() member function.
35  *
36  * See discrete_normal_distribution for a simpler interface to sampling
37  * discrete normal deviates. (But, you can't obtain an i_rand with
38  * this class.)
39  */
40  template<typename digit_gen> class discrete_normal_dist {
41  public:
42  /**
43  * @brief Hold the parameters of discrete_normal_dist.
44  */
45  struct param_type {
46  /**
47  * Construct from the individual parameters.
48  *
49  * @param mu_num the numerator of &mu;.
50  * @param mu_den the denominator of &mu;.
51  * @param sigma_num the numerator of &sigma;.
52  * @param sigma_den the denominator of &sigma;.
53  *
54  * Sets &mu; = @e mu_num / @e mu_den and &sigma; = @e sigma_num
55  * / @e sigma_den.
56  */
57  explicit param_type(int mu_num, int mu_den,
58  int sigma_num, int sigma_den)
59  { param_init(mu_num, mu_den, sigma_num, sigma_den); }
60 
61  /**
62  * The default constructor.
63  *
64  * Sets &mu; = 0 and &sigma; = 1.
65  */
66  explicit param_type()
67  : _mu_num(0), _mu_den(1), _sigma_num(1), _sigma_den(1) {}
68 
69  /**
70  * Construct with integer parameters.
71  *
72  * @param mu the value of &mu;.
73  * @param sigma the value of &sigma;.
74  *
75  * Sets &mu; = @e mu and &sigma; = @e sigma.
76  */
77  explicit param_type(int mu, int sigma)
78  { param_init(mu, 1, sigma, 1); }
79 
80  /**
81  * Construct with parameters with a common denominator.
82  *
83  * @param mu_num the numerator of &mu;.
84  * @param sigma_num the numerator of &sigma;.
85  * @param den the common denominator.
86  *
87  * Sets &mu; = @e mu_num / @e den and &sigma; = @e sigma_num / @e den.
88  */
89  param_type(int mu_num, int sigma_num, int den)
90  { param_init(mu_num, den, sigma_num, den); }
91 
92  /**
93  * @return the numerator of &mu;.
94  */
95  int mu_num() const { return _mu_num; }
96  /**
97  * @return the denominator of &mu;.
98  */
99  int mu_den() const { return _mu_den; }
100  /**
101  * @return the numerator of &sigma;.
102  */
103  int sigma_num() const { return _sigma_num; }
104  /**
105  * @return the denominator of &sigma;.
106  */
107  int sigma_den() const { return _sigma_den; }
108 
109  /**
110  * Test for equality.
111  *
112  * @param p1
113  * @param p2
114  * @return p1 == p2.
115  */
116  friend bool
117  operator==(const param_type& p1, const param_type& p2) {
118  return
119  p1._mu_num == p2._mu_num &&
120  p1._mu_den == p2._mu_den &&
121  p1._sigma_num == p2._sigma_num &&
122  p1._sigma_den == p2._sigma_den;
123  }
124 
125  /**
126  * Inserts a param_type @e x into the output stream @e os.
127  *
128  * @param os an output stream..
129  * @param x a param_type.
130  * @return os.
131  */
132  friend std::ostream& operator<<(std::ostream& os, const param_type& x) {
133  const auto flags = os.flags();
134  os.flags(std::ios::dec);
135  os << x.mu_num() << ' ' << x.mu_den() << ' '
136  << x.sigma_num() << ' ' << x.sigma_den();
137  os.flags(flags);
138  return os;
139  }
140 
141  /**
142  * Extracts a param_type @e x from the input stream @e is.
143  *
144  * @param is an input stream.
145  * @param x a param_type.
146  * @return is.
147  */
148  friend std::istream& operator>>(std::istream& is, param_type& x) {
149  const auto flags = is.flags();
150  is.flags(std::ios::dec | std::ios::skipws);
152  if (is >> mu_num >> mu_den >> sigma_num >> sigma_den)
153  x.param_init(mu_num, mu_den, sigma_num, sigma_den);
154  is.flags(flags);
155  return is;
156  }
157  private:
158  int _mu_num, _mu_den, _sigma_num, _sigma_den;
159  void param_init(int mu_num, int mu_den,
160  int sigma_num, int sigma_den) {
161  if (!( sigma_num > 0 && sigma_den > 0 && mu_den > 0 &&
162  mu_num > std::numeric_limits<int>::min()))
163  throw std::runtime_error("discrete_normal_dist: need sigma > 0");
164  int l;
165  l = gcd(mu_num, mu_den);
166  _mu_num = mu_num/l; _mu_den = mu_den/l;
167  l = gcd(sigma_num, sigma_den);
168  _sigma_num = sigma_num/l; _sigma_den = sigma_den/l;
169  }
170  };
171 
172  /**
173  * The default constructor.
174  *
175  * @param D a reference to the digit generator to be used.
176  *
177  * Sets &mu; = 0 and &sigma; = 1.
178  */
179  discrete_normal_dist(digit_gen& D)
180  : _D(D), _y(D), _z(D), _j(D), _param()
181  { init(); }
182 
183  /**
184  * Construct from a param_type.
185  *
186  * @param D a reference to the digit generator to be used.
187  * @param p the param_type.
188  */
189  discrete_normal_dist(digit_gen& D, const param_type& p)
190  : _D(D), _y(D), _z(D), _j(D), _param(p)
191  { init(); }
192 
193  /**
194  * Construct with integer parameters.
195  *
196  * @param D a reference to the digit generator to be used.
197  * @param mu the value of &mu;.
198  * @param sigma the value of &sigma;.
199  *
200  * Sets &mu; = @e mu and &sigma; = @e sigma.
201  */
202  discrete_normal_dist(digit_gen& D, int mu, int sigma)
203  : _D(D), _y(D), _z(D), _j(D), _param(mu, sigma)
204  { init(); }
205 
206  /**
207  * Construct with parameters with a common denominator.
208  *
209  * @param D a reference to the digit generator to be used.
210  * @param mu_num the numerator of &mu;.
211  * @param sigma_num the numerator of &sigma;.
212  * @param den the common denominator.
213  *
214  * Sets &mu; = @e mu_num / @e den and &sigma; = @e sigma_num / @e den.
215  */
216  discrete_normal_dist(digit_gen& D, int mu_num, int sigma_num, int den)
217  : _D(D), _y(D), _z(D), _j(D), _param(mu_num, den, sigma_num, den)
218  { init(); }
219 
220  /**
221  * Construct from the individual parameters.
222  *
223  * @param D a reference to the digit generator to be used.
224  * @param mu_num the numerator of &mu;.
225  * @param mu_den the denominator of &mu;.
226  * @param sigma_num the numerator of &sigma;.
227  * @param sigma_den the denominator of &sigma;.
228  *
229  * Sets &mu; = @e mu_num / @e mu_den and &sigma; = @e sigma_num / @e
230  * sigma_den.
231  */
232  discrete_normal_dist(digit_gen& D,
233  int mu_num, int mu_den,
234  int sigma_num, int sigma_den)
235  : _D(D), _y(D), _z(D), _j(D)
236  , _param(mu_num, mu_den, sigma_num, sigma_den)
237  { init(); }
238 
239  /**
240  * @return the numerator of &mu;.
241  */
242  int mu_num() const { return _param.mu_num(); }
243  /**
244  * @return the denominator of &mu;.
245  */
246  int mu_den() const { return _param.mu_den(); }
247  /**
248  * @return the numerator of &sigma;.
249  */
250  int sigma_num() const { return _param.sigma_num(); }
251  /**
252  * @return the denominator of &sigma;.
253  */
254  int sigma_den() const { return _param.sigma_den(); }
255 
256  /**
257  * Return a deviate as a i_rand.
258  *
259  * @tparam Generator the type of g.
260  * @param g the random generator engine.
261  * @param j the i_rand to set.
262  */
263  template<typename Generator>
264  void generate(Generator& g, i_rand<digit_gen>& j) {
265  for (;;) {
266  int k = G(g); // step 1
267  k = S(k); if (k < 0) continue; // step 2
268  // Explanation of Steps 3 & 5. The scheme for unit_normal samples k,
269  // samples x in [0,1], and (unless rejected) returns s*(k+x). For the
270  // discrete case, we sample x in [0,1) such that s*(k+x) = (i-mu)/sigma
271  // or
272  //
273  // x = s*(i - mu)/sigma - k
274  //
275  // The value of i which results in the smallest x >= 0 is
276  //
277  // s*i0 = s*ceil(sigma*k + s*mu)
278  //
279  // so sample
280  //
281  // i = s * (i0 + j)
282  //
283  // where j is uniformly distributed in [0, ceil(sigma)). The
284  // corresponding value of x is
285  //
286  // x = (i0 - (sigma*k + s*mu))/sigma + j/sigma
287  // = x0 + j/sigma
288  // x0 = (ceil(sigma*k + s*mu) - (sigma*k + s*mu))/sigma
289  int s = j.init(g,2)(g) ? -1 : 1; // step 6
290  long long xn0 = _sig * k + s * _mu;
291  int i0 = int(iceil(xn0, _d)); // step 6
292  xn0 = i0 * _d - xn0; // step 3, xn = xn0 + j * _d
293  j.init(g, _isig); // i = s * (i0 + j)
294  // If sigma is not an integer, this may result (with j = _isig-1) in x
295  // >= 1. Reject such samples. Reject also the case s = -1, k = 0, and
296  // x == 0 (since this is treated by the case s = 1, k = 0, x = 0).
297  if (!j.less_than(g, _sig - xn0, _d) ||
298  (k == 0 && s < 0 && !j.greater_than(g, -xn0, _d)))
299  continue;
300  int h = k; while (h-- && E(g, xn0, j)) {}; // step 4
301  if (!(h < 0)) continue;
302  if (!B(g, xn0, j)) continue; // step 5
303  j.add(i0 + s*_imu); // step 6
304  if (s < 0) j.negate(); // step 7
305  return; // step 8
306  }
307  }
308 
309  /**
310  * Return a deviate.
311  *
312  * @tparam Generator the type of g.
313  * @param g the random generator engine.
314  * @return the random deviate.
315  */
316  template<typename Generator>
317  int operator()(Generator& g) {
318  generate(g, _j);
319  return _j(g);
320  }
321  /**
322  * @return a reference to the digit generator used in the constructor.
323  */
324  digit_gen& digit_generator() const { return _D; }
325  /**
326  * @return the parameters.
327  */
328  const param_type& param() const { return _param; }
329  /**
330  * Set new parameters.
331  *
332  * @param param the new parameters.
333  */
334  void init(const param_type& param) { _param = param; init(); }
335  private:
336  static const int b = digit_gen::base;
337  // Allow base in [2,2^24]. Need digit to be representable as an int. This
338  // also allows kmax * base to be representable as in int.
339  static_assert(digit_gen::bits <= 24, "base must be in [2,2^24]");
340  // Disable copy assignment
341  discrete_normal_dist& operator=(const discrete_normal_dist&);
342  digit_gen& _D;
343  u_rand<digit_gen> _y, _z; // temporary storage
344  i_rand<digit_gen> _j; // temporary storage
345  param_type _param;
346  long long _sig, _mu, _d; // sigma = _sig/_d, mu = _imu + _mu/_d
347  int _imu, _isig; // _isig = ceil(sigma)
348 
349  static long long iceil(long long n, long long d) // ceil(n/d) for d > 0
350  { long long k = n / d; return k + (k * d < n ? 1 : 0); }
351 
352  // Knuth, TAOCP, vol 2, 4.5.2, Algorithm A
353  static int gcd(int u, int v) {
354  u = std::abs(u); v = std::abs(v);
355  while (v > 0) { int r = u % v; u = v; v = r; }
356  return u;
357  }
358 
359  void init() {
360  static const long long maxll = std::numeric_limits<long long>::max();
361  static const int maxint = std::numeric_limits<int>::max();
362  _imu = int(_param.mu_num() / _param.mu_den());
363  int fmu_num = _param.mu_num() - _imu * _param.mu_den();
364  _isig = int(iceil(_param.sigma_num(), _param.sigma_den()));
365  long long l = gcd(_param.sigma_den(), _param.mu_den());
366  if (!( _param.mu_den() / l <= maxll / _param.sigma_num() &&
367  std::abs(fmu_num) <= maxll / (_param.sigma_den() / l) &&
368  _param.mu_den() / l <= maxll / _param.sigma_den() ))
369  throw std::runtime_error("discrete_normal_dist: sigma or mu overflow");
370  _sig = _param.sigma_num() * (_param.mu_den() / l);
371  _mu = fmu_num * (_param.sigma_den() / l);
372  _d = _param.sigma_den() * (_param.mu_den() / l);
373  // sigma = _sig / _d; _isig = ceil(sigma); check _isig * _d is
374  // representable as a long long (in i_rand.less_than)
375  if (!(_isig <= maxll / _d))
376  throw std::runtime_error("discrete_normal_dist: sigma or mu overflow");
377  // The rest of the constructor tests for possible overflow
378  // The probability that k = kmax is about 10^-543.
379  int kmax = 50 + 1;
380  // Check that max plausible result fits in an int
381  if (!(_isig <= maxint / kmax))
382  throw std::runtime_error("discrete_normal_dist: possible overflow a");
383  if (!(std::abs(_imu) <= maxint - _isig * kmax))
384  throw std::runtime_error("discrete_normal_dist: possible overflow b");
385  // Need to represent
386  // _sig * kmax as long long -- xn0 = _sig * k ...)
387  // # no longer relevant -- compare no longer called
388  // # base * 2 * kmax as long long -- in compare(g, 1, 2, m)
389  // _isig * base as long long -- in i_rand::start
390  // _sig * kmax * base as long long -- in u_rand::less_than
391  // _sig * base as long long -- in u_rand::less_than
392  // Combine requirements as
393  // max(2,_sig) * base * kmax
394  if (!((std::max)(2LL, _sig) <= maxll / (b * kmax)))
395  throw std::runtime_error("discrete_normal_dist: possible overflow c");
396  }
397 
398  // Algorithm H: true with probability exp(-1/2).
399  template<typename Generator>
400  bool H(Generator& g) {
401  if (!_y.init().less_than_half(g)) return true;
402  for (;;) {
403  if (!_z.init().less_than(g, _y)) return false;
404  if (!_y.init().less_than(g, _z)) return true;
405  }
406  }
407 
408  // Step N1: return n >= 0 with prob. exp(-n/2) * (1 - exp(-1/2)).
409  template<typename Generator>
410  int G(Generator& g)
411  { int n = 0; while (H(g)) ++n; return n; }
412 
413  // return the square root of n >= 0 if it's a perfect square else -1
414  int S(int n) const {
415  for (int k = 0, k2 = 0; k2 <= n; ++k, k2 += 2*k - 1) {
416  // Here k2 = k * k; note that k^2 - (k - 1)^2 = 2*k - 1
417  if (n == k2) return k;
418  }
419  return -1;
420  }
421 
422  // Algorithm E: true with prob exp(-x) for x in (0, 1), where
423  // x = (xn0 + _d * j) / _sig.
424  template<typename Generator>
425  bool E(Generator& g, long long xn0, i_rand<digit_gen>& j) {
426  if (!_y.init().less_than(g, xn0, _d, _sig, j)) return true;
427  for (;;) {
428  if (!_z.init().less_than(g, _y)) return false;
429  if (!_y.init().less_than(g, _z)) return true;
430  }
431  }
432 
433  // Algorithm B: true with prob exp(-x^2/2), where
434  // x = (xn0 + _d * j) / _sig
435  template<typename Generator>
436  bool B(Generator& g, long long xn0,
437  i_rand<digit_gen>& j) {
438  int n = 0;
439  for (;; ++n) {
440  if (_z.init().less_than_half(g)) break;
441  _z.init();
442  if (!(n ? _z.less_than(g, _y) : _z.less_than(g, xn0, _d, _sig, j)))
443  break;
444  if (!_y.init().less_than(g, xn0, _d, _sig, j)) break;
445  _y.swap(_z); // an efficient way of doing y = z
446  }
447  return (n % 2) == 0;
448  }
449 
450  };
451 
452 }
453 
454 #endif // EXRANDOM_DISCRETE_NORMAL_DIST_HPP
bool greater_than(Generator &g, long long m, long long n=1)
Definition: i_rand.hpp:166
friend std::istream & operator>>(std::istream &is, param_type &x)
void add(int c)
Definition: i_rand.hpp:120
discrete_normal_dist(digit_gen &D, int mu_num, int sigma_num, int den)
discrete_normal_dist(digit_gen &D, const param_type &p)
discrete_normal_dist(digit_gen &D, int mu_num, int mu_den, int sigma_num, int sigma_den)
param_type(int mu_num, int sigma_num, int den)
bool less_than(Generator &g, u_rand &t)
Definition: u_rand.hpp:196
discrete_normal_dist(digit_gen &D, int mu, int sigma)
A class to sample integers [0,m).
Definition: i_rand.hpp:44
Definition of u_rand.
void swap(u_rand &t)
Definition: u_rand.hpp:134
u_rand & init()
Definition: u_rand.hpp:123
Partially sample exactly from the discrete normal distribution.
const param_type & param() const
param_type(int mu_num, int mu_den, int sigma_num, int sigma_den)
friend std::ostream & operator<<(std::ostream &os, const param_type &x)
The common namespace.
Definition: aux_info.hpp:18
Hold the parameters of discrete_normal_dist.
The machinery to handle u-rands, arbitrary precision random deviates.
Definition: u_rand.hpp:105
void generate(Generator &g, i_rand< digit_gen > &j)
i_rand & init(Generator &g, int m)
Definition: i_rand.hpp:67
Definition of i_rand.
friend bool operator==(const param_type &p1, const param_type &p2)
bool less_than(Generator &g, long long m, long long n=1)
Definition: i_rand.hpp:133
void init(const param_type &param)