35 #ifndef VIGRA_RANDOM_FOREST_SPLIT_HXX
36 #define VIGRA_RANDOM_FOREST_SPLIT_HXX
42 #include "../mathutil.hxx"
43 #include "../array_vector.hxx"
44 #include "../sized_int.hxx"
45 #include "../matrix.hxx"
46 #include "../random.hxx"
47 #include "../functorexpression.hxx"
48 #include "rf_nodeproxy.hxx"
50 #include "rf_region.hxx"
59 class CompileTimeError;
69 static void exec(Iter begin, Iter end)
74 class Normalise<ClassificationTag>
78 static void exec (Iter begin, Iter end)
80 double bla = std::accumulate(begin, end, 0.0);
81 for(
int ii = 0; ii < end - begin; ++ii)
82 begin[ii] = begin[ii]/bla ;
115 t_data.push_back(in.column_count_);
116 t_data.push_back(in.class_count_);
124 int classCount()
const
126 return int(t_data[1]);
129 int featureCount()
const
131 return int(t_data[0]);
149 template<
class T,
class C,
class T2,
class C2,
class Region,
class Random>
156 CompileTimeError SplitFunctor__findBestSplit_member_was_not_defined;
164 template<
class T,
class C,
class T2,
class C2,
class Region,
class Random>
172 if(ext_param_.class_weights_.
size() != region.classCounts().size())
174 std::copy( region.classCounts().begin(),
175 region.classCounts().end(),
180 std::transform( region.classCounts().begin(),
181 region.classCounts().end(),
182 ext_param_.class_weights_.
begin(),
183 ret.prob_begin(), std::multiplies<double>());
185 detail::Normalise<RF_Tag>::exec(ret.prob_begin(), ret.prob_end());
187 return e_ConstProbNode;
195 template<
class DataMatrix>
198 DataMatrix
const & data_;
205 double thresVal = 0.0)
207 sortColumn_(sortColumn),
213 sortColumn_ = sortColumn;
215 void setThreshold(
double value)
222 return data_(l, sortColumn_) < data_(r, sortColumn_);
226 return data_(l, sortColumn_) < thresVal_;
230 template<
class DataMatrix>
231 class DimensionNotEqual
233 DataMatrix
const & data_;
238 DimensionNotEqual(DataMatrix
const & data,
241 sortColumn_(sortColumn)
246 sortColumn_ = sortColumn;
251 return data_(l, sortColumn_) != data_(r, sortColumn_);
255 template<
class DataMatrix>
256 class SortSamplesByHyperplane
258 DataMatrix
const & data_;
259 Node<i_HyperplaneNode>
const & node_;
263 SortSamplesByHyperplane(DataMatrix
const & data,
264 Node<i_HyperplaneNode>
const & node)
274 double result_l = -1 * node_.intercept();
275 for(
int ii = 0; ii < node_.columns_size(); ++ii)
277 result_l +=
rowVector(data_, l)[node_.columns_begin()[ii]]
278 * node_.weights()[ii];
285 return (*
this)[l] < (*this)[r];
299 template <
class DataSource,
class CountArray>
302 DataSource
const & labels_;
303 CountArray & counts_;
322 counts_[labels_[l]] +=1;
337 double operator[](
size_t)
const
357 template<
class Array,
class Array2>
359 Array2
const & weights,
360 double total = 1.0)
const
362 return impurity(hist, weights, total);
367 template<
class Array>
368 double operator()(Array
const & hist,
double total = 1.0)
const
375 template<
class Array>
376 static double impurity(Array
const & hist,
double total)
378 return impurity(hist, detail::ConstArr<1>(), total);
383 template<
class Array,
class Array2>
385 Array2
const & weights,
389 int class_count = hist.size();
390 double entropy = 0.0;
393 double p0 = (hist[0]/total);
394 double p1 = (hist[1]/total);
399 for(
int ii = 0; ii < class_count; ++ii)
401 double w = weights[ii];
402 double pii = hist[ii]/total;
406 entropy = total * entropy;
419 template<
class Array,
class Array2>
421 Array2
const & weights,
422 double total = 1.0)
const
424 return impurity(hist, weights, total);
429 template<
class Array>
430 double operator()(Array
const & hist,
double total = 1.0)
const
437 template<
class Array>
438 static double impurity(Array
const & hist,
double total)
440 return impurity(hist, detail::ConstArr<1>(), total);
445 template<
class Array,
class Array2>
447 Array2
const & weights,
451 int class_count = hist.size();
455 double w = weights[0] * weights[1];
456 gini = w * (hist[0] * hist[1] / total);
460 for(
int ii = 0; ii < class_count; ++ii)
462 double w = weights[ii];
463 gini += w*( hist[ii]*( 1.0 - w * hist[ii]/total ) );
471 template <
class DataSource,
class Impurity= GiniCriterion>
475 DataSource
const & labels_;
476 ArrayVector<double> counts_;
477 ArrayVector<double>
const class_weights_;
478 double total_counts_;
484 ImpurityLoss(DataSource
const & labels,
485 ProblemSpec<T>
const & ext_)
487 counts_(ext_.class_count_, 0.0),
488 class_weights_(ext_.class_weights_),
498 template<
class Counts>
499 double increment_histogram(Counts
const & counts)
501 std::transform(counts.begin(), counts.end(),
502 counts_.begin(), counts_.begin(),
503 std::plus<double>());
504 total_counts_ = std::accumulate( counts_.begin(),
507 return impurity_(counts_, class_weights_, total_counts_);
510 template<
class Counts>
511 double decrement_histogram(Counts
const & counts)
513 std::transform(counts.begin(), counts.end(),
514 counts_.begin(), counts_.begin(),
515 std::minus<double>());
516 total_counts_ = std::accumulate( counts_.begin(),
519 return impurity_(counts_, class_weights_, total_counts_);
523 double increment(Iter begin, Iter end)
525 for(Iter iter = begin; iter != end; ++iter)
527 counts_[labels_(*iter, 0)] +=1.0;
530 return impurity_(counts_, class_weights_, total_counts_);
534 double decrement(Iter
const & begin, Iter
const & end)
536 for(Iter iter = begin; iter != end; ++iter)
538 counts_[labels_(*iter,0)] -=1.0;
541 return impurity_(counts_, class_weights_, total_counts_);
544 template<
class Iter,
class Resp_t>
545 double init (Iter begin, Iter end, Resp_t resp)
548 std::copy(resp.begin(), resp.end(), counts_.begin());
549 total_counts_ = std::accumulate(counts_.begin(), counts_.end(), 0.0);
550 return impurity_(counts_,class_weights_, total_counts_);
553 ArrayVector<double>
const & response()
559 template <
class DataSource>
560 class RegressionForestCounter
562 typedef MultiArrayShape<2>::type Shp;
563 DataSource
const & labels_;
564 ArrayVector <double> mean_;
565 ArrayVector <double> variance_;
566 ArrayVector <double> tmp_;
570 RegressionForestCounter(DataSource
const & labels,
571 ProblemSpec<T>
const & ext_)
574 mean_(ext_.response_size, 0.0),
575 variance_(ext_.response_size, 0.0),
576 tmp_(ext_.response_size),
583 double increment (Iter begin, Iter end)
585 for(Iter iter = begin; iter != end; ++iter)
588 for(
int ii = 0; ii < mean_.size(); ++ii)
589 tmp_[ii] = labels_(*iter, ii) - mean_[ii];
590 double f = 1.0 / count_,
592 for(
int ii = 0; ii < mean_.size(); ++ii)
593 mean_[ii] += f*tmp_[ii];
594 for(
int ii = 0; ii < mean_.size(); ++ii)
595 variance_[ii] += f1*
sq(tmp_[ii]);
597 return std::accumulate(variance_.begin(),
605 double decrement (Iter begin, Iter end)
607 for(Iter iter = begin; iter != end; ++iter)
610 for(
int ii = 0; ii < mean_.size(); ++ii)
611 tmp_[ii] = labels_(*iter, ii) - mean_[ii];
612 double f = 1.0 / count_,
614 for(
int ii = 0; ii < mean_.size(); ++ii)
615 mean_[ii] -= f*tmp_[ii];
616 for(
int ii = 0; ii < mean_.size(); ++ii)
617 variance_[ii] -= f1*
sq(tmp_[ii]);
619 return std::accumulate(variance_.begin(),
626 template<
class Iter,
class Resp_t>
627 double init (Iter begin, Iter end, Resp_t resp)
630 return increment(begin, end);
634 ArrayVector<double>
const & response()
647 template<
class Tag,
class Datatyp>
653 template<
class Datatype>
654 struct LossTraits<GiniCriterion, Datatype>
656 typedef ImpurityLoss<Datatype, GiniCriterion> type;
659 template<
class Datatype>
660 struct LossTraits<EntropyCriterion, Datatype>
662 typedef ImpurityLoss<Datatype, EntropyCriterion> type;
665 template<
class Datatype>
666 struct LossTraits<LSQLoss, Datatype>
668 typedef RegressionForestCounter<Datatype> type;
673 template<
class LineSearchLossTag>
680 ptrdiff_t min_index_;
681 double min_threshold_;
690 class_weights_(ext.class_weights_),
693 bestCurrentCounts[0].resize(ext.class_count_);
694 bestCurrentCounts[1].resize(ext.class_count_);
699 class_weights_ = ext.class_weights_;
701 bestCurrentCounts[0].resize(ext.class_count_);
702 bestCurrentCounts[1].resize(ext.class_count_);
730 template<
class DataSourceF_t,
736 DataSource_t
const & labels,
739 Array
const & region_response)
741 std::sort(begin, end,
744 LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
745 LineSearchLoss left(labels, ext_param_);
746 LineSearchLoss right(labels, ext_param_);
750 min_gini_ = right.init(begin, end, region_response);
751 min_threshold_ = *begin;
753 DimensionNotEqual<DataSourceF_t> comp(column, g);
756 I_Iter next = std::adjacent_find(iter, end, comp);
760 double loss = right.decrement(iter, next + 1)
761 + left.increment(iter , next + 1);
762 #ifdef CLASSIFIER_TEST
765 if(loss < min_gini_ )
768 bestCurrentCounts[0] = left.response();
769 bestCurrentCounts[1] = right.response();
770 #ifdef CLASSIFIER_TEST
771 min_gini_ = loss < min_gini_? loss : min_gini_;
775 min_index_ = next - begin +1 ;
776 min_threshold_ = (double(column(*next,g)) + double(column(*(next +1), g)))/2.0;
779 next = std::adjacent_find(iter, end, comp);
783 template<
class DataSource_t,
class Iter,
class Array>
784 double loss_of_region(DataSource_t
const & labels,
787 Array
const & region_response)
const
790 LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
791 LineSearchLoss region_loss(labels, ext_param_);
793 region_loss.init(begin, end, region_response);
802 template<
class ColumnDecisionFunctor,
class Tag = ClassificationTag>
811 ColumnDecisionFunctor bgfunc;
820 double minGini()
const
822 return min_gini_[bestSplitIndex];
824 int bestSplitColumn()
const
826 return splitColumns[bestSplitIndex];
828 double bestSplitThreshold()
const
830 return min_thresholds_[bestSplitIndex];
837 bgfunc.set_external_parameters( SB::ext_param_);
838 int featureCount_ = SB::ext_param_.column_count_;
839 splitColumns.resize(featureCount_);
840 for(
int k=0; k<featureCount_; ++k)
842 min_gini_.resize(featureCount_);
843 min_indices_.resize(featureCount_);
844 min_thresholds_.resize(featureCount_);
848 template<
class T,
class C,
class T2,
class C2,
class Region,
class Random>
856 typedef typename Region::IndexIterator IndexIterator;
857 if(region.size() == 0)
859 std::cerr <<
"SplitFunctor::findBestSplit(): stackentry with 0 examples encountered\n"
860 "continuing learning process....";
864 if(std::accumulate(region.classCounts().begin(),
865 region.classCounts().end(), 0) != region.size())
867 RandomForestClassCounter< MultiArrayView<2,T2, C2>,
868 ArrayVector<double> >
869 counter(labels, region.classCounts());
870 std::for_each( region.begin(), region.end(), counter);
871 region.classCountsIsValid =
true;
875 region_gini_ = bgfunc.loss_of_region(labels,
878 region.classCounts());
879 if(region_gini_ <= SB::ext_param_.precision_)
883 for(
int ii = 0; ii < SB::ext_param_.actual_mtry_; ++ii)
884 std::swap(splitColumns[ii],
885 splitColumns[ii+ randint(features.
shape(1) - ii)]);
889 double current_min_gini = region_gini_;
890 int num2try = features.
shape(1);
891 for(
int k=0; k<num2try; ++k)
897 region.begin(), region.end(),
898 region.classCounts());
899 min_gini_[k] = bgfunc.min_gini_;
900 min_indices_[k] = bgfunc.min_index_;
901 min_thresholds_[k] = bgfunc.min_threshold_;
902 #ifdef CLASSIFIER_TEST
903 if( bgfunc.min_gini_ < current_min_gini
906 if(bgfunc.min_gini_ < current_min_gini)
909 current_min_gini = bgfunc.min_gini_;
910 childRegions[0].classCounts() = bgfunc.bestCurrentCounts[0];
911 childRegions[1].classCounts() = bgfunc.bestCurrentCounts[1];
912 childRegions[0].classCountsIsValid =
true;
913 childRegions[1].classCountsIsValid =
true;
916 num2try = SB::ext_param_.actual_mtry_;
925 Node<i_ThresholdNode> node(SB::t_data, SB::p_data);
927 node.threshold() = min_thresholds_[bestSplitIndex];
928 node.column() = splitColumns[bestSplitIndex];
931 SortSamplesByDimensions<MultiArrayView<2, T, C> >
932 sorter(features, node.column(), node.threshold());
933 IndexIterator bestSplit =
934 std::partition(region.begin(), region.end(), sorter);
936 childRegions[0].setRange( region.begin() , bestSplit );
937 childRegions[0].rule = region.rule;
938 childRegions[0].rule.push_back(std::make_pair(1, 1.0));
939 childRegions[1].setRange( bestSplit , region.end() );
940 childRegions[1].rule = region.rule;
941 childRegions[1].rule.push_back(std::make_pair(1, 1.0));
943 return i_ThresholdNode;
947 typedef ThresholdSplit<BestGiniOfColumn<GiniCriterion> > GiniSplit;
948 typedef ThresholdSplit<BestGiniOfColumn<EntropyCriterion> > EntropySplit;
949 typedef ThresholdSplit<BestGiniOfColumn<LSQLoss>, RegressionTag> RegressionSplit;
988 ptrdiff_t min_index_;
989 double min_threshold_;
998 class_weights_(ext.class_weights_),
1001 bestCurrentCounts[0].resize(ext.class_count_);
1002 bestCurrentCounts[1].resize(ext.class_count_);
1008 class_weights_ = ext.class_weights_;
1010 bestCurrentCounts[0].resize(ext.class_count_);
1011 bestCurrentCounts[1].resize(ext.class_count_);
1014 template<
class DataSourceF_t,
1018 void operator()(DataSourceF_t
const & column,
1019 DataSource_t
const & labels,
1022 Array
const & region_response)
1024 std::sort(begin, end,
1027 LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
1028 LineSearchLoss left(labels, ext_param_);
1029 LineSearchLoss right(labels, ext_param_);
1030 right.init(begin, end, region_response);
1032 min_gini_ = NumericTraits<double>::max();
1033 min_index_ =
floor(
double(end - begin)/2.0);
1034 min_threshold_ = column[*(begin + min_index_)];
1036 sorter(column, 0, min_threshold_);
1037 I_Iter part = std::partition(begin, end, sorter);
1038 DimensionNotEqual<DataSourceF_t> comp(column, 0);
1041 part= std::adjacent_find(part, end, comp)+1;
1050 min_threshold_ = column[*part];
1052 min_gini_ = right.decrement(begin, part)
1053 + left.increment(begin , part);
1055 bestCurrentCounts[0] = left.response();
1056 bestCurrentCounts[1] = right.response();
1058 min_index_ = part - begin;
1061 template<
class DataSource_t,
class Iter,
class Array>
1062 double loss_of_region(DataSource_t
const & labels,
1065 Array
const & region_response)
const
1068 LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
1069 LineSearchLoss region_loss(labels, ext_param_);
1071 region_loss.init(begin, end, region_response);
1089 ptrdiff_t min_index_;
1090 double min_threshold_;
1101 class_weights_(ext.class_weights_),
1105 bestCurrentCounts[0].resize(ext.class_count_);
1106 bestCurrentCounts[1].resize(ext.class_count_);
1112 class_weights_(ext.class_weights_),
1116 bestCurrentCounts[0].resize(ext.class_count_);
1117 bestCurrentCounts[1].resize(ext.class_count_);
1123 class_weights_ = ext.class_weights_;
1125 bestCurrentCounts[0].resize(ext.class_count_);
1126 bestCurrentCounts[1].resize(ext.class_count_);
1129 template<
class DataSourceF_t,
1133 void operator()(DataSourceF_t
const & column,
1134 DataSource_t
const & labels,
1137 Array
const & region_response)
1139 std::sort(begin, end,
1142 LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
1143 LineSearchLoss left(labels, ext_param_);
1144 LineSearchLoss right(labels, ext_param_);
1145 right.init(begin, end, region_response);
1148 min_gini_ = NumericTraits<double>::max();
1150 min_index_ = begin + random.
uniformInt(end -begin);
1151 min_threshold_ = column[*(begin + min_index_)];
1153 sorter(column, 0, min_threshold_);
1154 I_Iter part = std::partition(begin, end, sorter);
1155 DimensionNotEqual<DataSourceF_t> comp(column, 0);
1158 part= std::adjacent_find(part, end, comp)+1;
1167 min_threshold_ = column[*part];
1169 min_gini_ = right.decrement(begin, part)
1170 + left.increment(begin , part);
1172 bestCurrentCounts[0] = left.response();
1173 bestCurrentCounts[1] = right.response();
1175 min_index_ = part - begin;
1178 template<
class DataSource_t,
class Iter,
class Array>
1179 double loss_of_region(DataSource_t
const & labels,
1182 Array
const & region_response)
const
1185 LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
1186 LineSearchLoss region_loss(labels, ext_param_);
1188 region_loss.init(begin, end, region_response);
1199 #endif // VIGRA_RANDOM_FOREST_SPLIT_HXX