35 #ifndef RF_VISITORS_HXX
36 #define RF_VISITORS_HXX
39 # include "vigra/hdf5impex.hxx"
41 #include <vigra/windows.h>
44 #include <vigra/timing.hxx>
139 template<
class Tree,
class Split,
class Region,
class Feature_t,
class Label_t>
145 Feature_t & features,
158 template<
class RF,
class PR,
class SM,
class ST>
168 template<
class RF,
class PR>
178 template<
class RF,
class PR>
194 template<
class TR,
class IntT,
class TopT,
class Feat>
202 template<
class TR,
class IntT,
class TopT,
class Feat>
241 template <
class Visitor,
class Next = StopVisiting>
251 next_(next), visitor_(visitor)
256 next_(stop_), visitor_(visitor)
259 template<
class Tree,
class Split,
class Region,
class Feature_t,
class Label_t>
260 void visit_after_split( Tree & tree,
265 Feature_t & features,
268 if(visitor_.is_active())
269 visitor_.visit_after_split(tree, split,
270 parent, leftChild, rightChild,
272 next_.visit_after_split(tree, split, parent, leftChild, rightChild,
276 template<
class RF,
class PR,
class SM,
class ST>
277 void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st,
int index)
279 if(visitor_.is_active())
280 visitor_.visit_after_tree(rf, pr, sm, st, index);
281 next_.visit_after_tree(rf, pr, sm, st, index);
284 template<
class RF,
class PR>
285 void visit_at_beginning(RF & rf, PR & pr)
287 if(visitor_.is_active())
288 visitor_.visit_at_beginning(rf, pr);
289 next_.visit_at_beginning(rf, pr);
291 template<
class RF,
class PR>
292 void visit_at_end(RF & rf, PR & pr)
294 if(visitor_.is_active())
295 visitor_.visit_at_end(rf, pr);
296 next_.visit_at_end(rf, pr);
299 template<
class TR,
class IntT,
class TopT,
class Feat>
300 void visit_external_node(TR & tr, IntT & index, TopT & node_t,Feat & features)
302 if(visitor_.is_active())
303 visitor_.visit_external_node(tr, index, node_t,features);
304 next_.visit_external_node(tr, index, node_t,features);
306 template<
class TR,
class IntT,
class TopT,
class Feat>
307 void visit_internal_node(TR & tr, IntT & index, TopT & node_t,Feat & features)
309 if(visitor_.is_active())
310 visitor_.visit_internal_node(tr, index, node_t,features);
311 next_.visit_internal_node(tr, index, node_t,features);
316 if(visitor_.is_active() && visitor_.has_value())
317 return visitor_.return_val();
318 return next_.return_val();
342 template<
class A,
class B>
343 detail::VisitorNode<A, detail::VisitorNode<B> >
356 template<
class A,
class B,
class C>
357 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C> > >
372 template<
class A,
class B,
class C,
class D>
373 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
374 detail::VisitorNode<D> > > >
391 template<
class A,
class B,
class C,
class D,
class E>
392 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
393 detail::VisitorNode<D, detail::VisitorNode<E> > > > >
413 template<
class A,
class B,
class C,
class D,
class E,
415 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
416 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F> > > > > >
438 template<
class A,
class B,
class C,
class D,
class E,
440 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
441 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
442 detail::VisitorNode<G> > > > > > >
444 D & d, E & e, F & f, G & g)
466 template<
class A,
class B,
class C,
class D,
class E,
467 class F,
class G,
class H>
468 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
469 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
470 detail::VisitorNode<G, detail::VisitorNode<H> > > > > > > >
497 template<
class A,
class B,
class C,
class D,
class E,
498 class F,
class G,
class H,
class I>
499 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
500 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
501 detail::VisitorNode<G, detail::VisitorNode<H, detail::VisitorNode<I> > > > > > > > >
529 template<
class A,
class B,
class C,
class D,
class E,
530 class F,
class G,
class H,
class I,
class J>
531 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
532 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
533 detail::VisitorNode<G, detail::VisitorNode<H, detail::VisitorNode<I,
534 detail::VisitorNode<J> > > > > > > > > >
575 bool adjust_thresholds;
583 struct MarginalDistribution
586 Int32 leftTotalCounts;
588 Int32 rightTotalCounts;
595 struct TreeOnlineInformation
597 std::vector<MarginalDistribution> mag_distributions;
598 std::vector<IndexList> index_lists;
600 std::map<int,int> interior_to_index;
602 std::map<int,int> exterior_to_index;
606 std::vector<TreeOnlineInformation> trees_online_information;
610 template<
class RF,
class PR>
614 trees_online_information.resize(rf.options_.tree_count_);
621 trees_online_information[tree_id].mag_distributions.clear();
622 trees_online_information[tree_id].index_lists.clear();
623 trees_online_information[tree_id].interior_to_index.clear();
624 trees_online_information[tree_id].exterior_to_index.clear();
629 template<
class RF,
class PR,
class SM,
class ST>
635 template<
class Tree,
class Split,
class Region,
class Feature_t,
class Label_t>
641 Feature_t & features,
645 int addr=tree.topology_.size();
646 if(split.createNode().typeID() == i_ThresholdNode)
648 if(adjust_thresholds)
651 linear_index=trees_online_information[tree_id].mag_distributions.size();
652 trees_online_information[tree_id].interior_to_index[addr]=linear_index;
653 trees_online_information[tree_id].mag_distributions.push_back(MarginalDistribution());
655 trees_online_information[tree_id].mag_distributions.back().leftCounts=leftChild.classCounts_;
656 trees_online_information[tree_id].mag_distributions.back().rightCounts=rightChild.classCounts_;
658 trees_online_information[tree_id].mag_distributions.back().leftTotalCounts=leftChild.size_;
659 trees_online_information[tree_id].mag_distributions.back().rightTotalCounts=rightChild.size_;
661 double gap_left,gap_right;
663 gap_left=features(leftChild[0],split.bestSplitColumn());
664 for(i=1;i<leftChild.size();++i)
665 if(features(leftChild[i],split.bestSplitColumn())>gap_left)
666 gap_left=features(leftChild[i],split.bestSplitColumn());
667 gap_right=features(rightChild[0],split.bestSplitColumn());
668 for(i=1;i<rightChild.size();++i)
669 if(features(rightChild[i],split.bestSplitColumn())<gap_right)
670 gap_right=features(rightChild[i],split.bestSplitColumn());
671 trees_online_information[tree_id].mag_distributions.back().gap_left=gap_left;
672 trees_online_information[tree_id].mag_distributions.back().gap_right=gap_right;
678 linear_index=trees_online_information[tree_id].index_lists.size();
679 trees_online_information[tree_id].exterior_to_index[addr]=linear_index;
681 trees_online_information[tree_id].index_lists.push_back(
IndexList());
683 trees_online_information[tree_id].index_lists.back().resize(parent.size_,0);
684 std::copy(parent.begin_,parent.end_,trees_online_information[tree_id].index_lists.back().begin());
687 void add_to_index_list(
int tree,
int node,
int index)
691 TreeOnlineInformation &ti=trees_online_information[tree];
692 ti.index_lists[ti.exterior_to_index[node]].push_back(index);
694 void move_exterior_node(
int src_tree,
int src_index,
int dst_tree,
int dst_index)
698 trees_online_information[dst_tree].exterior_to_index[dst_index]=trees_online_information[src_tree].exterior_to_index[src_index];
699 trees_online_information[src_tree].exterior_to_index.erase(src_index);
706 template<
class TR,
class IntT,
class TopT,
class Feat>
710 if(adjust_thresholds)
712 vigra_assert(node_t==i_ThresholdNode,
"We can only visit threshold nodes");
714 double value=features(0, Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).column());
715 TreeOnlineInformation &ti=trees_online_information[tree_id];
716 MarginalDistribution &m=ti.mag_distributions[ti.interior_to_index[index]];
717 if(value>m.gap_left && value<m.gap_right)
720 if(m.leftCounts[current_label]/
double(m.leftTotalCounts)>m.rightCounts[current_label]/double(m.rightTotalCounts))
730 Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).threshold()=(m.gap_right+m.gap_left)/2.0;
733 if(value>Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).threshold())
735 ++m.rightTotalCounts;
736 ++m.rightCounts[current_label];
741 ++m.rightCounts[current_label];
789 template<
class RF,
class PR,
class SM,
class ST>
793 if(
int(oobCount.
size()) != rf.ext_param_.row_count_)
795 oobCount.resize(rf.ext_param_.row_count_, 0);
796 oobErrorCount.resize(rf.ext_param_.row_count_, 0);
799 for(
int l = 0; l < rf.ext_param_.row_count_; ++l)
806 .predictLabel(
rowVector(pr.features(), l))
807 != pr.response()(l,0))
818 template<
class RF,
class PR>
822 for(
int l=0; l < (int)rf.ext_param_.row_count_; ++l)
826 oobError += double(oobErrorCount[l]) / oobCount[l];
863 void save(std::string filen, std::string pathn)
865 if(*(pathn.end()-1) !=
'/')
867 const char* filename = filen.c_str();
870 writeHDF5(filename, (pathn +
"breiman_error").c_str(), temp);
875 template<
class RF,
class PR>
876 void visit_at_beginning(RF & rf, PR & pr)
878 class_count = rf.class_count();
879 tmp_prob.
reshape(Shp(1, class_count), 0);
880 prob_oob.
reshape(Shp(rf.ext_param().row_count_,class_count), 0);
881 is_weighted = rf.options().predict_weighted_;
882 indices.resize(rf.ext_param().row_count_);
883 if(
int(oobCount.
size()) != rf.ext_param_.row_count_)
885 oobCount.
reshape(Shp(rf.ext_param_.row_count_, 1), 0);
887 for(
int ii = 0; ii < rf.ext_param().row_count_; ++ii)
893 template<
class RF,
class PR,
class SM,
class ST>
902 if(rf.ext_param_.actual_msample_ < pr.features().shape(0) - 10000)
906 std::random_shuffle(indices.
begin(), indices.
end());
907 for(
int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
909 if(!sm.is_used()[indices[ii]] && cts[pr.response()(indices[ii], 0)] < 40000)
911 oob_indices.push_back(indices[ii]);
912 ++cts[pr.response()(indices[ii], 0)];
915 for(
unsigned int ll = 0; ll < oob_indices.
size(); ++ll)
918 ++oobCount[oob_indices[ll]];
923 int pos = rf.tree(index).getToLeaf(
rowVector(pr.features(),oob_indices[ll]));
925 rf.tree(index).parameters_,
928 for(
int ii = 0; ii < class_count; ++ii)
930 tmp_prob[ii] = node.prob_begin()[ii];
934 for(
int ii = 0; ii < class_count; ++ii)
935 tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1));
937 rowVector(prob_oob, oob_indices[ll]) += tmp_prob;
938 int label =
argMax(tmp_prob);
943 for(
int ll = 0; ll < rf.ext_param_.row_count_; ++ll)
946 if(!sm.is_used()[ll])
954 int pos = rf.tree(index).getToLeaf(
rowVector(pr.features(),ll));
956 rf.tree(index).parameters_,
959 for(
int ii = 0; ii < class_count; ++ii)
961 tmp_prob[ii] = node.prob_begin()[ii];
965 for(
int ii = 0; ii < class_count; ++ii)
966 tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1));
969 int label =
argMax(tmp_prob);
979 template<
class RF,
class PR>
983 int totalOobCount =0;
984 int breimanstyle = 0;
985 for(
int ll=0; ll < (int)rf.ext_param_.row_count_; ++ll)
1059 void save(std::string filen, std::string pathn)
1061 if(*(pathn.end()-1) !=
'/')
1063 const char* filename = filen.c_str();
1065 writeHDF5(filename, (pathn +
"oob_per_tree").c_str(),
oob_per_tree);
1066 writeHDF5(filename, (pathn +
"oobroc_per_tree").c_str(),
oobroc_per_tree);
1067 writeHDF5(filename, (pathn +
"breiman_per_tree").c_str(),
breiman_per_tree);
1069 writeHDF5(filename, (pathn +
"per_tree_error").c_str(), temp);
1071 writeHDF5(filename, (pathn +
"per_tree_error_std").c_str(), temp);
1073 writeHDF5(filename, (pathn +
"breiman_error").c_str(), temp);
1075 writeHDF5(filename, (pathn +
"ulli_error").c_str(), temp);
1080 template<
class RF,
class PR>
1081 void visit_at_beginning(RF & rf, PR & pr)
1083 class_count = rf.class_count();
1084 if(class_count == 2)
1088 tmp_prob.
reshape(Shp(1, class_count), 0);
1089 prob_oob.
reshape(Shp(rf.ext_param().row_count_,class_count), 0);
1090 is_weighted = rf.options().predict_weighted_;
1094 if(
int(oobCount.
size()) != rf.ext_param_.row_count_)
1096 oobCount.
reshape(Shp(rf.ext_param_.row_count_, 1), 0);
1097 oobErrorCount.
reshape(Shp(rf.ext_param_.row_count_,1), 0);
1101 template<
class RF,
class PR,
class SM,
class ST>
1107 for(
int ll = 0; ll < rf.ext_param_.row_count_; ++ll)
1110 if(!sm.is_used()[ll])
1118 int pos = rf.tree(index).getToLeaf(
rowVector(pr.features(),ll));
1120 rf.tree(index).parameters_,
1123 for(
int ii = 0; ii < class_count; ++ii)
1125 tmp_prob[ii] = node.prob_begin()[ii];
1129 for(
int ii = 0; ii < class_count; ++ii)
1130 tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1));
1133 int label =
argMax(tmp_prob);
1135 if(label != pr.response()(ll, 0))
1140 ++oobErrorCount[ll];
1144 int breimanstyle = 0;
1145 int totalOobCount = 0;
1146 for(
int ll=0; ll < (int)rf.ext_param_.row_count_; ++ll)
1165 for(
int gg = 0; gg < current_roc.
shape(2); ++gg)
1167 for(
int ll=0; ll < (int)rf.ext_param_.row_count_; ++ll)
1171 int pred = prob_oob(ll, 1) > (double(gg)/double(current_roc.
shape(2)))?
1173 current_roc(pr.response()(ll, 0), pred, gg)+= 1;
1176 current_roc.
bindOuter(gg)/= totalOobCount;
1180 oob_per_tree[index] = double(wrong_oob)/double(total_oob);
1186 template<
class RF,
class PR>
1191 int totalOobCount =0;
1192 int breimanstyle = 0;
1193 for(
int ll=0; ll < (int)rf.ext_param_.row_count_; ++ll)
1244 int repetition_count_;
1248 void save(std::string filename, std::string prefix)
1250 prefix =
"variable_importance_" + prefix;
1251 writeHDF5(filename.c_str(),
1262 : repetition_count_(rep_cnt)
1269 template<
class Tree,
class Split,
class Region,
class Feature_t,
class Label_t>
1274 Region & rightChild,
1275 Feature_t & features,
1280 Int32 const class_count = tree.ext_param_.class_count_;
1281 Int32 const column_count = tree.ext_param_.column_count_;
1290 if(split.createNode().typeID() == i_ThresholdNode)
1292 Node<i_ThresholdNode> node(split.createNode());
1294 += split.region_gini_ - split.minGini();
1304 template<
class RF,
class PR,
class SM,
class ST>
1308 Int32 column_count = rf.ext_param_.column_count_;
1309 Int32 class_count = rf.ext_param_.class_count_;
1319 typename PR::FeatureWithMemory_t features = pr.features();
1325 for(
int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
1326 if(!sm.is_used()[ii])
1327 oob_indices.push_back(ii);
1330 std::vector<double> backup_column;
1333 #ifdef CLASSIFIER_TEST
1344 oob_right(Shp_t(1, class_count + 1));
1346 perm_oob_right (Shp_t(1, class_count + 1));
1350 for(iter = oob_indices.
begin();
1351 iter != oob_indices.
end();
1355 .predictLabel(
rowVector(features, *iter))
1356 == pr.response()(*iter, 0))
1359 ++oob_right[pr.response()(*iter,0)];
1361 ++oob_right[class_count];
1365 for(
int ii = 0; ii < column_count; ++ii)
1367 perm_oob_right.
init(0.0);
1369 backup_column.clear();
1370 for(iter = oob_indices.
begin();
1371 iter != oob_indices.
end();
1374 backup_column.push_back(features(*iter,ii));
1378 for(
int rr = 0; rr < repetition_count_; ++rr)
1381 int n = oob_indices.
size();
1382 for(
int jj = 1; jj < n; ++jj)
1383 std::swap(features(oob_indices[jj], ii),
1384 features(oob_indices[randint(jj+1)], ii));
1387 for(iter = oob_indices.
begin();
1388 iter != oob_indices.
end();
1392 .predictLabel(
rowVector(features, *iter))
1393 == pr.response()(*iter, 0))
1396 ++perm_oob_right[pr.response()(*iter, 0)];
1398 ++perm_oob_right[class_count];
1405 perm_oob_right /= repetition_count_;
1406 perm_oob_right -=oob_right;
1407 perm_oob_right *= -1;
1408 perm_oob_right /= oob_indices.
size();
1411 Shp_t(ii+1,class_count+1)) += perm_oob_right;
1413 for(
int jj = 0; jj < int(oob_indices.
size()); ++jj)
1414 features(oob_indices[jj], ii) = backup_column[jj];
1423 template<
class RF,
class PR,
class SM,
class ST>
1431 template<
class RF,
class PR>
1444 template<
class RF,
class PR,
class SM,
class ST>
1446 if(index != rf.options().tree_count_-1) {
1447 std::cout <<
"\r[" << std::setw(10) << (index+1)/static_cast<double>(rf.options().tree_count_)*100 <<
"%]"
1448 <<
" (" << index+1 <<
" of " << rf.options().tree_count_ <<
") done" << std::flush;
1451 std::cout <<
"\r[" << std::setw(10) << 100.0 <<
"%]" << std::endl;
1455 template<
class RF,
class PR>
1457 std::string a = TOCS;
1458 std::cout <<
"all " << rf.options().tree_count_ <<
" trees have been learned in " << a << std::endl;
1461 template<
class RF,
class PR>
1464 std::cout <<
"growing random forest, which will have " << rf.options().tree_count_ <<
" trees" << std::endl;
1512 void save(std::string file, std::string prefix)
1529 template<
class RF,
class PR>
1530 void visit_at_beginning(RF
const & rf, PR & pr)
1533 int n = rf.ext_param_.column_count_;
1536 corr_l.
reshape(Shp(n +1, 10));
1539 noise_l.
reshape(Shp(pr.features().shape(0), 10));
1541 for(
int ii = 0; ii <
noise.
size(); ++ii)
1543 noise[ii] = random.uniform53();
1544 noise_l[ii] = random.uniform53() > 0.5;
1546 bgfunc = ColumnDecisionFunctor( rf.ext_param_);
1547 tmp_labels.
reshape(pr.response().shape());
1552 template<
class RF,
class PR>
1562 for(
int jj = 0; jj < rC-1; ++jj)
1567 for(
int jj = 0; jj < rC; ++jj)
1576 for(
int jj = 0; jj < rC; ++jj)
1583 for(
int jj = 0; jj < rC; ++jj)
1588 for(
int jj = 0; jj < rC; ++jj)
1594 template<
class Tree,
class Split,
class Region,
class Feature_t,
class Label_t>
1599 Region & rightChild,
1600 Feature_t & features,
1603 if(split.createNode().typeID() == i_ThresholdNode)
1607 for(
int ii = 0; ii < parent.size(); ++ii)
1609 tmp_labels[parent[ii]]
1610 = (features(parent[ii], split.bestSplitColumn()) < split.bestSplitThreshold());
1611 ++tmp_cc[tmp_labels[parent[ii]]];
1613 double region_gini = bgfunc.loss_of_region(tmp_labels,
1618 int n = split.bestSplitColumn();
1622 for(
int k = 0; k < features.shape(1); ++k)
1627 parent.
begin(), parent.end(),
1629 wgini = (region_gini - bgfunc.min_gini_);
1633 for(
int k = 0; k < 10; ++k)
1638 parent.
begin(), parent.end(),
1640 wgini = (region_gini - bgfunc.min_gini_);
1645 for(
int k = 0; k < 10; ++k)
1650 parent.
begin(), parent.end(),
1652 wgini = (region_gini - bgfunc.min_gini_);
1656 bgfunc(labels,0, tmp_labels, parent.
begin(), parent.end(),tmp_cc);
1657 wgini = (region_gini - bgfunc.min_gini_);
1661 region_gini = split.region_gini_;
1663 Node<i_ThresholdNode> node(split.createNode());
1666 +=split.region_gini_ - split.minGini();
1668 for(
int k = 0; k < 10; ++k)
1673 parent.begin(), parent.end(),
1674 parent.classCounts());
1680 for(
int k = 0; k < tree.ext_param_.actual_mtry_; ++k)
1682 wgini = region_gini - split.min_gini_[k];
1685 split.splitColumns[k])
1689 for(
int k=tree.ext_param_.actual_mtry_; k<features.shape(1); ++k)
1691 split.bgfunc(
columnVector(features, split.splitColumns[k]),
1693 parent.begin(), parent.end(),
1694 parent.classCounts());
1695 wgini = region_gini - split.bgfunc.min_gini_;
1697 split.splitColumns[k]) += wgini;
1705 sorter(features, split.bestSplitColumn(), split.bestSplitThreshold());
1706 std::partition(parent.begin(), parent.end(), sorter);
1717 #endif // RF_VISITORS_HXX