35 #define VIGRA_RF_ALGORTIHM_HXX
38 #include "splices.hxx"
57 template<
class OrigMultiArray,
60 void choose(OrigMultiArray
const & in,
69 for(Iter iter = b; iter != e; ++iter, ++ii)
99 template<
class Feature_t,
class Response_t>
101 Response_t
const & response)
124 typedef std::vector<int> FeatureList_t;
125 typedef std::vector<double> ErrorList_t;
126 typedef FeatureList_t::iterator Pivot_t;
152 template<
class FeatureT,
155 class ErrorRateCallBack>
156 bool init(FeatureT
const & all_features,
157 ResponseT
const & response,
160 ErrorRateCallBack errorcallback)
162 bool ret_ = init(all_features, response, errorcallback);
165 vigra_precondition(std::distance(b, e) ==
selected.size(),
166 "Number of features in ranking != number of features matrix");
171 template<
class FeatureT,
174 bool init(FeatureT
const & all_features,
175 ResponseT
const & response,
180 return init(all_features, response, b, e, ecallback);
184 template<
class FeatureT,
186 bool init(FeatureT
const & all_features,
187 ResponseT
const & response)
189 return init(all_features, response, RFErrorCallback());
201 template<
class FeatureT,
203 class ErrorRateCallBack>
204 bool init(FeatureT
const & all_features,
205 ResponseT
const & response,
206 ErrorRateCallBack errorcallback)
213 selected.resize(all_features.shape(1), 0);
214 for(
unsigned int ii = 0; ii <
selected.size(); ++ii)
216 errors.resize(all_features.shape(1), -1);
217 errors.back() = errorcallback(all_features, response);
221 std::map<typename ResponseT::value_type, int> res_map;
222 std::vector<int> cts;
224 for(
int ii = 0; ii < response.shape(0); ++ii)
226 if(res_map.find(response(ii, 0)) == res_map.end())
228 res_map[response(ii, 0)] = counter;
232 cts[res_map[response(ii,0)]] +=1;
234 no_features = double(*(std::max_element(cts.begin(),
236 /
double(response.shape(0));
291 template<
class FeatureT,
class ResponseT,
class ErrorRateCallBack>
293 ResponseT
const & response,
295 ErrorRateCallBack errorcallback)
297 VariableSelectionResult::FeatureList_t & selected = result.
selected;
298 VariableSelectionResult::ErrorList_t & errors = result.
errors;
299 VariableSelectionResult::Pivot_t & pivot = result.pivot;
300 int featureCount = features.shape(1);
302 if(!result.init(features, response, errorcallback))
306 vigra_precondition(selected.size() == featureCount,
307 "forward_selection(): Number of features in Feature "
308 "matrix and number of features in previously used "
309 "result struct mismatch!");
313 int not_selected_size = std::distance(pivot, selected.end());
314 while(not_selected_size > 1)
316 std::vector<int> current_errors;
317 VariableSelectionResult::Pivot_t next = pivot;
318 for(
int ii = 0; ii < not_selected_size; ++ii, ++next)
320 std::swap(*pivot, *next);
322 detail::choose( features,
326 double error = errorcallback(cur_feats, response);
327 current_errors.push_back(error);
328 std::swap(*pivot, *next);
330 int pos = std::distance(current_errors.begin(),
331 std::min_element(current_errors.begin(),
332 current_errors.end()));
334 std::advance(next, pos);
335 std::swap(*pivot, *next);
336 errors[std::distance(selected.begin(), pivot)] = current_errors[pos];
338 not_selected_size = std::distance(pivot, selected.end());
341 template<
class FeatureT,
class ResponseT>
343 ResponseT
const & response,
344 VariableSelectionResult & result)
389 template<
class FeatureT,
class ResponseT,
class ErrorRateCallBack>
391 ResponseT
const & response,
393 ErrorRateCallBack errorcallback)
395 int featureCount = features.shape(1);
396 VariableSelectionResult::FeatureList_t & selected = result.
selected;
397 VariableSelectionResult::ErrorList_t & errors = result.
errors;
398 VariableSelectionResult::Pivot_t & pivot = result.pivot;
401 if(!result.init(features, response, errorcallback))
405 vigra_precondition(selected.size() == featureCount,
406 "backward_elimination(): Number of features in Feature "
407 "matrix and number of features in previously used "
408 "result struct mismatch!");
410 pivot = selected.end() - 1;
412 int selected_size = std::distance(selected.begin(), pivot);
413 while(selected_size > 1)
415 VariableSelectionResult::Pivot_t next = selected.begin();
416 std::vector<int> current_errors;
417 for(
int ii = 0; ii < selected_size; ++ii, ++next)
419 std::swap(*pivot, *next);
421 detail::choose( features,
425 double error = errorcallback(cur_feats, response);
426 current_errors.push_back(error);
427 std::swap(*pivot, *next);
429 int pos = std::distance(current_errors.begin(),
430 std::max_element(current_errors.begin(),
431 current_errors.end()));
432 next = selected.begin();
433 std::advance(next, pos);
434 std::swap(*pivot, *next);
436 errors[std::distance(selected.begin(), pivot)] = current_errors[pos];
437 selected_size = std::distance(selected.begin(), pivot);
442 template<
class FeatureT,
class ResponseT>
444 ResponseT
const & response,
445 VariableSelectionResult & result)
482 template<
class FeatureT,
class ResponseT,
class ErrorRateCallBack>
484 ResponseT
const & response,
486 ErrorRateCallBack errorcallback)
488 VariableSelectionResult::FeatureList_t & selected = result.
selected;
489 VariableSelectionResult::ErrorList_t & errors = result.
errors;
490 VariableSelectionResult::Pivot_t & iter = result.pivot;
491 int featureCount = features.shape(1);
493 if(!result.init(features, response, errorcallback))
497 vigra_precondition(selected.size() == featureCount,
498 "forward_selection(): Number of features in Feature "
499 "matrix and number of features in previously used "
500 "result struct mismatch!");
504 for(; iter != selected.end(); ++iter)
509 detail::choose( features,
513 double error = errorcallback(cur_feats, response);
514 errors[std::distance(selected.begin(), iter)] = error;
519 template<
class FeatureT,
class ResponseT>
521 ResponseT
const & response,
522 VariableSelectionResult & result)
529 enum ClusterLeafTypes{c_Leaf = 95, c_Node = 99};
544 ClusterNode():NodeBase(){}
545 ClusterNode(
int nCol,
546 BT::T_Container_type & topology,
547 BT::P_Container_type & split_param)
548 : BT(nCol + 5, 5,topology, split_param)
558 ClusterNode( BT::T_Container_type
const & topology,
559 BT::P_Container_type
const & split_param,
561 :
NodeBase(5 , 5,topology, split_param, n)
567 ClusterNode( BT & node_)
572 BT::parameter_size_ += 0;
578 void set_index(
int in)
604 HC_Entry(
int p,
int l,
int a,
bool in)
605 : parent(p), level(l), addr(a), infm(in)
634 double dist_func(
double a,
double b)
636 return std::min(a, b);
642 template<
class Functor>
646 std::vector<int> stack;
647 stack.push_back(begin_addr);
648 while(!stack.empty())
650 ClusterNode node(topology_, parameters_, stack.
back());
654 if(node.columns_size() != 1)
656 stack.push_back(node.child(0));
657 stack.push_back(node.child(1));
665 template<
class Functor>
669 std::queue<HC_Entry> queue;
674 queue.push(
HC_Entry(parent,level,begin_addr, infm));
675 while(!queue.empty())
677 level = queue.front().level;
678 parent = queue.front().parent;
679 addr = queue.front().addr;
680 infm = queue.front().infm;
681 ClusterNode node(topology_, parameters_, queue.
front().addr);
685 parnt = ClusterNode(topology_, parameters_, parent);
688 bool istrue = tester(node, level, parnt, infm);
689 if(node.columns_size() != 1)
691 queue.push(
HC_Entry(addr, level +1,node.child(0),istrue));
692 queue.push(
HC_Entry(addr, level +1,node.child(1),istrue));
698 void save(std::string file, std::string prefix)
701 vigra::writeHDF5(file.c_str(), (prefix +
"topology").c_str(),
705 vigra::writeHDF5(file.c_str(), (prefix +
"parameters").c_str(),
708 parameters_.
data()));
709 vigra::writeHDF5(file.c_str(), (prefix +
"begin_addr").c_str(),
717 template<
class T,
class C>
721 std::vector<std::pair<int, int> > addr;
722 typedef std::pair<int, int> Entry;
724 for(
int ii = 0; ii < distance.
shape(0); ++ii)
726 addr.push_back(std::make_pair(topology_.
size(), ii));
727 ClusterNode leaf(1, topology_, parameters_);
728 leaf.set_index(index);
730 leaf.columns_begin()[0] = ii;
733 while(addr.size() != 1)
738 double min_dist = dist((addr.begin()+ii_min)->second,
739 (addr.begin()+jj_min)->second);
740 for(
unsigned int ii = 0; ii < addr.size(); ++ii)
742 for(
unsigned int jj = ii+1; jj < addr.size(); ++jj)
744 if( dist((addr.begin()+ii_min)->second,
745 (addr.begin()+jj_min)->second)
746 > dist((addr.begin()+ii)->second,
747 (addr.begin()+jj)->second))
749 min_dist = dist((addr.begin()+ii)->second,
750 (addr.begin()+jj)->second);
762 ClusterNode firstChild(topology_,
764 (addr.begin() +ii_min)->first);
765 ClusterNode secondChild(topology_,
767 (addr.begin() +jj_min)->first);
768 col_size = firstChild.columns_size() + secondChild.columns_size();
770 int cur_addr = topology_.
size();
771 begin_addr = cur_addr;
773 ClusterNode parent(col_size,
776 ClusterNode firstChild(topology_,
778 (addr.begin() +ii_min)->first);
779 ClusterNode secondChild(topology_,
781 (addr.begin() +jj_min)->first);
782 parent.parameters_begin()[0] = min_dist;
783 parent.set_index(index);
785 std::merge(firstChild.columns_begin(), firstChild.columns_end(),
786 secondChild.columns_begin(),secondChild.columns_end(),
787 parent.columns_begin());
792 if(*parent.columns_begin() == *firstChild.columns_begin())
794 parent.child(0) = (addr.begin()+ii_min)->first;
795 parent.child(1) = (addr.begin()+jj_min)->first;
796 (addr.begin()+ii_min)->first = cur_addr;
798 to_keep = (addr.begin()+ii_min)->second;
799 to_desc = (addr.begin()+jj_min)->second;
800 addr.erase(addr.begin()+jj_min);
804 parent.child(1) = (addr.begin()+ii_min)->first;
805 parent.child(0) = (addr.begin()+jj_min)->first;
806 (addr.begin()+jj_min)->first = cur_addr;
808 to_keep = (addr.begin()+jj_min)->second;
809 to_desc = (addr.begin()+ii_min)->second;
810 addr.erase(addr.begin()+ii_min);
814 for(
unsigned int jj = 0 ; jj < addr.size(); ++jj)
818 double bla = dist_func(
819 dist(to_desc, (addr.begin()+jj)->second),
820 dist((addr.begin()+ii_keep)->second,
821 (addr.begin()+jj)->second));
823 dist((addr.begin()+ii_keep)->second,
824 (addr.begin()+jj)->second) = bla;
825 dist((addr.begin()+jj)->second,
826 (addr.begin()+ii_keep)->second) = bla;
847 bool operator()(Node& node)
860 template<
class Iter,
class DT>
865 Matrix<double> tmp_mem_;
868 Matrix<double> feats_;
875 template<
class Feat_T,
class Label_T>
878 Feat_T
const & feats,
879 Label_T
const & labls,
884 :tmp_mem_(_spl(a, b).size(), feats.shape(1)),
887 feats_(_spl(a,b).size(), feats.shape(1)),
888 labels_(_spl(a,b).size(),1),
894 copy_splice(_spl(a,b),
895 _spl(feats.shape(1)),
898 copy_splice(_spl(a,b),
899 _spl(labls.shape(1)),
905 bool operator()(Node& node)
909 int class_count = perm_imp.
shape(1) - 1;
911 for(
int kk = 0; kk < nPerm; ++kk)
914 for(
int ii = 0; ii <
rowCount(feats_); ++ii)
917 for(
int jj = 0; jj < node.columns_size(); ++jj)
919 if(node.columns_begin()[jj] != feats_.shape(1))
920 tmp_mem_(ii, node.columns_begin()[jj])
921 = tmp_mem_(index, node.columns_begin()[jj]);
925 for(
int ii = 0; ii <
rowCount(tmp_mem_); ++ii)
932 ++perm_imp(index,labels_(ii, 0));
934 ++perm_imp(index, class_count);
938 double node_status = perm_imp(index, class_count);
939 node_status /= nPerm;
940 node_status -= orig_imp(0, class_count);
942 node_status /= oob_size;
943 node.status() += node_status;
963 void save(std::string file, std::string prefix)
965 vigra::writeHDF5(file.c_str(), (prefix +
"_variables").c_str(),
970 bool operator()(Node& node)
972 for(
int ii = 0; ii < node.columns_size(); ++ii)
973 variables(index, ii) = node.columns_begin()[ii];
987 bool operator()(Nde & cur,
int level, Nde parent,
bool infm)
990 cur.status() = std::min(parent.status(), cur.status());
1017 std::ofstream graphviz;
1022 std::string
const gz)
1023 :features_(features), labels_(labels),
1024 graphviz(gz.c_str(), std::ios::out)
1026 graphviz <<
"digraph G\n{\n node [shape=\"record\"]";
1030 graphviz <<
"\n}\n";
1035 bool operator()(Nde & cur,
int level, Nde parent,
bool infm)
1037 graphviz <<
"node" << cur.index() <<
" [style=\"filled\"][label = \" #Feats: "<< cur.columns_size() <<
"\\n";
1038 graphviz <<
" status: " << cur.status() <<
"\\n";
1039 for(
int kk = 0; kk < cur.columns_size(); ++kk)
1041 graphviz << cur.columns_begin()[kk] <<
" ";
1045 graphviz <<
"\"] [color = \"" <<cur.status() <<
" 1.000 1.000\"];\n";
1047 graphviz <<
"\"node" << parent.index() <<
"\" -> \"node" << cur.index() <<
"\";\n";
1067 int repetition_count_;
1073 void save(std::string filename, std::string prefix)
1075 std::string prefix1 =
"cluster_importance_" + prefix;
1076 writeHDF5(filename.c_str(),
1079 prefix1 =
"vars_" + prefix;
1080 writeHDF5(filename.c_str(),
1087 : repetition_count_(rep_cnt), clustering(clst)
1093 template<
class RF,
class PR>
1096 Int32 const class_count = rf.ext_param_.class_count_;
1097 Int32 const column_count = rf.ext_param_.column_count_+1;
1118 template<
class RF,
class PR,
class SM,
class ST>
1122 Int32 column_count = rf.ext_param_.column_count_ +1;
1123 Int32 class_count = rf.ext_param_.class_count_;
1127 typename PR::Feature_t & features
1128 =
const_cast<typename PR::Feature_t &
>(pr.features());
1135 if(rf.ext_param_.actual_msample_ < pr.features().shape(0)- 10000)
1139 for(
int ii = 0; ii < pr.features().shape(0); ++ii)
1140 indices.push_back(ii);
1141 std::random_shuffle(indices.begin(), indices.end());
1142 for(
int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
1144 if(!sm.is_used()[indices[ii]] && cts[pr.response()(indices[ii], 0)] < 3000)
1146 oob_indices.push_back(indices[ii]);
1147 ++cts[pr.response()(indices[ii], 0)];
1153 for(
int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
1154 if(!sm.is_used()[ii])
1155 oob_indices.push_back(ii);
1165 oob_right(Shp_t(1, class_count + 1));
1168 for(iter = oob_indices.
begin();
1169 iter != oob_indices.
end();
1173 .predictLabel(
rowVector(features, *iter))
1174 == pr.response()(*iter, 0))
1177 ++oob_right[pr.response()(*iter,0)];
1179 ++oob_right[class_count];
1184 perm_oob_right (Shp_t(2* column_count-1, class_count + 1));
1187 pc(oob_indices.
begin(), oob_indices.
end(),
1196 perm_oob_right /= repetition_count_;
1197 for(
int ii = 0; ii <
rowCount(perm_oob_right); ++ii)
1198 rowVector(perm_oob_right, ii) -= oob_right;
1200 perm_oob_right *= -1;
1201 perm_oob_right /= oob_indices.
size();
1210 template<
class RF,
class PR,
class SM,
class ST>
1218 template<
class RF,
class PR>
1258 template<
class FeatureT,
class ResponseT>
1260 ResponseT
const & response,
1267 if(features.shape(0) > 40000)
1274 RF.
learn(features, response,
1303 template<
class FeatureT,
class ResponseT>
1305 ResponseT
const & response,
1306 HClustering & linkage)