[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

rf_visitors.hxx
1 /************************************************************************/
2 /* */
3 /* Copyright 2008-2009 by Ullrich Koethe and Rahul Nair */
4 /* */
5 /* This file is part of the VIGRA computer vision library. */
6 /* The VIGRA Website is */
7 /* http://hci.iwr.uni-heidelberg.de/vigra/ */
8 /* Please direct questions, bug reports, and contributions to */
9 /* ullrich.koethe@iwr.uni-heidelberg.de or */
10 /* vigra@informatik.uni-hamburg.de */
11 /* */
12 /* Permission is hereby granted, free of charge, to any person */
13 /* obtaining a copy of this software and associated documentation */
14 /* files (the "Software"), to deal in the Software without */
15 /* restriction, including without limitation the rights to use, */
16 /* copy, modify, merge, publish, distribute, sublicense, and/or */
17 /* sell copies of the Software, and to permit persons to whom the */
18 /* Software is furnished to do so, subject to the following */
19 /* conditions: */
20 /* */
21 /* The above copyright notice and this permission notice shall be */
22 /* included in all copies or substantial portions of the */
23 /* Software. */
24 /* */
25 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */
26 /* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */
27 /* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */
28 /* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */
29 /* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */
30 /* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */
31 /* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */
32 /* OTHER DEALINGS IN THE SOFTWARE. */
33 /* */
34 /************************************************************************/
35 #ifndef RF_VISITORS_HXX
36 #define RF_VISITORS_HXX
37 
38 #ifdef HasHDF5
39 # include "vigra/hdf5impex.hxx"
40 #endif // HasHDF5
41 #include <vigra/windows.h>
42 #include <iostream>
43 #include <iomanip>
44 #include <vigra/timing.hxx>
45 
46 namespace vigra
47 {
48 namespace rf
49 {
50 /** \addtogroup MachineLearning Machine Learning
51 **/
52 //@{
53 
54 /**
55  This namespace contains all classes and methods related to extracting information during
56  learning of the random forest. All Visitors share the same interface defined in
57  visitors::VisitorBase. The member methods are invoked at certain points of the main code in
58  the order they were supplied.
59 
60  For the Random Forest the Visitor concept is implemented as a statically linked list
61  (Using templates). Each Visitor object is encapsulated in a detail::VisitorNode object. The
62  VisitorNode object calls the Next Visitor after one of its visit() methods have terminated.
63 
64  To simplify usage create_visitor() factory methods are supplied.
65  Use the create_visitor() method to supply visitor objects to the RandomForest::learn() method.
66  It is possible to supply more than one visitor. They will then be invoked in serial order.
67 
68  The calculated information are stored as public data members of the class. - see documentation
69  of the individual visitors
70 
71  While creating a new visitor the new class should therefore publicly inherit from this class
72  (i.e.: see visitors::OOB_Error).
73 
74  \code
75 
76  typedef xxx feature_t \\ replace xxx with whichever type
77  typedef yyy label_t \\ meme chose.
78  MultiArrayView<2, feature_t> f = get_some_features();
79  MultiArrayView<2, label_t> l = get_some_labels();
80  RandomForest<> rf()
81 
82  //calculate OOB Error
83  visitors::OOB_Error oob_v;
84  //calculate Variable Importance
85  visitors::VariableImportanceVisitor varimp_v;
86 
87  double oob_error = rf.learn(f, l, visitors::create_visitor(oob_v, varimp_v);
88  //the data can be found in the attributes of oob_v and varimp_v now
89 
90  \endcode
91 */
92 namespace visitors
93 {
94 
95 
96 /** Base Class from which all Visitors derive. Can be used as a template to create new
97  * Visitors.
98  */
100 {
101  public:
102  bool active_;
103  bool is_active()
104  {
105  return active_;
106  }
107 
108  bool has_value()
109  {
110  return false;
111  }
112 
113  VisitorBase()
114  : active_(true)
115  {}
116 
117  void deactivate()
118  {
119  active_ = false;
120  }
121  void activate()
122  {
123  active_ = true;
124  }
125 
126  /** do something after the the Split has decided how to process the Region
127  * (Stack entry)
128  *
129  * \param tree reference to the tree that is currently being learned
130  * \param split reference to the split object
131  * \param parent current stack entry which was used to decide the split
132  * \param leftChild left stack entry that will be pushed
133  * \param rightChild
134  * right stack entry that will be pushed.
135  * \param features features matrix
136  * \param labels label matrix
137  * \sa RF_Traits::StackEntry_t
138  */
139  template<class Tree, class Split, class Region, class Feature_t, class Label_t>
140  void visit_after_split( Tree & tree,
141  Split & split,
142  Region & parent,
143  Region & leftChild,
144  Region & rightChild,
145  Feature_t & features,
146  Label_t & labels)
147  {}
148 
149  /** do something after each tree has been learned
150  *
151  * \param rf reference to the random forest object that called this
152  * visitor
153  * \param pr reference to the preprocessor that processed the input
154  * \param sm reference to the sampler object
155  * \param st reference to the first stack entry
156  * \param index index of current tree
157  */
158  template<class RF, class PR, class SM, class ST>
159  void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index)
160  {}
161 
162  /** do something after all trees have been learned
163  *
164  * \param rf reference to the random forest object that called this
165  * visitor
166  * \param pr reference to the preprocessor that processed the input
167  */
168  template<class RF, class PR>
169  void visit_at_end(RF const & rf, PR const & pr)
170  {}
171 
172  /** do something before learning starts
173  *
174  * \param rf reference to the random forest object that called this
175  * visitor
176  * \param pr reference to the Processor class used.
177  */
178  template<class RF, class PR>
179  void visit_at_beginning(RF const & rf, PR const & pr)
180  {}
181  /** do some thing while traversing tree after it has been learned
182  * (external nodes)
183  *
184  * \param tr reference to the tree object that called this visitor
185  * \param index index in the topology_ array we currently are at
186  * \param node_t type of node we have (will be e_.... - )
187  * \param weight Node weight of current node.
188  * \sa NodeTags;
189  *
190  * you can create the node by using a switch on node_tag and using the
191  * corresponding Node objects. Or - if you do not care about the type
192  * use the Nodebase class.
193  */
194  template<class TR, class IntT, class TopT,class Feat>
195  void visit_external_node(TR & tr, IntT index, TopT node_t,Feat & features)
196  {}
197 
198  /** do something when visiting a internal node after it has been learned
199  *
200  * \sa visit_external_node
201  */
202  template<class TR, class IntT, class TopT,class Feat>
203  void visit_internal_node(TR & tr, IntT index, TopT node_t,Feat & features)
204  {}
205 
206  /** return a double value. The value of the first
207  * visitor encountered that has a return value is returned with the
208  * RandomForest::learn() method - or -1.0 if no return value visitor
209  * existed. This functionality basically only exists so that the
210  * OOB - visitor can return the oob error rate like in the old version
211  * of the random forest.
212  */
213  double return_val()
214  {
215  return -1.0;
216  }
217 };
218 
219 
220 /** Last Visitor that should be called to stop the recursion.
221  */
223 {
224  public:
225  bool has_value()
226  {
227  return true;
228  }
229  double return_val()
230  {
231  return -1.0;
232  }
233 };
234 namespace detail
235 {
236 /** Container elements of the statically linked Visitor list.
237  *
238  * use the create_visitor() factory functions to create visitors up to size 10;
239  *
240  */
241 template <class Visitor, class Next = StopVisiting>
243 {
244  public:
245 
246  StopVisiting stop_;
247  Next next_;
248  Visitor & visitor_;
249  VisitorNode(Visitor & visitor, Next & next)
250  :
251  next_(next), visitor_(visitor)
252  {}
253 
254  VisitorNode(Visitor & visitor)
255  :
256  next_(stop_), visitor_(visitor)
257  {}
258 
259  template<class Tree, class Split, class Region, class Feature_t, class Label_t>
260  void visit_after_split( Tree & tree,
261  Split & split,
262  Region & parent,
263  Region & leftChild,
264  Region & rightChild,
265  Feature_t & features,
266  Label_t & labels)
267  {
268  if(visitor_.is_active())
269  visitor_.visit_after_split(tree, split,
270  parent, leftChild, rightChild,
271  features, labels);
272  next_.visit_after_split(tree, split, parent, leftChild, rightChild,
273  features, labels);
274  }
275 
276  template<class RF, class PR, class SM, class ST>
277  void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index)
278  {
279  if(visitor_.is_active())
280  visitor_.visit_after_tree(rf, pr, sm, st, index);
281  next_.visit_after_tree(rf, pr, sm, st, index);
282  }
283 
284  template<class RF, class PR>
285  void visit_at_beginning(RF & rf, PR & pr)
286  {
287  if(visitor_.is_active())
288  visitor_.visit_at_beginning(rf, pr);
289  next_.visit_at_beginning(rf, pr);
290  }
291  template<class RF, class PR>
292  void visit_at_end(RF & rf, PR & pr)
293  {
294  if(visitor_.is_active())
295  visitor_.visit_at_end(rf, pr);
296  next_.visit_at_end(rf, pr);
297  }
298 
299  template<class TR, class IntT, class TopT,class Feat>
300  void visit_external_node(TR & tr, IntT & index, TopT & node_t,Feat & features)
301  {
302  if(visitor_.is_active())
303  visitor_.visit_external_node(tr, index, node_t,features);
304  next_.visit_external_node(tr, index, node_t,features);
305  }
306  template<class TR, class IntT, class TopT,class Feat>
307  void visit_internal_node(TR & tr, IntT & index, TopT & node_t,Feat & features)
308  {
309  if(visitor_.is_active())
310  visitor_.visit_internal_node(tr, index, node_t,features);
311  next_.visit_internal_node(tr, index, node_t,features);
312  }
313 
314  double return_val()
315  {
316  if(visitor_.is_active() && visitor_.has_value())
317  return visitor_.return_val();
318  return next_.return_val();
319  }
320 };
321 
322 } //namespace detail
323 
324 //////////////////////////////////////////////////////////////////////////////
325 // Visitor Factory function up to 10 visitors //
326 //////////////////////////////////////////////////////////////////////////////
327 
328 /** factory method to to be used with RandomForest::learn()
329  */
330 template<class A>
333 {
334  typedef detail::VisitorNode<A> _0_t;
335  _0_t _0(a);
336  return _0;
337 }
338 
339 
340 /** factory method to to be used with RandomForest::learn()
341  */
342 template<class A, class B>
343 detail::VisitorNode<A, detail::VisitorNode<B> >
344 create_visitor(A & a, B & b)
345 {
346  typedef detail::VisitorNode<B> _1_t;
347  _1_t _1(b);
348  typedef detail::VisitorNode<A, _1_t> _0_t;
349  _0_t _0(a, _1);
350  return _0;
351 }
352 
353 
354 /** factory method to to be used with RandomForest::learn()
355  */
356 template<class A, class B, class C>
357 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C> > >
358 create_visitor(A & a, B & b, C & c)
359 {
360  typedef detail::VisitorNode<C> _2_t;
361  _2_t _2(c);
362  typedef detail::VisitorNode<B, _2_t> _1_t;
363  _1_t _1(b, _2);
364  typedef detail::VisitorNode<A, _1_t> _0_t;
365  _0_t _0(a, _1);
366  return _0;
367 }
368 
369 
370 /** factory method to to be used with RandomForest::learn()
371  */
372 template<class A, class B, class C, class D>
373 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
374  detail::VisitorNode<D> > > >
375 create_visitor(A & a, B & b, C & c, D & d)
376 {
377  typedef detail::VisitorNode<D> _3_t;
378  _3_t _3(d);
379  typedef detail::VisitorNode<C, _3_t> _2_t;
380  _2_t _2(c, _3);
381  typedef detail::VisitorNode<B, _2_t> _1_t;
382  _1_t _1(b, _2);
383  typedef detail::VisitorNode<A, _1_t> _0_t;
384  _0_t _0(a, _1);
385  return _0;
386 }
387 
388 
389 /** factory method to to be used with RandomForest::learn()
390  */
391 template<class A, class B, class C, class D, class E>
392 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
393  detail::VisitorNode<D, detail::VisitorNode<E> > > > >
394 create_visitor(A & a, B & b, C & c,
395  D & d, E & e)
396 {
397  typedef detail::VisitorNode<E> _4_t;
398  _4_t _4(e);
399  typedef detail::VisitorNode<D, _4_t> _3_t;
400  _3_t _3(d, _4);
401  typedef detail::VisitorNode<C, _3_t> _2_t;
402  _2_t _2(c, _3);
403  typedef detail::VisitorNode<B, _2_t> _1_t;
404  _1_t _1(b, _2);
405  typedef detail::VisitorNode<A, _1_t> _0_t;
406  _0_t _0(a, _1);
407  return _0;
408 }
409 
410 
411 /** factory method to to be used with RandomForest::learn()
412  */
413 template<class A, class B, class C, class D, class E,
414  class F>
415 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
416  detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F> > > > > >
417 create_visitor(A & a, B & b, C & c,
418  D & d, E & e, F & f)
419 {
420  typedef detail::VisitorNode<F> _5_t;
421  _5_t _5(f);
422  typedef detail::VisitorNode<E, _5_t> _4_t;
423  _4_t _4(e, _5);
424  typedef detail::VisitorNode<D, _4_t> _3_t;
425  _3_t _3(d, _4);
426  typedef detail::VisitorNode<C, _3_t> _2_t;
427  _2_t _2(c, _3);
428  typedef detail::VisitorNode<B, _2_t> _1_t;
429  _1_t _1(b, _2);
430  typedef detail::VisitorNode<A, _1_t> _0_t;
431  _0_t _0(a, _1);
432  return _0;
433 }
434 
435 
436 /** factory method to to be used with RandomForest::learn()
437  */
438 template<class A, class B, class C, class D, class E,
439  class F, class G>
440 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
441  detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
442  detail::VisitorNode<G> > > > > > >
443 create_visitor(A & a, B & b, C & c,
444  D & d, E & e, F & f, G & g)
445 {
446  typedef detail::VisitorNode<G> _6_t;
447  _6_t _6(g);
448  typedef detail::VisitorNode<F, _6_t> _5_t;
449  _5_t _5(f, _6);
450  typedef detail::VisitorNode<E, _5_t> _4_t;
451  _4_t _4(e, _5);
452  typedef detail::VisitorNode<D, _4_t> _3_t;
453  _3_t _3(d, _4);
454  typedef detail::VisitorNode<C, _3_t> _2_t;
455  _2_t _2(c, _3);
456  typedef detail::VisitorNode<B, _2_t> _1_t;
457  _1_t _1(b, _2);
458  typedef detail::VisitorNode<A, _1_t> _0_t;
459  _0_t _0(a, _1);
460  return _0;
461 }
462 
463 
464 /** factory method to to be used with RandomForest::learn()
465  */
466 template<class A, class B, class C, class D, class E,
467  class F, class G, class H>
468 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
469  detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
470  detail::VisitorNode<G, detail::VisitorNode<H> > > > > > > >
471 create_visitor(A & a, B & b, C & c,
472  D & d, E & e, F & f,
473  G & g, H & h)
474 {
475  typedef detail::VisitorNode<H> _7_t;
476  _7_t _7(h);
477  typedef detail::VisitorNode<G, _7_t> _6_t;
478  _6_t _6(g, _7);
479  typedef detail::VisitorNode<F, _6_t> _5_t;
480  _5_t _5(f, _6);
481  typedef detail::VisitorNode<E, _5_t> _4_t;
482  _4_t _4(e, _5);
483  typedef detail::VisitorNode<D, _4_t> _3_t;
484  _3_t _3(d, _4);
485  typedef detail::VisitorNode<C, _3_t> _2_t;
486  _2_t _2(c, _3);
487  typedef detail::VisitorNode<B, _2_t> _1_t;
488  _1_t _1(b, _2);
489  typedef detail::VisitorNode<A, _1_t> _0_t;
490  _0_t _0(a, _1);
491  return _0;
492 }
493 
494 
495 /** factory method to to be used with RandomForest::learn()
496  */
497 template<class A, class B, class C, class D, class E,
498  class F, class G, class H, class I>
499 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
500  detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
501  detail::VisitorNode<G, detail::VisitorNode<H, detail::VisitorNode<I> > > > > > > > >
502 create_visitor(A & a, B & b, C & c,
503  D & d, E & e, F & f,
504  G & g, H & h, I & i)
505 {
506  typedef detail::VisitorNode<I> _8_t;
507  _8_t _8(i);
508  typedef detail::VisitorNode<H, _8_t> _7_t;
509  _7_t _7(h, _8);
510  typedef detail::VisitorNode<G, _7_t> _6_t;
511  _6_t _6(g, _7);
512  typedef detail::VisitorNode<F, _6_t> _5_t;
513  _5_t _5(f, _6);
514  typedef detail::VisitorNode<E, _5_t> _4_t;
515  _4_t _4(e, _5);
516  typedef detail::VisitorNode<D, _4_t> _3_t;
517  _3_t _3(d, _4);
518  typedef detail::VisitorNode<C, _3_t> _2_t;
519  _2_t _2(c, _3);
520  typedef detail::VisitorNode<B, _2_t> _1_t;
521  _1_t _1(b, _2);
522  typedef detail::VisitorNode<A, _1_t> _0_t;
523  _0_t _0(a, _1);
524  return _0;
525 }
526 
527 /** factory method to to be used with RandomForest::learn()
528  */
529 template<class A, class B, class C, class D, class E,
530  class F, class G, class H, class I, class J>
531 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
532  detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
533  detail::VisitorNode<G, detail::VisitorNode<H, detail::VisitorNode<I,
534  detail::VisitorNode<J> > > > > > > > > >
535 create_visitor(A & a, B & b, C & c,
536  D & d, E & e, F & f,
537  G & g, H & h, I & i,
538  J & j)
539 {
540  typedef detail::VisitorNode<J> _9_t;
541  _9_t _9(j);
542  typedef detail::VisitorNode<I, _9_t> _8_t;
543  _8_t _8(i, _9);
544  typedef detail::VisitorNode<H, _8_t> _7_t;
545  _7_t _7(h, _8);
546  typedef detail::VisitorNode<G, _7_t> _6_t;
547  _6_t _6(g, _7);
548  typedef detail::VisitorNode<F, _6_t> _5_t;
549  _5_t _5(f, _6);
550  typedef detail::VisitorNode<E, _5_t> _4_t;
551  _4_t _4(e, _5);
552  typedef detail::VisitorNode<D, _4_t> _3_t;
553  _3_t _3(d, _4);
554  typedef detail::VisitorNode<C, _3_t> _2_t;
555  _2_t _2(c, _3);
556  typedef detail::VisitorNode<B, _2_t> _1_t;
557  _1_t _1(b, _2);
558  typedef detail::VisitorNode<A, _1_t> _0_t;
559  _0_t _0(a, _1);
560  return _0;
561 }
562 
563 //////////////////////////////////////////////////////////////////////////////
564 // Visitors of communal interest. //
565 //////////////////////////////////////////////////////////////////////////////
566 
567 
568 /** Visitor to gain information, later needed for online learning.
569  */
570 
572 {
573 public:
574  //Set if we adjust thresholds
575  bool adjust_thresholds;
576  //Current tree id
577  int tree_id;
578  //Last node id for finding parent
579  int last_node_id;
580  //Need to now the label for interior node visiting
581  vigra::Int32 current_label;
582  //marginal distribution for interior nodes
583  struct MarginalDistribution
584  {
585  ArrayVector<Int32> leftCounts;
586  Int32 leftTotalCounts;
587  ArrayVector<Int32> rightCounts;
588  Int32 rightTotalCounts;
589  double gap_left;
590  double gap_right;
591  };
593 
594  //All information for one tree
595  struct TreeOnlineInformation
596  {
597  std::vector<MarginalDistribution> mag_distributions;
598  std::vector<IndexList> index_lists;
599  //map for linear index of mag_distiributions
600  std::map<int,int> interior_to_index;
601  //map for linear index of index_lists
602  std::map<int,int> exterior_to_index;
603  };
604 
605  //All trees
606  std::vector<TreeOnlineInformation> trees_online_information;
607 
608  /** Initilize, set the number of trees
609  */
610  template<class RF,class PR>
611  void visit_at_beginning(RF & rf,const PR & pr)
612  {
613  tree_id=0;
614  trees_online_information.resize(rf.options_.tree_count_);
615  }
616 
617  /** Reset a tree
618  */
619  void reset_tree(int tree_id)
620  {
621  trees_online_information[tree_id].mag_distributions.clear();
622  trees_online_information[tree_id].index_lists.clear();
623  trees_online_information[tree_id].interior_to_index.clear();
624  trees_online_information[tree_id].exterior_to_index.clear();
625  }
626 
627  /** simply increase the tree count
628  */
629  template<class RF, class PR, class SM, class ST>
630  void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index)
631  {
632  tree_id++;
633  }
634 
635  template<class Tree, class Split, class Region, class Feature_t, class Label_t>
636  void visit_after_split( Tree & tree,
637  Split & split,
638  Region & parent,
639  Region & leftChild,
640  Region & rightChild,
641  Feature_t & features,
642  Label_t & labels)
643  {
644  int linear_index;
645  int addr=tree.topology_.size();
646  if(split.createNode().typeID() == i_ThresholdNode)
647  {
648  if(adjust_thresholds)
649  {
650  //Store marginal distribution
651  linear_index=trees_online_information[tree_id].mag_distributions.size();
652  trees_online_information[tree_id].interior_to_index[addr]=linear_index;
653  trees_online_information[tree_id].mag_distributions.push_back(MarginalDistribution());
654 
655  trees_online_information[tree_id].mag_distributions.back().leftCounts=leftChild.classCounts_;
656  trees_online_information[tree_id].mag_distributions.back().rightCounts=rightChild.classCounts_;
657 
658  trees_online_information[tree_id].mag_distributions.back().leftTotalCounts=leftChild.size_;
659  trees_online_information[tree_id].mag_distributions.back().rightTotalCounts=rightChild.size_;
660  //Store the gap
661  double gap_left,gap_right;
662  int i;
663  gap_left=features(leftChild[0],split.bestSplitColumn());
664  for(i=1;i<leftChild.size();++i)
665  if(features(leftChild[i],split.bestSplitColumn())>gap_left)
666  gap_left=features(leftChild[i],split.bestSplitColumn());
667  gap_right=features(rightChild[0],split.bestSplitColumn());
668  for(i=1;i<rightChild.size();++i)
669  if(features(rightChild[i],split.bestSplitColumn())<gap_right)
670  gap_right=features(rightChild[i],split.bestSplitColumn());
671  trees_online_information[tree_id].mag_distributions.back().gap_left=gap_left;
672  trees_online_information[tree_id].mag_distributions.back().gap_right=gap_right;
673  }
674  }
675  else
676  {
677  //Store index list
678  linear_index=trees_online_information[tree_id].index_lists.size();
679  trees_online_information[tree_id].exterior_to_index[addr]=linear_index;
680 
681  trees_online_information[tree_id].index_lists.push_back(IndexList());
682 
683  trees_online_information[tree_id].index_lists.back().resize(parent.size_,0);
684  std::copy(parent.begin_,parent.end_,trees_online_information[tree_id].index_lists.back().begin());
685  }
686  }
687  void add_to_index_list(int tree,int node,int index)
688  {
689  if(!this->active_)
690  return;
691  TreeOnlineInformation &ti=trees_online_information[tree];
692  ti.index_lists[ti.exterior_to_index[node]].push_back(index);
693  }
694  void move_exterior_node(int src_tree,int src_index,int dst_tree,int dst_index)
695  {
696  if(!this->active_)
697  return;
698  trees_online_information[dst_tree].exterior_to_index[dst_index]=trees_online_information[src_tree].exterior_to_index[src_index];
699  trees_online_information[src_tree].exterior_to_index.erase(src_index);
700  }
701  /** do something when visiting a internal node during getToLeaf
702  *
703  * remember as last node id, for finding the parent of the last external node
704  * also: adjust class counts and borders
705  */
706  template<class TR, class IntT, class TopT,class Feat>
707  void visit_internal_node(TR & tr, IntT index, TopT node_t,Feat & features)
708  {
709  last_node_id=index;
710  if(adjust_thresholds)
711  {
712  vigra_assert(node_t==i_ThresholdNode,"We can only visit threshold nodes");
713  //Check if we are in the gap
714  double value=features(0, Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).column());
715  TreeOnlineInformation &ti=trees_online_information[tree_id];
716  MarginalDistribution &m=ti.mag_distributions[ti.interior_to_index[index]];
717  if(value>m.gap_left && value<m.gap_right)
718  {
719  //Check which site we want to go
720  if(m.leftCounts[current_label]/double(m.leftTotalCounts)>m.rightCounts[current_label]/double(m.rightTotalCounts))
721  {
722  //We want to go left
723  m.gap_left=value;
724  }
725  else
726  {
727  //We want to go right
728  m.gap_right=value;
729  }
730  Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).threshold()=(m.gap_right+m.gap_left)/2.0;
731  }
732  //Adjust class counts
733  if(value>Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).threshold())
734  {
735  ++m.rightTotalCounts;
736  ++m.rightCounts[current_label];
737  }
738  else
739  {
740  ++m.leftTotalCounts;
741  ++m.rightCounts[current_label];
742  }
743  }
744  }
745  /** do something when visiting a extern node during getToLeaf
746  *
747  * Store the new index!
748  */
749 };
750 
751 //////////////////////////////////////////////////////////////////////////////
752 // Out of Bag Error estimates //
753 //////////////////////////////////////////////////////////////////////////////
754 
755 
756 /** Visitor that calculates the oob error of each individual randomized
757  * decision tree.
758  *
759  * After training a tree, all those samples that are OOB for this particular tree
760  * are put down the tree and the error estimated.
761  * the per tree oob error is the average of the individual error estimates.
762  * (oobError = average error of one randomized tree)
763  * Note: This is Not the OOB - Error estimate suggested by Breiman (See OOB_Error
764  * visitor)
765  */
767 {
768 public:
769  /** Average error of one randomized decision tree
770  */
771  double oobError;
772 
773  int totalOobCount;
774  ArrayVector<int> oobCount,oobErrorCount;
775 
777  : oobError(0.0),
778  totalOobCount(0)
779  {}
780 
781 
782  bool has_value()
783  {
784  return true;
785  }
786 
787 
788  /** does the basic calculation per tree*/
789  template<class RF, class PR, class SM, class ST>
790  void visit_after_tree( RF& rf, PR & pr, SM & sm, ST & st, int index)
791  {
792  //do the first time called.
793  if(int(oobCount.size()) != rf.ext_param_.row_count_)
794  {
795  oobCount.resize(rf.ext_param_.row_count_, 0);
796  oobErrorCount.resize(rf.ext_param_.row_count_, 0);
797  }
798  // go through the samples
799  for(int l = 0; l < rf.ext_param_.row_count_; ++l)
800  {
801  // if the lth sample is oob...
802  if(!sm.is_used()[l])
803  {
804  ++oobCount[l];
805  if( rf.tree(index)
806  .predictLabel(rowVector(pr.features(), l))
807  != pr.response()(l,0))
808  {
809  ++oobErrorCount[l];
810  }
811  }
812 
813  }
814  }
815 
816  /** Does the normalisation
817  */
818  template<class RF, class PR>
819  void visit_at_end(RF & rf, PR & pr)
820  {
821  // do some normalisation
822  for(int l=0; l < (int)rf.ext_param_.row_count_; ++l)
823  {
824  if(oobCount[l])
825  {
826  oobError += double(oobErrorCount[l]) / oobCount[l];
827  ++totalOobCount;
828  }
829  }
830  oobError/=totalOobCount;
831  }
832 
833 };
834 
835 /** Visitor that calculates the oob error of the ensemble
836  * This rate should be used to estimate the crossvalidation
837  * error rate.
838  * Here each sample is put down those trees, for which this sample
839  * is OOB i.e. if sample #1 is OOB for trees 1, 3 and 5 we calculate
840  * the output using the ensemble consisting only of trees 1 3 and 5.
841  *
842  * Using normal bagged sampling each sample is OOB for approx. 33% of trees
843  * The error rate obtained as such therefore corresponds to crossvalidation
844  * rate obtained using a ensemble containing 33% of the trees.
845  */
846 class OOB_Error : public VisitorBase
847 {
849  int class_count;
850  bool is_weighted;
851  MultiArray<2,double> tmp_prob;
852  public:
853 
854  MultiArray<2, double> prob_oob;
855  /** Ensemble oob error rate
856  */
857  double oob_breiman;
858 
859  MultiArray<2, double> oobCount;
860  ArrayVector< int> indices;
861  OOB_Error() : VisitorBase(), oob_breiman(0.0) {}
862 
863  void save(std::string filen, std::string pathn)
864  {
865  if(*(pathn.end()-1) != '/')
866  pathn += "/";
867  const char* filename = filen.c_str();
868  MultiArray<2, double> temp(Shp(1,1), 0.0);
869  temp[0] = oob_breiman;
870  writeHDF5(filename, (pathn + "breiman_error").c_str(), temp);
871  }
872  // negative value if sample was ib, number indicates how often.
873  // value >=0 if sample was oob, 0 means fail 1, corrrect
874 
875  template<class RF, class PR>
876  void visit_at_beginning(RF & rf, PR & pr)
877  {
878  class_count = rf.class_count();
879  tmp_prob.reshape(Shp(1, class_count), 0);
880  prob_oob.reshape(Shp(rf.ext_param().row_count_,class_count), 0);
881  is_weighted = rf.options().predict_weighted_;
882  indices.resize(rf.ext_param().row_count_);
883  if(int(oobCount.size()) != rf.ext_param_.row_count_)
884  {
885  oobCount.reshape(Shp(rf.ext_param_.row_count_, 1), 0);
886  }
887  for(int ii = 0; ii < rf.ext_param().row_count_; ++ii)
888  {
889  indices[ii] = ii;
890  }
891  }
892 
893  template<class RF, class PR, class SM, class ST>
894  void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index)
895  {
896  // go through the samples
897  int total_oob =0;
898  int wrong_oob =0;
899  // FIXME: magic number 10000: invoke special treatment when when msample << sample_count
900  // (i.e. the OOB sample ist very large)
901  // 40000: use at most 40000 OOB samples per class for OOB error estimate
902  if(rf.ext_param_.actual_msample_ < pr.features().shape(0) - 10000)
903  {
904  ArrayVector<int> oob_indices;
905  ArrayVector<int> cts(class_count, 0);
906  std::random_shuffle(indices.begin(), indices.end());
907  for(int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
908  {
909  if(!sm.is_used()[indices[ii]] && cts[pr.response()(indices[ii], 0)] < 40000)
910  {
911  oob_indices.push_back(indices[ii]);
912  ++cts[pr.response()(indices[ii], 0)];
913  }
914  }
915  for(unsigned int ll = 0; ll < oob_indices.size(); ++ll)
916  {
917  // update number of trees in which current sample is oob
918  ++oobCount[oob_indices[ll]];
919 
920  // update number of oob samples in this tree.
921  ++total_oob;
922  // get the predicted votes ---> tmp_prob;
923  int pos = rf.tree(index).getToLeaf(rowVector(pr.features(),oob_indices[ll]));
924  Node<e_ConstProbNode> node ( rf.tree(index).topology_,
925  rf.tree(index).parameters_,
926  pos);
927  tmp_prob.init(0);
928  for(int ii = 0; ii < class_count; ++ii)
929  {
930  tmp_prob[ii] = node.prob_begin()[ii];
931  }
932  if(is_weighted)
933  {
934  for(int ii = 0; ii < class_count; ++ii)
935  tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1));
936  }
937  rowVector(prob_oob, oob_indices[ll]) += tmp_prob;
938  int label = argMax(tmp_prob);
939 
940  }
941  }else
942  {
943  for(int ll = 0; ll < rf.ext_param_.row_count_; ++ll)
944  {
945  // if the lth sample is oob...
946  if(!sm.is_used()[ll])
947  {
948  // update number of trees in which current sample is oob
949  ++oobCount[ll];
950 
951  // update number of oob samples in this tree.
952  ++total_oob;
953  // get the predicted votes ---> tmp_prob;
954  int pos = rf.tree(index).getToLeaf(rowVector(pr.features(),ll));
955  Node<e_ConstProbNode> node ( rf.tree(index).topology_,
956  rf.tree(index).parameters_,
957  pos);
958  tmp_prob.init(0);
959  for(int ii = 0; ii < class_count; ++ii)
960  {
961  tmp_prob[ii] = node.prob_begin()[ii];
962  }
963  if(is_weighted)
964  {
965  for(int ii = 0; ii < class_count; ++ii)
966  tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1));
967  }
968  rowVector(prob_oob, ll) += tmp_prob;
969  int label = argMax(tmp_prob);
970 
971  }
972  }
973  }
974  // go through the ib samples;
975  }
976 
977  /** Normalise variable importance after the number of trees is known.
978  */
979  template<class RF, class PR>
980  void visit_at_end(RF & rf, PR & pr)
981  {
982  // ullis original metric and breiman style stuff
983  int totalOobCount =0;
984  int breimanstyle = 0;
985  for(int ll=0; ll < (int)rf.ext_param_.row_count_; ++ll)
986  {
987  if(oobCount[ll])
988  {
989  if(argMax(rowVector(prob_oob, ll)) != pr.response()(ll, 0))
990  ++breimanstyle;
991  ++totalOobCount;
992  }
993  }
994  oob_breiman = double(breimanstyle)/totalOobCount;
995  }
996 };
997 
998 
999 /** Visitor that calculates different OOB error statistics
1000  */
1002 {
1003  typedef MultiArrayShape<2>::type Shp;
1004  int class_count;
1005  bool is_weighted;
1006  MultiArray<2,double> tmp_prob;
1007  public:
1008 
1009  /** OOB Error rate of each individual tree
1010  */
1012  /** Mean of oob_per_tree
1013  */
1014  double oob_mean;
1015  /**Standard deviation of oob_per_tree
1016  */
1017  double oob_std;
1018 
1019  MultiArray<2, double> prob_oob;
1020  /** Ensemble OOB error
1021  *
1022  * \sa OOB_Error
1023  */
1024  double oob_breiman;
1025 
1026  MultiArray<2, double> oobCount;
1027  MultiArray<2, double> oobErrorCount;
1028  /** Per Tree OOB error calculated as in OOB_PerTreeError
1029  * (Ulli's version)
1030  */
1032 
1033  /**Column containing the development of the Ensemble
1034  * error rate with increasing number of trees
1035  */
1037  /** 4 dimensional array containing the development of confusion matrices
1038  * with number of trees - can be used to estimate ROC curves etc.
1039  *
1040  * oobroc_per_tree(ii,jj,kk,ll)
1041  * corresponds true label = ii
1042  * predicted label = jj
1043  * confusion matrix after ll trees
1044  *
1045  * explaination of third index:
1046  *
1047  * Two class case:
1048  * kk = 0 - (treeCount-1)
1049  * Threshold is on Probability for class 0 is kk/(treeCount-1);
1050  * More classes:
1051  * kk = 0. Threshold on probability set by argMax of the probability array.
1052  */
1054 
1056 
1057  /** save to HDF5 file
1058  */
1059  void save(std::string filen, std::string pathn)
1060  {
1061  if(*(pathn.end()-1) != '/')
1062  pathn += "/";
1063  const char* filename = filen.c_str();
1064  MultiArray<2, double> temp(Shp(1,1), 0.0);
1065  writeHDF5(filename, (pathn + "oob_per_tree").c_str(), oob_per_tree);
1066  writeHDF5(filename, (pathn + "oobroc_per_tree").c_str(), oobroc_per_tree);
1067  writeHDF5(filename, (pathn + "breiman_per_tree").c_str(), breiman_per_tree);
1068  temp[0] = oob_mean;
1069  writeHDF5(filename, (pathn + "per_tree_error").c_str(), temp);
1070  temp[0] = oob_std;
1071  writeHDF5(filename, (pathn + "per_tree_error_std").c_str(), temp);
1072  temp[0] = oob_breiman;
1073  writeHDF5(filename, (pathn + "breiman_error").c_str(), temp);
1074  temp[0] = oob_per_tree2;
1075  writeHDF5(filename, (pathn + "ulli_error").c_str(), temp);
1076  }
1077  // negative value if sample was ib, number indicates how often.
1078  // value >=0 if sample was oob, 0 means fail 1, corrrect
1079 
1080  template<class RF, class PR>
1081  void visit_at_beginning(RF & rf, PR & pr)
1082  {
1083  class_count = rf.class_count();
1084  if(class_count == 2)
1085  oobroc_per_tree.reshape(MultiArrayShape<4>::type(2,2,rf.tree_count(), rf.tree_count()));
1086  else
1087  oobroc_per_tree.reshape(MultiArrayShape<4>::type(rf.class_count(),rf.class_count(),1, rf.tree_count()));
1088  tmp_prob.reshape(Shp(1, class_count), 0);
1089  prob_oob.reshape(Shp(rf.ext_param().row_count_,class_count), 0);
1090  is_weighted = rf.options().predict_weighted_;
1091  oob_per_tree.reshape(Shp(1, rf.tree_count()), 0);
1092  breiman_per_tree.reshape(Shp(1, rf.tree_count()), 0);
1093  //do the first time called.
1094  if(int(oobCount.size()) != rf.ext_param_.row_count_)
1095  {
1096  oobCount.reshape(Shp(rf.ext_param_.row_count_, 1), 0);
1097  oobErrorCount.reshape(Shp(rf.ext_param_.row_count_,1), 0);
1098  }
1099  }
1100 
1101  template<class RF, class PR, class SM, class ST>
1102  void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index)
1103  {
1104  // go through the samples
1105  int total_oob =0;
1106  int wrong_oob =0;
1107  for(int ll = 0; ll < rf.ext_param_.row_count_; ++ll)
1108  {
1109  // if the lth sample is oob...
1110  if(!sm.is_used()[ll])
1111  {
1112  // update number of trees in which current sample is oob
1113  ++oobCount[ll];
1114 
1115  // update number of oob samples in this tree.
1116  ++total_oob;
1117  // get the predicted votes ---> tmp_prob;
1118  int pos = rf.tree(index).getToLeaf(rowVector(pr.features(),ll));
1119  Node<e_ConstProbNode> node ( rf.tree(index).topology_,
1120  rf.tree(index).parameters_,
1121  pos);
1122  tmp_prob.init(0);
1123  for(int ii = 0; ii < class_count; ++ii)
1124  {
1125  tmp_prob[ii] = node.prob_begin()[ii];
1126  }
1127  if(is_weighted)
1128  {
1129  for(int ii = 0; ii < class_count; ++ii)
1130  tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1));
1131  }
1132  rowVector(prob_oob, ll) += tmp_prob;
1133  int label = argMax(tmp_prob);
1134 
1135  if(label != pr.response()(ll, 0))
1136  {
1137  // update number of wrong oob samples in this tree.
1138  ++wrong_oob;
1139  // update number of trees in which current sample is wrong oob
1140  ++oobErrorCount[ll];
1141  }
1142  }
1143  }
1144  int breimanstyle = 0;
1145  int totalOobCount = 0;
1146  for(int ll=0; ll < (int)rf.ext_param_.row_count_; ++ll)
1147  {
1148  if(oobCount[ll])
1149  {
1150  if(argMax(rowVector(prob_oob, ll)) != pr.response()(ll, 0))
1151  ++breimanstyle;
1152  ++totalOobCount;
1153  if(oobroc_per_tree.shape(2) == 1)
1154  {
1155  oobroc_per_tree(pr.response()(ll,0), argMax(rowVector(prob_oob, ll)),0 ,index)++;
1156  }
1157  }
1158  }
1159  if(oobroc_per_tree.shape(2) == 1)
1160  oobroc_per_tree.bindOuter(index)/=totalOobCount;
1161  if(oobroc_per_tree.shape(2) > 1)
1162  {
1163  MultiArrayView<3, double> current_roc
1164  = oobroc_per_tree.bindOuter(index);
1165  for(int gg = 0; gg < current_roc.shape(2); ++gg)
1166  {
1167  for(int ll=0; ll < (int)rf.ext_param_.row_count_; ++ll)
1168  {
1169  if(oobCount[ll])
1170  {
1171  int pred = prob_oob(ll, 1) > (double(gg)/double(current_roc.shape(2)))?
1172  1 : 0;
1173  current_roc(pr.response()(ll, 0), pred, gg)+= 1;
1174  }
1175  }
1176  current_roc.bindOuter(gg)/= totalOobCount;
1177  }
1178  }
1179  breiman_per_tree[index] = double(breimanstyle)/double(totalOobCount);
1180  oob_per_tree[index] = double(wrong_oob)/double(total_oob);
1181  // go through the ib samples;
1182  }
1183 
1184  /** Normalise variable importance after the number of trees is known.
1185  */
1186  template<class RF, class PR>
1187  void visit_at_end(RF & rf, PR & pr)
1188  {
1189  // ullis original metric and breiman style stuff
1190  oob_per_tree2 = 0;
1191  int totalOobCount =0;
1192  int breimanstyle = 0;
1193  for(int ll=0; ll < (int)rf.ext_param_.row_count_; ++ll)
1194  {
1195  if(oobCount[ll])
1196  {
1197  if(argMax(rowVector(prob_oob, ll)) != pr.response()(ll, 0))
1198  ++breimanstyle;
1199  oob_per_tree2 += double(oobErrorCount[ll]) / oobCount[ll];
1200  ++totalOobCount;
1201  }
1202  }
1203  oob_per_tree2 /= totalOobCount;
1204  oob_breiman = double(breimanstyle)/totalOobCount;
1205  // mean error of each tree
1206  MultiArrayView<2, double> mean(Shp(1,1), &oob_mean);
1207  MultiArrayView<2, double> stdDev(Shp(1,1), &oob_std);
1208  rowStatistics(oob_per_tree, mean, stdDev);
1209  }
1210 };
1211 
1212 /** calculate variable importance while learning.
1213  */
1215 {
1216  public:
1217 
1218  /** This Array has the same entries as the R - random forest variable
1219  * importance.
1220  * Matrix is featureCount by (classCount +2)
1221  * variable_importance_(ii,jj) is the variable importance measure of
1222  * the ii-th variable according to:
1223  * jj = 0 - (classCount-1)
1224  * classwise permutation importance
1225  * jj = rowCount(variable_importance_) -2
1226  * permutation importance
1227  * jj = rowCount(variable_importance_) -1
1228  * gini decrease importance.
1229  *
1230  * permutation importance:
1231  * The difference between the fraction of OOB samples classified correctly
1232  * before and after permuting (randomizing) the ii-th column is calculated.
1233  * The ii-th column is permuted rep_cnt times.
1234  *
1235  * class wise permutation importance:
1236  * same as permutation importance. We only look at those OOB samples whose
1237  * response corresponds to class jj.
1238  *
1239  * gini decrease importance:
1240  * row ii corresponds to the sum of all gini decreases induced by variable ii
1241  * in each node of the random forest.
1242  */
1244  int repetition_count_;
1245  bool in_place_;
1246 
1247 #ifdef HasHDF5
1248  void save(std::string filename, std::string prefix)
1249  {
1250  prefix = "variable_importance_" + prefix;
1251  writeHDF5(filename.c_str(),
1252  prefix.c_str(),
1254  }
1255 #endif
1256  /** Constructor
1257  * \param rep_cnt (defautl: 10) how often should
1258  * the permutation take place. Set to 1 to make calculation faster (but
1259  * possibly more instable)
1260  */
1261  VariableImportanceVisitor(int rep_cnt = 10)
1262  : repetition_count_(rep_cnt)
1263 
1264  {}
1265 
1266  /** calculates impurity decrease based variable importance after every
1267  * split.
1268  */
1269  template<class Tree, class Split, class Region, class Feature_t, class Label_t>
1270  void visit_after_split( Tree & tree,
1271  Split & split,
1272  Region & parent,
1273  Region & leftChild,
1274  Region & rightChild,
1275  Feature_t & features,
1276  Label_t & labels)
1277  {
1278  //resize to right size when called the first time
1279 
1280  Int32 const class_count = tree.ext_param_.class_count_;
1281  Int32 const column_count = tree.ext_param_.column_count_;
1282  if(variable_importance_.size() == 0)
1283  {
1284 
1286  .reshape(MultiArrayShape<2>::type(column_count,
1287  class_count+2));
1288  }
1289 
1290  if(split.createNode().typeID() == i_ThresholdNode)
1291  {
1292  Node<i_ThresholdNode> node(split.createNode());
1293  variable_importance_(node.column(),class_count+1)
1294  += split.region_gini_ - split.minGini();
1295  }
1296  }
1297 
1298  /**compute permutation based var imp.
1299  * (Only an Array of size oob_sample_count x 1 is created.
1300  * - apposed to oob_sample_count x feature_count in the other method.
1301  *
1302  * \sa FieldProxy
1303  */
1304  template<class RF, class PR, class SM, class ST>
1305  void after_tree_ip_impl(RF& rf, PR & pr, SM & sm, ST & st, int index)
1306  {
1307  typedef MultiArrayShape<2>::type Shp_t;
1308  Int32 column_count = rf.ext_param_.column_count_;
1309  Int32 class_count = rf.ext_param_.class_count_;
1310 
1311  /* This solution saves memory uptake but not multithreading
1312  * compatible
1313  */
1314  // remove the const cast on the features (yep , I know what I am
1315  // doing here.) data is not destroyed.
1316  //typename PR::Feature_t & features
1317  // = const_cast<typename PR::Feature_t &>(pr.features());
1318 
1319  typename PR::FeatureWithMemory_t features = pr.features();
1320 
1321  //find the oob indices of current tree.
1322  ArrayVector<Int32> oob_indices;
1324  iter;
1325  for(int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
1326  if(!sm.is_used()[ii])
1327  oob_indices.push_back(ii);
1328 
1329  //create space to back up a column
1330  std::vector<double> backup_column;
1331 
1332  // Random foo
1333 #ifdef CLASSIFIER_TEST
1334  RandomMT19937 random(1);
1335 #else
1336  RandomMT19937 random(RandomSeed);
1337 #endif
1339  randint(random);
1340 
1341 
1342  //make some space for the results
1344  oob_right(Shp_t(1, class_count + 1));
1346  perm_oob_right (Shp_t(1, class_count + 1));
1347 
1348 
1349  // get the oob success rate with the original samples
1350  for(iter = oob_indices.begin();
1351  iter != oob_indices.end();
1352  ++iter)
1353  {
1354  if(rf.tree(index)
1355  .predictLabel(rowVector(features, *iter))
1356  == pr.response()(*iter, 0))
1357  {
1358  //per class
1359  ++oob_right[pr.response()(*iter,0)];
1360  //total
1361  ++oob_right[class_count];
1362  }
1363  }
1364  //get the oob rate after permuting the ii'th dimension.
1365  for(int ii = 0; ii < column_count; ++ii)
1366  {
1367  perm_oob_right.init(0.0);
1368  //make backup of orinal column
1369  backup_column.clear();
1370  for(iter = oob_indices.begin();
1371  iter != oob_indices.end();
1372  ++iter)
1373  {
1374  backup_column.push_back(features(*iter,ii));
1375  }
1376 
1377  //get the oob rate after permuting the ii'th dimension.
1378  for(int rr = 0; rr < repetition_count_; ++rr)
1379  {
1380  //permute dimension.
1381  int n = oob_indices.size();
1382  for(int jj = 1; jj < n; ++jj)
1383  std::swap(features(oob_indices[jj], ii),
1384  features(oob_indices[randint(jj+1)], ii));
1385 
1386  //get the oob sucess rate after permuting
1387  for(iter = oob_indices.begin();
1388  iter != oob_indices.end();
1389  ++iter)
1390  {
1391  if(rf.tree(index)
1392  .predictLabel(rowVector(features, *iter))
1393  == pr.response()(*iter, 0))
1394  {
1395  //per class
1396  ++perm_oob_right[pr.response()(*iter, 0)];
1397  //total
1398  ++perm_oob_right[class_count];
1399  }
1400  }
1401  }
1402 
1403 
1404  //normalise and add to the variable_importance array.
1405  perm_oob_right /= repetition_count_;
1406  perm_oob_right -=oob_right;
1407  perm_oob_right *= -1;
1408  perm_oob_right /= oob_indices.size();
1410  .subarray(Shp_t(ii,0),
1411  Shp_t(ii+1,class_count+1)) += perm_oob_right;
1412  //copy back permuted dimension
1413  for(int jj = 0; jj < int(oob_indices.size()); ++jj)
1414  features(oob_indices[jj], ii) = backup_column[jj];
1415  }
1416  }
1417 
1418  /** calculate permutation based impurity after every tree has been
1419  * learned default behaviour is that this happens out of place.
1420  * If you have very big data sets and want to avoid copying of data
1421  * set the in_place_ flag to true.
1422  */
1423  template<class RF, class PR, class SM, class ST>
1424  void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index)
1425  {
1426  after_tree_ip_impl(rf, pr, sm, st, index);
1427  }
1428 
1429  /** Normalise variable importance after the number of trees is known.
1430  */
1431  template<class RF, class PR>
1432  void visit_at_end(RF & rf, PR & pr)
1433  {
1434  variable_importance_ /= rf.trees_.size();
1435  }
1436 };
1437 
1438 /** Verbose output
1439  */
1441  public:
1443 
1444  template<class RF, class PR, class SM, class ST>
1445  void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index){
1446  if(index != rf.options().tree_count_-1) {
1447  std::cout << "\r[" << std::setw(10) << (index+1)/static_cast<double>(rf.options().tree_count_)*100 << "%]"
1448  << " (" << index+1 << " of " << rf.options().tree_count_ << ") done" << std::flush;
1449  }
1450  else {
1451  std::cout << "\r[" << std::setw(10) << 100.0 << "%]" << std::endl;
1452  }
1453  }
1454 
1455  template<class RF, class PR>
1456  void visit_at_end(RF const & rf, PR const & pr) {
1457  std::string a = TOCS;
1458  std::cout << "all " << rf.options().tree_count_ << " trees have been learned in " << a << std::endl;
1459  }
1460 
1461  template<class RF, class PR>
1462  void visit_at_beginning(RF const & rf, PR const & pr) {
1463  TIC;
1464  std::cout << "growing random forest, which will have " << rf.options().tree_count_ << " trees" << std::endl;
1465  }
1466 
1467  private:
1468  USETICTOC;
1469 };
1470 
1471 
1472 /** Computes Correlation/Similarity Matrix of features while learning
1473  * random forest.
1474  */
1476 {
1477  public:
1478  /** gini_missc(ii, jj) describes how well variable jj can describe a partition
1479  * created on variable ii(when variable ii was chosen)
1480  */
1482  MultiArray<2, int> tmp_labels;
1483  /** additional noise features.
1484  */
1486  MultiArray<2, double> noise_l;
1487  /** how well can a noise column describe a partition created on variable ii.
1488  */
1490  MultiArray<2, double> corr_l;
1491 
1492  /** Similarity Matrix
1493  *
1494  * (numberOfFeatures + 1) by (number Of Features + 1) Matrix
1495  * gini_missc
1496  * - row normalized by the number of times the column was chosen
1497  * - mean of corr_noise subtracted
1498  * - and symmetrised.
1499  *
1500  */
1502  /** Distance Matrix 1-similarity
1503  */
1505  ArrayVector<int> tmp_cc;
1506 
1507  /** How often was variable ii chosen
1508  */
1512  void save(std::string file, std::string prefix)
1513  {
1514  /*
1515  std::string tmp;
1516 #define VAR_WRITE(NAME) \
1517  tmp = #NAME;\
1518  tmp += "_";\
1519  tmp += prefix;\
1520  vigra::writeToHDF5File(file.c_str(), tmp.c_str(), NAME);
1521  VAR_WRITE(gini_missc);
1522  VAR_WRITE(corr_noise);
1523  VAR_WRITE(distance);
1524  VAR_WRITE(similarity);
1525  vigra::writeToHDF5File(file.c_str(), "nChoices", MultiArrayView<2, int>(MultiArrayShape<2>::type(numChoices.size(),1), numChoices.data()));
1526 #undef VAR_WRITE
1527 */
1528  }
1529  template<class RF, class PR>
1530  void visit_at_beginning(RF const & rf, PR & pr)
1531  {
1532  typedef MultiArrayShape<2>::type Shp;
1533  int n = rf.ext_param_.column_count_;
1534  gini_missc.reshape(Shp(n +1,n+ 1));
1535  corr_noise.reshape(Shp(n + 1, 10));
1536  corr_l.reshape(Shp(n +1, 10));
1537 
1538  noise.reshape(Shp(pr.features().shape(0), 10));
1539  noise_l.reshape(Shp(pr.features().shape(0), 10));
1540  RandomMT19937 random(RandomSeed);
1541  for(int ii = 0; ii < noise.size(); ++ii)
1542  {
1543  noise[ii] = random.uniform53();
1544  noise_l[ii] = random.uniform53() > 0.5;
1545  }
1546  bgfunc = ColumnDecisionFunctor( rf.ext_param_);
1547  tmp_labels.reshape(pr.response().shape());
1548  tmp_cc.resize(2);
1549  numChoices.resize(n+1);
1550  // look at allaxes
1551  }
1552  template<class RF, class PR>
1553  void visit_at_end(RF const & rf, PR const & pr)
1554  {
1555  typedef MultiArrayShape<2>::type Shp;
1558  MultiArray<2, double> mean_noise(Shp(corr_noise.shape(0), 1));
1559  rowStatistics(corr_noise, mean_noise);
1560  mean_noise/= MultiArrayView<2, int>(mean_noise.shape(), numChoices.data());
1561  int rC = similarity.shape(0);
1562  for(int jj = 0; jj < rC-1; ++jj)
1563  {
1564  rowVector(similarity, jj) /= numChoices[jj];
1565  rowVector(similarity, jj) -= mean_noise(jj, 0);
1566  }
1567  for(int jj = 0; jj < rC; ++jj)
1568  {
1569  similarity(rC -1, jj) /= numChoices[jj];
1570  }
1571  rowVector(similarity, rC - 1) -= mean_noise(rC-1, 0);
1573  FindMinMax<double> minmax;
1574  inspectMultiArray(srcMultiArrayRange(similarity), minmax);
1575 
1576  for(int jj = 0; jj < rC; ++jj)
1577  similarity(jj, jj) = minmax.max;
1578 
1579  similarity.subarray(Shp(0,0), Shp(rC-1, rC-1))
1580  += similarity.subarray(Shp(0,0), Shp(rC-1, rC-1)).transpose();
1581  similarity.subarray(Shp(0,0), Shp(rC-1, rC-1))/= 2;
1582  columnVector(similarity, rC-1) = rowVector(similarity, rC-1).transpose();
1583  for(int jj = 0; jj < rC; ++jj)
1584  similarity(jj, jj) = 0;
1585 
1586  FindMinMax<double> minmax2;
1587  inspectMultiArray(srcMultiArrayRange(similarity), minmax2);
1588  for(int jj = 0; jj < rC; ++jj)
1589  similarity(jj, jj) = minmax2.max;
1590  distance.reshape(gini_missc.shape(), minmax2.max);
1591  distance -= similarity;
1592  }
1593 
1594  template<class Tree, class Split, class Region, class Feature_t, class Label_t>
1595  void visit_after_split( Tree & tree,
1596  Split & split,
1597  Region & parent,
1598  Region & leftChild,
1599  Region & rightChild,
1600  Feature_t & features,
1601  Label_t & labels)
1602  {
1603  if(split.createNode().typeID() == i_ThresholdNode)
1604  {
1605  double wgini;
1606  tmp_cc.init(0);
1607  for(int ii = 0; ii < parent.size(); ++ii)
1608  {
1609  tmp_labels[parent[ii]]
1610  = (features(parent[ii], split.bestSplitColumn()) < split.bestSplitThreshold());
1611  ++tmp_cc[tmp_labels[parent[ii]]];
1612  }
1613  double region_gini = bgfunc.loss_of_region(tmp_labels,
1614  parent.begin(),
1615  parent.end(),
1616  tmp_cc);
1617 
1618  int n = split.bestSplitColumn();
1619  ++numChoices[n];
1620  ++(*(numChoices.end()-1));
1621  //this functor does all the work
1622  for(int k = 0; k < features.shape(1); ++k)
1623  {
1624  bgfunc(columnVector(features, k),
1625  0,
1626  tmp_labels,
1627  parent.begin(), parent.end(),
1628  tmp_cc);
1629  wgini = (region_gini - bgfunc.min_gini_);
1630  gini_missc(n, k)
1631  += wgini;
1632  }
1633  for(int k = 0; k < 10; ++k)
1634  {
1635  bgfunc(columnVector(noise, k),
1636  0,
1637  tmp_labels,
1638  parent.begin(), parent.end(),
1639  tmp_cc);
1640  wgini = (region_gini - bgfunc.min_gini_);
1641  corr_noise(n, k)
1642  += wgini;
1643  }
1644 
1645  for(int k = 0; k < 10; ++k)
1646  {
1647  bgfunc(columnVector(noise_l, k),
1648  0,
1649  tmp_labels,
1650  parent.begin(), parent.end(),
1651  tmp_cc);
1652  wgini = (region_gini - bgfunc.min_gini_);
1653  corr_l(n, k)
1654  += wgini;
1655  }
1656  bgfunc(labels,0, tmp_labels, parent.begin(), parent.end(),tmp_cc);
1657  wgini = (region_gini - bgfunc.min_gini_);
1659  += wgini;
1660 
1661  region_gini = split.region_gini_;
1662 #if 1
1663  Node<i_ThresholdNode> node(split.createNode());
1665  node.column())
1666  +=split.region_gini_ - split.minGini();
1667 #endif
1668  for(int k = 0; k < 10; ++k)
1669  {
1670  split.bgfunc(columnVector(noise, k),
1671  0,
1672  labels,
1673  parent.begin(), parent.end(),
1674  parent.classCounts());
1676  k)
1677  += wgini;
1678  }
1679 #if 0
1680  for(int k = 0; k < tree.ext_param_.actual_mtry_; ++k)
1681  {
1682  wgini = region_gini - split.min_gini_[k];
1683 
1685  split.splitColumns[k])
1686  += wgini;
1687  }
1688 
1689  for(int k=tree.ext_param_.actual_mtry_; k<features.shape(1); ++k)
1690  {
1691  split.bgfunc(columnVector(features, split.splitColumns[k]),
1692  labels,
1693  parent.begin(), parent.end(),
1694  parent.classCounts());
1695  wgini = region_gini - split.bgfunc.min_gini_;
1697  split.splitColumns[k]) += wgini;
1698  }
1699 #endif
1700  // remember to partition the data according to the best.
1702  columnCount(gini_missc)-1)
1703  += region_gini;
1705  sorter(features, split.bestSplitColumn(), split.bestSplitThreshold());
1706  std::partition(parent.begin(), parent.end(), sorter);
1707  }
1708  }
1709 };
1710 
1711 
1712 } // namespace visitors
1713 } // namespace rf
1714 } // namespace vigra
1715 
1716 //@}
1717 #endif // RF_VISITORS_HXX

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.7.1 (Wed Mar 12 2014)