[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

rf_earlystopping.hxx
1 #ifndef RF_EARLY_STOPPING_P_HXX
2 #define RF_EARLY_STOPPING_P_HXX
3 #include <cmath>
4 #include "rf_common.hxx"
5 
6 namespace vigra
7 {
8 
9 #if 0
10 namespace es_detail
11 {
12  template<class T>
13  T power(T const & in, int n)
14  {
15  T result = NumericTraits<T>::one();
16  for(int ii = 0; ii < n ;++ii)
17  result *= in;
18  return result;
19  }
20 }
21 #endif
22 
23 /**Base class from which all EarlyStopping Functors derive.
24  */
25 class StopBase
26 {
27 protected:
28  ProblemSpec<> ext_param_;
29  int tree_count_ ;
30  bool is_weighted_;
31 
32 public:
33  template<class T>
34  void set_external_parameters(ProblemSpec<T> const &prob, int tree_count = 0, bool is_weighted = false)
35  {
36  ext_param_ = prob;
37  is_weighted_ = is_weighted;
38  tree_count_ = tree_count;
39  }
40 
41  /** called after the prediction of a tree was added to the total prediction
42  * \param WeightIter Iterator to the weights delivered by current tree.
43  * \param k after kth tree
44  * \param prob Total probability array
45  * \param totalCt sum of probability array.
46  */
47  template<class WeightIter, class T, class C>
48  bool after_prediction(WeightIter, int k, MultiArrayView<2, T, C> const & /* prob */, double /* totalCt */)
49  {return false;}
50 };
51 
52 
53 /**Stop predicting after a set number of trees
54  */
55 class StopAfterTree : public StopBase
56 {
57 public:
58  double max_tree_p;
59  int max_tree_;
60  typedef StopBase SB;
61 
62  ArrayVector<double> depths;
63 
64  /** Constructor
65  * \param max_tree number of trees to be used for prediction
66  */
67  StopAfterTree(double max_tree)
68  :
69  max_tree_p(max_tree)
70  {}
71 
72  template<class T>
73  void set_external_parameters(ProblemSpec<T> const &prob, int tree_count = 0, bool is_weighted = false)
74  {
75  max_tree_ = ceil(max_tree_p * tree_count);
76  SB::set_external_parameters(prob, tree_count, is_weighted);
77  }
78 
79  template<class WeightIter, class T, class C>
80  bool after_prediction(WeightIter, int k, MultiArrayView<2, T, C> const & /* prob */, double /* totalCt */)
81  {
82  if(k == SB::tree_count_ -1)
83  {
84  depths.push_back(double(k+1)/double(SB::tree_count_));
85  return false;
86  }
87  if(k < max_tree_)
88  return false;
89  depths.push_back(double(k+1)/double(SB::tree_count_));
90  return true;
91  }
92 };
93 
94 /** Stop predicting after a certain amount of votes exceed certain proportion.
95  * case unweighted voting: stop if the leading class exceeds proportion * SB::tree_count_
96  * case weighted votion: stop if the leading class exceeds proportion * msample_ * SB::tree_count_ ;
97  * (maximal number of votes possible in both cases)
98  */
100 {
101 public:
102  double proportion_;
103  typedef StopBase SB;
104  ArrayVector<double> depths;
105 
106  /** Constructor
107  * \param proportion specify proportion to be used.
108  */
109  StopAfterVoteCount(double proportion)
110  :
111  proportion_(proportion)
112  {}
113 
114  template<class WeightIter, class T, class C>
115  bool after_prediction(WeightIter, int k, MultiArrayView<2, T, C> const & prob, double /* totalCt */)
116  {
117  if(k == SB::tree_count_ -1)
118  {
119  depths.push_back(double(k+1)/double(SB::tree_count_));
120  return false;
121  }
122 
123 
124  if(SB::is_weighted_)
125  {
126  if(prob[argMax(prob)] > proportion_ *SB::ext_param_.actual_msample_* SB::tree_count_)
127  {
128  depths.push_back(double(k+1)/double(SB::tree_count_));
129  return true;
130  }
131  }
132  else
133  {
134  if(prob[argMax(prob)] > proportion_ * SB::tree_count_)
135  {
136  depths.push_back(double(k+1)/double(SB::tree_count_));
137  return true;
138  }
139  }
140  return false;
141  }
142 
143 };
144 
145 
146 /** Stop predicting if the 2norm of the probabilities does not change*/
148 
149 {
150 public:
151  double thresh_;
152  int num_;
153  MultiArray<2, double> last_;
155  ArrayVector<double> depths;
156  typedef StopBase SB;
157 
158  /** Constructor
159  * \param thresh: If the two norm of the probabilites changes less then thresh then stop
160  * \param num : look at atleast num trees before stopping
161  */
162  StopIfConverging(double thresh, int num = 10)
163  :
164  thresh_(thresh),
165  num_(num)
166  {}
167 
168  template<class T>
169  void set_external_parameters(ProblemSpec<T> const &prob, int tree_count = 0, bool is_weighted = false)
170  {
171  last_.reshape(MultiArrayShape<2>::type(1, prob.class_count_), 0);
172  cur_.reshape(MultiArrayShape<2>::type(1, prob.class_count_), 0);
173  SB::set_external_parameters(prob, tree_count, is_weighted);
174  }
175  template<class WeightIter, class T, class C>
176  bool after_prediction(WeightIter iter, int k, MultiArrayView<2, T, C> const & prob, double totalCt)
177  {
178  if(k == SB::tree_count_ -1)
179  {
180  depths.push_back(double(k+1)/double(SB::tree_count_));
181  return false;
182  }
183  if(k <= num_)
184  {
185  last_ = prob;
186  last_/= last_.norm(1);
187  return false;
188  }
189  else
190  {
191  cur_ = prob;
192  cur_ /= cur_.norm(1);
193  last_ -= cur_;
194  double nrm = last_.norm();
195  if(nrm < thresh_)
196  {
197  depths.push_back(double(k+1)/double(SB::tree_count_));
198  return true;
199  }
200  else
201  {
202  last_ = cur_;
203  }
204  }
205  return false;
206  }
207 };
208 
209 
210 /** Stop predicting if the margin prob(leading class) - prob(second class) exceeds a proportion
211  * case unweighted voting: stop if margin exceeds proportion * SB::tree_count_
212  * case weighted votion: stop if margin exceeds proportion * msample_ * SB::tree_count_ ;
213  * (maximal number of votes possible in both cases)
214  */
215 class StopIfMargin : public StopBase
216 {
217 public:
218  double proportion_;
219  typedef StopBase SB;
220  ArrayVector<double> depths;
221 
222  /** Constructor
223  * \param proportion specify proportion to be used.
224  */
225  StopIfMargin(double proportion)
226  :
227  proportion_(proportion)
228  {}
229 
230  template<class WeightIter, class T, class C>
231  bool after_prediction(WeightIter, int k, MultiArrayView<2, T, C> prob, double /* totalCt */)
232  {
233  if(k == SB::tree_count_ -1)
234  {
235  depths.push_back(double(k+1)/double(SB::tree_count_));
236  return false;
237  }
238  int index = argMax(prob);
239  double a = prob[argMax(prob)];
240  prob[argMax(prob)] = 0;
241  double b = prob[argMax(prob)];
242  prob[index] = a;
243  double margin = a - b;
244  if(SB::is_weighted_)
245  {
246  if(margin > proportion_ *SB::ext_param_.actual_msample_ * SB::tree_count_)
247  {
248  depths.push_back(double(k+1)/double(SB::tree_count_));
249  return true;
250  }
251  }
252  else
253  {
254  if(prob[argMax(prob)] > proportion_ * SB::tree_count_)
255  {
256  depths.push_back(double(k+1)/double(SB::tree_count_));
257  return true;
258  }
259  }
260  return false;
261  }
262 };
263 
264 
265 /**Probabilistic Stopping criterion (binomial test)
266  *
267  * Can only be used in a two class setting
268  *
269  * Stop if the Parameters estimated for the underlying binomial distribution
270  * can be estimated with certainty over 1-alpha.
271  * (Thesis, Rahul Nair Page 80 onwards: called the "binomial" criterion
272  */
273 class StopIfBinTest : public StopBase
274 {
275 public:
276  double alpha_;
277  MultiArrayView<2, double> n_choose_k;
278  /** Constructor
279  * \param proportion specify alpha value for binomial test.
280  * \param nck_ Matrix with precomputed values for n choose k
281  * nck_(n, k) is n choose k.
282  */
284  :
285  alpha_(alpha),
286  n_choose_k(nck_)
287  {}
288  typedef StopBase SB;
289 
290  /**ArrayVector that will contain the fraction of trees that was visited before terminating
291  */
293 
294  double binomial(int N, int k, double p)
295  {
296 // return n_choose_k(N, k) * es_detail::power(p, k) *es_detail::power(1 - p, N-k);
297  return n_choose_k(N, k) * std::pow(p, k) * std::pow(1 - p, N-k);
298  }
299 
300  template<class WeightIter, class T, class C>
301  bool after_prediction(WeightIter iter, int k, MultiArrayView<2, T, C> prob, double totalCt)
302  {
303  if(k == SB::tree_count_ -1)
304  {
305  depths.push_back(double(k+1)/double(SB::tree_count_));
306  return false;
307  }
308  if(k < 10)
309  {
310  return false;
311  }
312  int index = argMax(prob);
313  int n_a = prob[index];
314  int n_b = prob[(index+1)%2];
315  int n_tilde = (SB::tree_count_ - n_a + n_b);
316  double p_a = double(n_b - n_a + n_tilde)/double(2* n_tilde);
317  vigra_precondition(p_a <= 1, "probability should be smaller than 1");
318  double cum_val = 0;
319  int c = 0;
320  // std::cerr << "prob: " << p_a << std::endl;
321  if(n_a <= 0)n_a = 0;
322  if(n_b <= 0)n_b = 0;
323  for(int ii = 0; ii <= n_b + n_a;++ii)
324  {
325 // std::cerr << "nb +ba " << n_b + n_a << " " << ii <<std::endl;
326  cum_val += binomial(n_b + n_a, ii, p_a);
327  if(cum_val >= 1 -alpha_)
328  {
329  c = ii;
330  break;
331  }
332  }
333 // std::cerr << c << " " << n_a << " " << n_b << " " << p_a << alpha_ << std::endl;
334  if(c < n_a)
335  {
336  depths.push_back(double(k+1)/double(SB::tree_count_));
337  return true;
338  }
339 
340  return false;
341  }
342 };
343 
344 /**Probabilistic Stopping criteria. (toChange)
345  *
346  * Can only be used in a two class setting
347  *
348  * Stop if the probability that the decision will change after seeing all trees falls under
349  * a specified value alpha.
350  * (Thesis, Rahul Nair Page 80 onwards: called the "toChange" criterion
351  */
352 class StopIfProb : public StopBase
353 {
354 public:
355  double alpha_;
356  MultiArrayView<2, double> n_choose_k;
357 
358 
359  /** Constructor
360  * \param proportion specify alpha value
361  * \param nck_ Matrix with precomputed values for n choose k
362  * nck_(n, k) is n choose k.
363  */
365  :
366  alpha_(alpha),
367  n_choose_k(nck_)
368  {}
369  typedef StopBase SB;
370  /**ArrayVector that will contain the fraction of trees that was visited before terminating
371  */
373 
374  double binomial(int N, int k, double p)
375  {
376 // return n_choose_k(N, k) * es_detail::power(p, k) *es_detail::power(1 - p, N-k);
377  return n_choose_k(N, k) * std::pow(p, k) * std::pow(1 - p, N-k);
378  }
379 
380  template<class WeightIter, class T, class C>
381  bool after_prediction(WeightIter iter, int k, MultiArrayView<2, T, C> prob, double totalCt)
382  {
383  if(k == SB::tree_count_ -1)
384  {
385  depths.push_back(double(k+1)/double(SB::tree_count_));
386  return false;
387  }
388  if(k <= 10)
389  {
390  return false;
391  }
392  int index = argMax(prob);
393  int n_a = prob[index];
394  int n_b = prob[(index+1)%2];
395  int n_needed = ceil(double(SB::tree_count_)/2.0)-n_a;
396  int n_tilde = SB::tree_count_ - (n_a +n_b);
397  if(n_tilde <= 0) n_tilde = 0;
398  if(n_needed <= 0) n_needed = 0;
399  double p = 0;
400  for(int ii = n_needed; ii < n_tilde; ++ii)
401  p += binomial(n_tilde, ii, 0.5);
402 
403  if(p >= 1-alpha_)
404  {
405  depths.push_back(double(k+1)/double(SB::tree_count_));
406  return true;
407  }
408 
409  return false;
410  }
411 };
412 } //namespace vigra;
413 #endif //RF_EARLY_STOPPING_P_HXX

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.7.1 (Wed Mar 12 2014)