37 #ifndef VIGRA_RANDOM_FOREST_HXX
38 #define VIGRA_RANDOM_FOREST_HXX
46 #include "mathutil.hxx"
47 #include "array_vector.hxx"
48 #include "sized_int.hxx"
51 #include "functorexpression.hxx"
52 #include "random_forest/rf_common.hxx"
53 #include "random_forest/rf_nodeproxy.hxx"
54 #include "random_forest/rf_split.hxx"
55 #include "random_forest/rf_decisionTree.hxx"
56 #include "random_forest/rf_visitors.hxx"
57 #include "random_forest/rf_region.hxx"
58 #include "sampling.hxx"
59 #include "random_forest/rf_preprocessing.hxx"
60 #include "random_forest/rf_online_prediction_set.hxx"
61 #include "random_forest/rf_earlystopping.hxx"
62 #include "random_forest/rf_ridge_split.hxx"
82 inline SamplerOptions make_sampler_opt ( RandomForestOptions & RF_opt)
84 SamplerOptions return_opt;
86 return_opt.
stratified(RF_opt.stratification_method_ == RF_EQUAL);
126 template <
class LabelType =
double ,
class PreprocessorTag = ClassificationTag >
140 typedef LabelType LabelT;
215 template<
class TopologyIterator,
class ParameterIterator>
217 TopologyIterator topology_begin,
218 ParameterIterator parameter_begin,
223 ext_param_(problem_spec),
226 for(
unsigned int k=0; k<treeCount; ++k, ++topology_begin, ++parameter_begin)
228 trees_[k].topology_ = *topology_begin;
229 trees_[k].parameters_ = *parameter_begin;
248 vigra_precondition(ext_param_.used() ==
true,
249 "RandomForest::ext_param(): "
250 "Random forest has not been trained yet.");
266 vigra_precondition(ext_param_.used() ==
false,
267 "RandomForest::set_ext_param():"
268 "Random forest has been trained! Call reset()"
269 "before specifying new extrinsic parameters.");
295 return trees_[index];
302 return trees_[index];
312 return ext_param_.column_count_;
323 return ext_param_.column_count_;
331 return ext_param_.class_count_;
338 return options_.tree_count_;
343 template<
class U,
class C1,
356 bool adjust_thresholds=
false);
358 template <
class U,
class C1,
class U2,
class C2>
363 onlineLearn(features,
373 template<
class U,
class C1,
379 void reLearnTree(MultiArrayView<2,U,C1>
const & features,
380 MultiArrayView<2,U2,C2>
const & response,
387 template<
class U,
class C1,
class U2,
class C2>
388 void reLearnTree(MultiArrayView<2, U, C1>
const & features,
389 MultiArrayView<2, U2, C2>
const & labels,
392 RandomNumberGenerator<> rnd = RandomNumberGenerator<>(RandomSeed);
437 template <
class U,
class C1,
443 void learn( MultiArrayView<2, U, C1>
const & features,
444 MultiArrayView<2, U2,C2>
const & response,
448 Random_t
const & random);
450 template <
class U,
class C1,
455 void learn( MultiArrayView<2, U, C1>
const & features,
456 MultiArrayView<2, U2,C2>
const & response,
462 RandomNumberGenerator<> rnd = RandomNumberGenerator<>(RandomSeed);
471 template <
class U,
class C1,
class U2,
class C2,
class Visitor_t>
472 void learn( MultiArrayView<2, U, C1>
const & features,
473 MultiArrayView<2, U2,C2>
const & labels,
483 template <
class U,
class C1,
class U2,
class C2,
484 class Visitor_t,
class Split_t>
485 void learn( MultiArrayView<2, U, C1>
const & features,
486 MultiArrayView<2, U2,C2>
const & labels,
515 template <
class U,
class C1,
class U2,
class C2>
543 template <
class U,
class C,
class Stop>
546 template <
class U,
class C>
557 template <
class U,
class C>
558 LabelType
predictLabel(MultiArrayView<2, U, C>
const & features,
559 ArrayVectorView<double> prior)
const;
569 template <
class U,
class C1,
class T,
class C2>
573 vigra_precondition(features.
shape(0) == labels.
shape(0),
574 "RandomForest::predictLabels(): Label array has wrong size.");
575 for(
int k=0; k<features.
shape(0); ++k)
579 template <
class U,
class C1,
class T,
class C2,
class Stop>
584 vigra_precondition(features.
shape(0) == labels.
shape(0),
585 "RandomForest::predictLabels(): Label array has wrong size.");
586 for(
int k=0; k<features.
shape(0); ++k)
597 template <
class U,
class C1,
class T,
class C2,
class Stop>
599 MultiArrayView<2, T, C2> & prob,
601 template <
class T1,
class T2,
class C>
603 MultiArrayView<2, T2, C> & prob);
611 template <
class U,
class C1,
class T,
class C2>
624 template <
class LabelType,
class PreprocessorTag>
625 template<
class U,
class C1,
631 void RandomForest<LabelType, PreprocessorTag>::onlineLearn(MultiArrayView<2,U,C1>
const & features,
632 MultiArrayView<2,U2,C2>
const & response,
638 bool adjust_thresholds)
640 online_visitor_.activate();
641 online_visitor_.adjust_thresholds=adjust_thresholds;
645 typedef Processor<PreprocessorTag,LabelType,U,C1,U2,C2> Preprocessor_t;
646 typedef UniformIntRandomFunctor<Random_t>
653 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
654 Default_Stop_t default_stop(options_);
655 typename RF_CHOOSER(Stop_t)::type stop
656 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
657 Default_Split_t default_split;
658 typename RF_CHOOSER(Split_t)::type split
659 = RF_CHOOSER(Split_t)::choose(split_, default_split);
660 rf::visitors::StopVisiting stopvisiting;
661 typedef rf::visitors::detail::VisitorNode
662 <rf::visitors::OnlineLearnVisitor,
663 typename RF_CHOOSER(Visitor_t)::type>
666 visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
673 ext_param_.class_count_=0;
674 Preprocessor_t preprocessor( features, response,
675 options_, ext_param_);
678 RandFunctor_t randint ( random);
681 split.set_external_parameters(ext_param_);
682 stop.set_external_parameters(ext_param_);
686 PoissonSampler<RandomTT800> poisson_sampler(1.0,
vigra::Int32(new_start_index),
vigra::Int32(ext_param().row_count_));
692 for(
int ii = 0; ii < (int)trees_.size(); ++ii)
694 online_visitor_.tree_id=ii;
695 poisson_sampler.sample();
696 std::map<int,int> leaf_parents;
697 leaf_parents.clear();
699 for(
int s=0;s<poisson_sampler.numOfSamples();++s)
701 int sample=poisson_sampler[s];
702 online_visitor_.current_label=preprocessor.response()(sample,0);
703 online_visitor_.last_node_id=StackEntry_t::DecisionTreeNoParent;
704 int leaf=trees_[ii].getToLeaf(
rowVector(features,sample),online_visitor_);
708 online_visitor_.add_to_index_list(ii,leaf,sample);
711 if(Node<e_ConstProbNode>(trees_[ii].topology_,trees_[ii].parameters_,leaf).prob_begin()[preprocessor.response()(sample,0)]!=1.0)
713 leaf_parents[leaf]=online_visitor_.last_node_id;
718 std::map<int,int>::iterator leaf_iterator;
719 for(leaf_iterator=leaf_parents.begin();leaf_iterator!=leaf_parents.end();++leaf_iterator)
721 int leaf=leaf_iterator->first;
722 int parent=leaf_iterator->second;
723 int lin_index=online_visitor_.trees_online_information[ii].exterior_to_index[leaf];
724 ArrayVector<Int32> indeces;
726 indeces.swap(online_visitor_.trees_online_information[ii].index_lists[lin_index]);
727 StackEntry_t stack_entry(indeces.begin(),
729 ext_param_.class_count_);
734 if(NodeBase(trees_[ii].topology_,trees_[ii].parameters_,parent).child(0)==leaf)
736 stack_entry.leftParent=parent;
740 vigra_assert(NodeBase(trees_[ii].topology_,trees_[ii].parameters_,parent).child(1)==leaf,
"last_node_id seems to be wrong");
741 stack_entry.rightParent=parent;
745 trees_[ii].continueLearn(preprocessor.features(),preprocessor.response(),stack_entry,split,stop,visitor,randint,-1);
747 online_visitor_.move_exterior_node(ii,trees_[ii].topology_.size(),ii,leaf);
760 online_visitor_.deactivate();
763 template<
class LabelType,
class PreprocessorTag>
764 template<
class U,
class C1,
785 ext_param_.class_count_=0;
786 typedef Processor<PreprocessorTag,LabelType, U, C1, U2, C2> Preprocessor_t;
793 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
795 typename RF_CHOOSER(Stop_t)::type stop
796 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
798 typename RF_CHOOSER(Split_t)::type split
799 = RF_CHOOSER(Split_t)::choose(split_, default_split);
803 typename RF_CHOOSER(Visitor_t)::type> IntermedVis;
805 visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
807 vigra_precondition(options_.prepare_online_learning_,
"reLearnTree: Re learning trees only makes sense, if online learning is enabled");
808 online_visitor_.activate();
811 RandFunctor_t randint ( random);
817 Preprocessor_t preprocessor( features, response,
818 options_, ext_param_);
821 split.set_external_parameters(ext_param_);
822 stop.set_external_parameters(ext_param_);
829 preprocessor.strata().end(),
830 detail::make_sampler_opt(options_)
831 .sampleSize(ext_param().actual_msample_),
838 first_stack_entry( sampler.sampledIndices().begin(),
839 sampler.sampledIndices().end(),
840 ext_param_.class_count_);
842 .set_oob_range( sampler.oobIndices().begin(),
843 sampler.oobIndices().end());
844 online_visitor_.reset_tree(treeId);
845 online_visitor_.tree_id=treeId;
846 trees_[treeId].reset();
848 .learn( preprocessor.features(),
849 preprocessor.response(),
856 .visit_after_tree( *
this,
862 online_visitor_.deactivate();
865 template <
class LabelType,
class PreprocessorTag>
866 template <
class U,
class C1,
878 Random_t
const & random)
887 typedef Processor<PreprocessorTag,LabelType, U, C1, U2, C2> Preprocessor_t;
894 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
896 typename RF_CHOOSER(Stop_t)::type stop
897 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
899 typename RF_CHOOSER(Split_t)::type split
900 = RF_CHOOSER(Split_t)::choose(split_, default_split);
904 typename RF_CHOOSER(Visitor_t)::type> IntermedVis;
906 visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
908 if(options_.prepare_online_learning_)
909 online_visitor_.activate();
911 online_visitor_.deactivate();
915 RandFunctor_t randint ( random);
922 Preprocessor_t preprocessor( features, response,
923 options_, ext_param_);
926 split.set_external_parameters(ext_param_);
927 stop.set_external_parameters(ext_param_);
934 preprocessor.strata().end(),
935 detail::make_sampler_opt(options_)
936 .sampleSize(ext_param().actual_msample_),
939 visitor.visit_at_beginning(*
this, preprocessor);
942 for(
int ii = 0; ii < (int)trees_.size(); ++ii)
948 first_stack_entry( sampler.sampledIndices().begin(),
949 sampler.sampledIndices().end(),
950 ext_param_.class_count_);
952 .set_oob_range( sampler.oobIndices().begin(),
953 sampler.oobIndices().end());
955 .learn( preprocessor.features(),
956 preprocessor.response(),
963 .visit_after_tree( *
this,
970 visitor.visit_at_end(*
this, preprocessor);
972 online_visitor_.deactivate();
978 template <
class LabelType,
class Tag>
979 template <
class U,
class C,
class Stop>
983 vigra_precondition(
columnCount(features) >= ext_param_.column_count_,
984 "RandomForestn::predictLabel():"
985 " Too few columns in feature matrix.");
986 vigra_precondition(
rowCount(features) == 1,
987 "RandomForestn::predictLabel():"
988 " Feature matrix must have a singlerow.");
990 garbage_prediction_.reshape(Shp(1, ext_param_.class_count_), 0.0);
992 predictProbabilities(features, garbage_prediction_, stop);
993 ext_param_.to_classlabel(
argMax(garbage_prediction_), d);
999 template <
class LabelType,
class PreprocessorTag>
1000 template <
class U,
class C>
1005 using namespace functor;
1006 vigra_precondition(
columnCount(features) >= ext_param_.column_count_,
1007 "RandomForestn::predictLabel(): Too few columns in feature matrix.");
1008 vigra_precondition(
rowCount(features) == 1,
1009 "RandomForestn::predictLabel():"
1010 " Feature matrix must have a single row.");
1011 Matrix<double> prob(1,ext_param_.class_count_);
1012 predictProbabilities(features, prob);
1013 std::transform( prob.begin(), prob.end(),
1014 priors.
begin(), prob.begin(),
1017 ext_param_.to_classlabel(
argMax(prob), d);
1021 template<
class LabelType,
class PreprocessorTag>
1022 template <
class T1,
class T2,
class C>
1031 "RandomFroest::predictProbabilities():"
1032 " Feature matrix and probability matrix size misnmatch.");
1035 vigra_precondition(
columnCount(predictionSet.features) >= ext_param_.column_count_,
1036 "RandomForestn::predictProbabilities():"
1037 " Too few columns in feature matrix.");
1040 "RandomForestn::predictProbabilities():"
1041 " Probability matrix must have as many columns as there are classes.");
1044 std::vector<T1> totalWeights(predictionSet.indices[0].size(),0.0);
1047 for(
int k=0; k<options_.tree_count_; ++k)
1049 set_id=(set_id+1) % predictionSet.indices[0].size();
1050 typedef std::set<SampleRange<T1> > my_set;
1051 typedef typename my_set::iterator set_it;
1054 std::vector<std::pair<int,set_it> > stack;
1057 for(i=predictionSet.ranges[set_id].begin();i!=predictionSet.ranges[set_id].end();++i)
1058 stack.push_back(std::pair<int,set_it>(2,i));
1060 int num_decisions=0;
1061 while(!stack.empty())
1063 set_it range=stack.back().second;
1064 int index=stack.back().first;
1068 if(trees_[k].isLeafNode(trees_[k].topology_[index]))
1071 trees_[k].parameters_,
1072 index).prob_begin();
1073 for(
int i=range->start;i!=range->end;++i)
1076 for(
int l=0; l<ext_param_.class_count_; ++l)
1078 prob(predictionSet.indices[set_id][i], l) += (T2)weights[l];
1080 totalWeights[predictionSet.indices[set_id][i]] += (T1)weights[l];
1087 if(trees_[k].topology_[index]!=i_ThresholdNode)
1089 throw std::runtime_error(
"predicting with online prediction sets is only supported for RFs with threshold nodes");
1091 Node<i_ThresholdNode> node(trees_[k].topology_,trees_[k].parameters_,index);
1092 if(range->min_boundaries[node.column()]>=node.threshold())
1095 stack.push_back(std::pair<int,set_it>(node.child(1),range));
1098 if(range->max_boundaries[node.column()]<node.threshold())
1101 stack.push_back(std::pair<int,set_it>(node.child(0),range));
1105 SampleRange<T1> new_range=*range;
1106 new_range.min_boundaries[node.column()]=FLT_MAX;
1107 range->max_boundaries[node.column()]=-FLT_MAX;
1108 new_range.start=new_range.end=range->end;
1110 while(i!=range->end)
1113 if(predictionSet.features(predictionSet.indices[set_id][i],node.column())>=node.threshold())
1115 new_range.min_boundaries[node.column()]=std::min(new_range.min_boundaries[node.column()],
1116 predictionSet.features(predictionSet.indices[set_id][i],node.column()));
1119 std::swap(predictionSet.indices[set_id][i],predictionSet.indices[set_id][range->end]);
1124 range->max_boundaries[node.column()]=std::max(range->max_boundaries[node.column()],
1125 predictionSet.features(predictionSet.indices[set_id][i],node.column()));
1130 if(range->start==range->end)
1132 predictionSet.ranges[set_id].erase(range);
1136 stack.push_back(std::pair<int,set_it>(node.child(0),range));
1139 if(new_range.start!=new_range.end)
1141 std::pair<set_it,bool> new_it=predictionSet.ranges[set_id].insert(new_range);
1142 stack.push_back(std::pair<int,set_it>(node.child(1),new_it.first));
1146 predictionSet.cumulativePredTime[k]=num_decisions;
1148 for(
unsigned int i=0;i<totalWeights.size();++i)
1152 for(
int l=0; l<ext_param_.class_count_; ++l)
1155 prob(i, l) /= totalWeights[i];
1157 assert(test==totalWeights[i]);
1158 assert(totalWeights[i]>0.0);
1162 template <
class LabelType,
class PreprocessorTag>
1163 template <
class U,
class C1,
class T,
class C2,
class Stop_t>
1166 MultiArrayView<2, T, C2> & prob,
1167 Stop_t & stop_)
const
1173 "RandomForestn::predictProbabilities():"
1174 " Feature matrix and probability matrix size mismatch.");
1178 vigra_precondition(
columnCount(features) >= ext_param_.column_count_,
1179 "RandomForestn::predictProbabilities():"
1180 " Too few columns in feature matrix.");
1183 "RandomForestn::predictProbabilities():"
1184 " Probability matrix must have as many columns as there are classes.");
1186 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
1187 Default_Stop_t default_stop(options_);
1188 typename RF_CHOOSER(Stop_t)::type & stop
1189 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
1191 stop.set_external_parameters(ext_param_, tree_count());
1192 prob.init(NumericTraits<T>::zero());
1202 for(
int row=0; row <
rowCount(features); ++row)
1204 ArrayVector<double>::const_iterator weights;
1207 double totalWeight = 0.0;
1210 for(
int k=0; k<options_.tree_count_; ++k)
1213 weights = trees_[k ].predict(
rowVector(features, row));
1216 int weighted = options_.predict_weighted_;
1217 for(
int l=0; l<ext_param_.class_count_; ++l)
1219 double cur_w = weights[l] * (weighted * (*(weights-1))
1221 prob(row, l) += (T)cur_w;
1223 totalWeight += cur_w;
1225 if(stop.after_prediction(weights,
1235 for(
int l=0; l< ext_param_.class_count_; ++l)
1237 prob(row, l) /= detail::RequiresExplicitCast<T>::cast(totalWeight);
1247 #include "random_forest/rf_algorithm.hxx"
1248 #endif // VIGRA_RANDOM_FOREST_HXX