37 #ifndef VIGRA_RF_COMMON_HXX
38 #define VIGRA_RF_COMMON_HXX
44 struct ClassificationTag
99 template<
class T,
class C>
104 static T & choose(T & t, C &)
176 double training_set_proportion_;
177 int training_set_size_;
178 int (*training_set_func_)(int);
180 training_set_calc_switch_;
182 bool sample_with_replacement_;
184 stratification_method_;
195 int (*mtry_func_)(int) ;
197 bool predict_weighted_;
199 int min_split_node_size_;
200 bool prepare_online_learning_;
203 int serialized_size()
const
212 #define COMPARE(field) result = result && (this->field == rhs.field);
213 COMPARE(training_set_proportion_);
214 COMPARE(training_set_size_);
215 COMPARE(training_set_calc_switch_);
216 COMPARE(sample_with_replacement_);
217 COMPARE(stratification_method_);
218 COMPARE(mtry_switch_);
220 COMPARE(tree_count_);
221 COMPARE(min_split_node_size_);
222 COMPARE(predict_weighted_);
229 return !(*
this == rhs_);
232 void unserialize(Iter
const & begin, Iter
const & end)
235 vigra_precondition(static_cast<int>(end - begin) == serialized_size(),
236 "RandomForestOptions::unserialize():"
237 "wrong number of parameters");
238 #define PULL(item_, type_) item_ = type_(*iter); ++iter;
239 PULL(training_set_proportion_,
double);
240 PULL(training_set_size_,
int);
243 PULL(sample_with_replacement_, 0 != );
248 PULL(tree_count_,
int);
249 PULL(min_split_node_size_,
int);
250 PULL(predict_weighted_, 0 !=);
254 void serialize(Iter
const & begin, Iter
const & end)
const
257 vigra_precondition(static_cast<int>(end - begin) == serialized_size(),
258 "RandomForestOptions::serialize():"
259 "wrong number of parameters");
260 #define PUSH(item_) *iter = double(item_); ++iter;
261 PUSH(training_set_proportion_);
262 PUSH(training_set_size_);
263 if(training_set_func_ != 0)
271 PUSH(training_set_calc_switch_);
272 PUSH(sample_with_replacement_);
273 PUSH(stratification_method_);
285 PUSH(min_split_node_size_);
286 PUSH(predict_weighted_);
293 #define PULL(item_, type_) item_ = type_(in[#item_][0]);
294 #define PULLBOOL(item_, type_) item_ = type_(in[#item_][0] > 0);
295 PULL(training_set_proportion_,
double);
296 PULL(training_set_size_,
int);
298 PULL(tree_count_,
int);
299 PULL(min_split_node_size_,
int);
300 PULLBOOL(sample_with_replacement_,
bool);
301 PULLBOOL(prepare_online_learning_,
bool);
302 PULLBOOL(predict_weighted_,
bool);
317 #define PUSH(item_, type_) in[#item_] = ArrayVector<double>(1, double(item_));
318 #define PUSHFUNC(item_, type_) in[#item_] = ArrayVector<double>(1, double(item_!=0));
319 PUSH(training_set_proportion_,
double);
320 PUSH(training_set_size_,
int);
322 PUSH(tree_count_,
int);
323 PUSH(min_split_node_size_,
int);
324 PUSH(sample_with_replacement_,
bool);
325 PUSH(prepare_online_learning_,
bool);
326 PUSH(predict_weighted_,
bool);
332 PUSHFUNC(mtry_func_,
int);
333 PUSHFUNC(training_set_func_,
int);
346 training_set_proportion_(1.0),
347 training_set_size_(0),
348 training_set_func_(0),
349 training_set_calc_switch_(RF_PROPORTIONAL),
350 sample_with_replacement_(true),
351 stratification_method_(RF_NONE),
352 mtry_switch_(RF_SQRT),
355 predict_weighted_(false),
357 min_split_node_size_(1),
358 prepare_online_learning_(false)
374 vigra_precondition(in == RF_EQUAL ||
375 in == RF_PROPORTIONAL ||
378 "RandomForestOptions::use_stratification()"
379 "input must be RF_EQUAL, RF_PROPORTIONAL,"
380 "RF_EXTERNAL or RF_NONE");
381 stratification_method_ = in;
387 prepare_online_learning_=in;
397 sample_with_replacement_ = in;
411 training_set_proportion_ = in;
412 training_set_calc_switch_ = RF_PROPORTIONAL;
420 training_set_size_ = in;
421 training_set_calc_switch_ = RF_CONST;
433 training_set_func_ = in;
434 training_set_calc_switch_ = RF_FUNCTION;
442 predict_weighted_ =
true;
455 vigra_precondition(in == RF_LOG ||
458 "RandomForestOptions()::features_per_node():"
459 "input must be of type RF_LOG or RF_SQRT");
473 mtry_switch_ = RF_CONST;
485 mtry_switch_ = RF_FUNCTION;
509 min_split_node_size_ = in;
530 template<
class LabelType =
double>
559 void to_classlabel(
int index, T & out)
const
561 out = T(classes[index]);
564 int to_classIndex(T index)
const
566 return std::find(classes.
begin(), classes.
end(), index) - classes.
begin();
569 #define EQUALS(field) field(rhs.field)
572 EQUALS(column_count_),
573 EQUALS(class_count_),
575 EQUALS(actual_mtry_),
576 EQUALS(actual_msample_),
577 EQUALS(problem_type_),
579 EQUALS(class_weights_),
580 EQUALS(is_weighted_),
583 std::back_insert_iterator<ArrayVector<Label_t> >
585 std::copy(rhs.classes.begin(), rhs.classes.end(), iter);
588 #define EQUALS(field) field(rhs.field)
592 EQUALS(column_count_),
593 EQUALS(class_count_),
595 EQUALS(actual_mtry_),
596 EQUALS(actual_msample_),
597 EQUALS(problem_type_),
599 EQUALS(class_weights_),
600 EQUALS(is_weighted_),
603 std::back_insert_iterator<ArrayVector<Label_t> >
605 std::copy(rhs.classes.begin(), rhs.classes.end(), iter);
611 #define EQUALS(field) (this->field = rhs.field);
614 EQUALS(column_count_);
615 EQUALS(class_count_);
617 EQUALS(actual_mtry_);
618 EQUALS(actual_msample_);
619 EQUALS(problem_type_);
621 EQUALS(is_weighted_);
623 class_weights_.clear();
624 std::back_insert_iterator<ArrayVector<double> >
625 iter2(class_weights_);
626 std::copy(rhs.class_weights_.begin(), rhs.class_weights_.end(), iter2);
628 std::back_insert_iterator<ArrayVector<Label_t> >
630 std::copy(rhs.classes.begin(), rhs.classes.end(), iter);
635 ProblemSpec<Label_t> & operator=(ProblemSpec<T>
const & rhs)
637 EQUALS(column_count_);
638 EQUALS(class_count_);
640 EQUALS(actual_mtry_);
641 EQUALS(actual_msample_);
642 EQUALS(problem_type_);
644 EQUALS(is_weighted_);
646 class_weights_.clear();
647 std::back_insert_iterator<ArrayVector<double> >
648 iter2(class_weights_);
649 std::copy(rhs.class_weights_.begin(), rhs.class_weights_.end(), iter2);
651 std::back_insert_iterator<ArrayVector<Label_t> >
653 std::copy(rhs.classes.begin(), rhs.classes.end(), iter);
659 bool operator==(ProblemSpec<T>
const & rhs)
662 #define COMPARE(field) result = result && (this->field == rhs.field);
663 COMPARE(column_count_);
664 COMPARE(class_count_);
666 COMPARE(actual_mtry_);
667 COMPARE(actual_msample_);
668 COMPARE(problem_type_);
669 COMPARE(is_weighted_);
672 COMPARE(class_weights_);
680 return !(*
this == rhs);
684 size_t serialized_size()
const
686 return 9 + class_count_ *int(is_weighted_+1);
691 void unserialize(Iter
const & begin, Iter
const & end)
694 vigra_precondition(end - begin >= 9,
695 "ProblemSpec::unserialize():"
696 "wrong number of parameters");
697 #define PULL(item_, type_) item_ = type_(*iter); ++iter;
698 PULL(column_count_,
int);
699 PULL(class_count_,
int);
701 vigra_precondition(end - begin >= 9 + class_count_,
702 "ProblemSpec::unserialize(): 1");
703 PULL(row_count_,
int);
704 PULL(actual_mtry_,
int);
705 PULL(actual_msample_,
int);
707 PULL(is_weighted_,
int);
709 PULL(precision_,
double);
712 vigra_precondition(end - begin == 9 + 2*class_count_,
713 "ProblemSpec::unserialize(): 2");
714 class_weights_.insert(class_weights_.
end(),
716 iter + class_count_);
717 iter += class_count_;
719 classes.insert(classes.
end(), iter, end);
725 void serialize(Iter
const & begin, Iter
const & end)
const
728 vigra_precondition(end - begin == serialized_size(),
729 "RandomForestOptions::serialize():"
730 "wrong number of parameters");
731 #define PUSH(item_) *iter = double(item_); ++iter;
736 PUSH(actual_msample_);
743 std::copy(class_weights_.
begin(),
744 class_weights_.
end(),
746 iter += class_count_;
748 std::copy(classes.
begin(),
754 void make_from_map(std::map<std::string, ArrayVector<double> > & in)
756 typedef MultiArrayShape<2>::type Shp;
757 #define PULL(item_, type_) item_ = type_(in[#item_][0]);
758 PULL(column_count_,
int);
759 PULL(class_count_,
int);
760 PULL(row_count_,
int);
761 PULL(actual_mtry_,
int);
762 PULL(actual_msample_,
int);
764 PULL(is_weighted_,
int);
766 PULL(precision_,
double);
767 class_weights_ = in[
"class_weights_"];
770 void make_map(std::map<std::string, ArrayVector<double> > & in)
const
772 typedef MultiArrayShape<2>::type Shp;
773 #define PUSH(item_) in[#item_] = ArrayVector<double>(1, double(item_));
778 PUSH(actual_msample_);
783 in["class_weights_"] = class_weights_;
795 problem_type_(CHECKLATER),
812 template<
class C_Iter>
815 int size = end-begin;
816 for(
int k=0; k<size; ++k, ++begin)
817 classes.push_back(detail::RequiresExplicitCast<LabelType>::cast(*begin));
827 template<
class W_Iter>
830 class_weights_.insert(class_weights_.
end(), begin, end);
841 class_weights_.clear();
846 problem_type_ = CHECKLATER;
847 is_weighted_ =
false;
870 int min_split_node_size_;
874 : min_split_node_size_(opt.min_split_node_size_)
878 void set_external_parameters(
ProblemSpec<T>const &,
int = 0,
bool =
false)
881 template<
class Region>
882 bool operator()(Region& region)
884 return region.size() < min_split_node_size_;
887 template<
class WeightIter,
class T,
class C>
897 #endif //VIGRA_RF_COMMON_HXX