35 #ifndef VIGRA_RF3_VISITORS_HXX 36 #define VIGRA_RF3_VISITORS_HXX 40 #include "../multi_array.hxx" 41 #include "../multi_shape.hxx" 89 template <
typename VISITORS,
typename RF,
typename FEATURES,
typename LABELS>
98 template <
typename TREE,
typename FEATURES,
typename LABELS,
typename WEIGHTS>
105 template <
typename RF,
typename FEATURES,
typename LABELS,
typename WEIGHTS>
115 template <
typename TREE,
179 template <
typename TREE,
typename FEATURES,
typename LABELS,
typename WEIGHTS>
186 double const EPS = 1e-20;
190 is_in_bag_.resize(weights.size(),
true);
191 for (
size_t i = 0; i < weights.size(); ++i)
195 is_in_bag_[i] =
false;
201 throw std::runtime_error(
"OOBError::visit_before_tree(): The tree has no out-of-bags.");
207 template <
typename VISITORS,
typename RF,
typename FEATURES,
typename LABELS>
211 const FEATURES & features,
212 const LABELS & labels
215 vigra_precondition(rf.num_trees() > 0,
"OOBError::visit_after_training(): Number of trees must be greater than zero after training.");
216 vigra_precondition(visitors.size() == rf.num_trees(),
"OOBError::visit_after_training(): Number of visitors must be equal to number of trees.");
217 size_t const num_instances = features.shape()[0];
218 auto const num_features = features.shape()[1];
219 for (
auto vptr : visitors)
220 vigra_precondition(vptr->is_in_bag_.size() == num_instances,
"OOBError::visit_after_training(): Some visitors have the wrong number of data points.");
223 typedef typename std::remove_const<LABELS>::type Labels;
226 for (
size_t i = 0; i < (size_t)num_instances; ++i)
229 std::vector<size_t> tree_indices;
230 for (
size_t k = 0; k < visitors.size(); ++k)
231 if (!visitors[k]->is_in_bag_[i])
232 tree_indices.push_back(k);
235 auto const sub_features = features.subarray(Shape2(i, 0), Shape2(i+1, num_features));
236 rf.predict(sub_features, pred, 1, tree_indices);
237 if (pred(0) != labels(i))
240 oob_err_ /= num_instances;
249 std::vector<bool> is_in_bag_;
263 repetition_count_(repetition_count)
269 template <
typename TREE,
typename FEATURES,
typename LABELS,
typename WEIGHTS>
279 auto const num_features = features.shape()[1];
280 variable_importance_.reshape(Shape2(num_features, tree.num_classes()+2), 0.0);
283 double const EPS = 1e-20;
285 is_in_bag_.resize(weights.size(),
true);
286 for (
size_t i = 0; i < weights.size(); ++i)
290 is_in_bag_[i] =
false;
295 throw std::runtime_error(
"VariableImportance::visit_before_tree(): The tree has no out-of-bags.");
301 template <
typename TREE,
317 typename SCORER::Functor functor;
318 auto const region_impurity = functor.region_score(labels, weights, begin, end);
319 auto const split_impurity = scorer.best_score_;
320 variable_importance_(scorer.best_dim_, tree.num_classes()+1) += region_impurity - split_impurity;
326 template <
typename RF,
typename FEATURES,
typename LABELS,
typename WEIGHTS>
328 const FEATURES & features,
329 const LABELS & labels,
333 typedef typename std::remove_const<FEATURES>::type Features;
334 typedef typename std::remove_const<LABELS>::type Labels;
336 typedef typename Features::value_type FeatureType;
338 auto const num_features = features.shape()[1];
345 copy_out_of_bags(features, labels, feats, labs);
346 auto const num_oobs = feats.shape()[0];
351 rf.predict(feats, pred, 1);
352 for (
size_t i = 0; i < (size_t)labs.size(); ++i)
354 if (labs(i) == pred(i))
356 oob_right(labs(i)) += 1.0;
357 oob_right(rf.num_classes()) += 1.0;
363 for (
size_t j = 0; j < (size_t)num_features; ++j)
366 backup = feats.template bind<1>(j);
369 for (
size_t k = 0; k < repetition_count_; ++k)
372 for (
int ii = num_oobs-1; ii >= 1; --ii)
373 std::swap(feats(ii, j), feats(randint(ii+1), j));
376 rf.predict(feats, pred, 1);
377 for (
size_t i = 0; i < (size_t)labs.size(); ++i)
379 if (labs(i) == pred(i))
381 perm_oob_right(0, labs(i)) += 1.0;
382 perm_oob_right(0, rf.num_classes()) += 1.0;
388 perm_oob_right /= repetition_count_;
389 perm_oob_right.bind<0>(0) -= oob_right;
390 perm_oob_right *= -1;
391 perm_oob_right /= num_oobs;
392 variable_importance_.
subarray(Shape2(j, 0), Shape2(j+1, rf.num_classes()+1)) += perm_oob_right;
395 feats.template bind<1>(j) = backup;
402 template <
typename VISITORS,
typename RF,
typename FEATURES,
typename LABELS>
406 const FEATURES & features,
409 vigra_precondition(rf.num_trees() > 0,
"VariableImportance::visit_after_training(): Number of trees must be greater than zero after training.");
410 vigra_precondition(visitors.size() == rf.num_trees(),
"VariableImportance::visit_after_training(): Number of visitors must be equal to number of trees.");
413 auto const num_features = features.shape()[1];
414 variable_importance_.reshape(Shape2(num_features, rf.num_classes()+2), 0.0);
415 for (
auto vptr : visitors)
417 vigra_precondition(vptr->variable_importance_.shape() == variable_importance_.shape(),
418 "VariableImportance::visit_after_training(): Shape mismatch.");
419 variable_importance_ += vptr->variable_importance_;
423 variable_importance_ /= rf.num_trees();
464 template <
typename F0,
typename L0,
typename F1,
typename L1>
465 void copy_out_of_bags(
466 F0
const & features_in,
467 L0
const & labels_in,
471 auto const num_instances = features_in.shape()[0];
472 auto const num_features = features_in.shape()[1];
476 for (
auto x : is_in_bag_)
481 features_out.reshape(Shape2(num_oobs, num_features));
482 labels_out.reshape(
Shape1(num_oobs));
484 for (
size_t i = 0; i < (size_t)num_instances; ++i)
488 auto const src = features_in.template bind<0>(i);
489 auto out = features_out.template bind<0>(current);
491 labels_out(current) = labels_in(i);
497 std::vector<bool> is_in_bag_;
518 template <
typename VISITOR,
typename NEXT = RFStopVisiting,
bool CPY = false>
523 typedef VISITOR Visitor;
526 typename std::conditional<CPY, Visitor, Visitor &>::type visitor_;
543 visitor_(other.visitor_),
549 visitor_(other.visitor_),
555 if (visitor_.is_active())
556 visitor_.visit_before_training();
557 next_.visit_before_training();
560 template <
typename VISITORS,
typename RF,
typename FEATURES,
typename LABELS>
561 void visit_after_training(VISITORS & v, RF & rf,
const FEATURES & features,
const LABELS & labels)
563 typedef typename VISITORS::value_type VisitorNodeType;
564 typedef typename VisitorNodeType::Visitor VisitorType;
565 typedef typename VisitorNodeType::Next NextType;
571 if (visitor_.is_active())
573 std::vector<VisitorType*> visitors;
575 visitors.push_back(&x.visitor_);
576 visitor_.visit_after_training(visitors, rf, features, labels);
580 std::vector<NextType> nexts;
582 nexts.push_back(x.next_);
585 next_.visit_after_training(nexts, rf, features, labels);
588 template <
typename TREE,
typename FEATURES,
typename LABELS,
typename WEIGHTS>
589 void visit_before_tree(TREE & tree, FEATURES & features, LABELS & labels, WEIGHTS & weights)
591 if (visitor_.is_active())
592 visitor_.visit_before_tree(tree, features, labels, weights);
593 next_.visit_before_tree(tree, features, labels, weights);
596 template <
typename RF,
typename FEATURES,
typename LABELS,
typename WEIGHTS>
602 if (visitor_.is_active())
603 visitor_.visit_after_tree(rf, features, labels, weights);
604 next_.visit_after_tree(rf, features, labels, weights);
607 template <
typename TREE,
622 if (visitor_.is_active())
623 visitor_.visit_after_split(tree, features, labels, weights, scorer, begin, split, end);
624 next_.visit_after_split(tree, features, labels, weights, scorer, begin, split, end);
634 template <
typename VISITOR>
643 typedef RFStopVisiting
type;
655 create_visitor(A & a)
662 template<
typename A,
typename B>
664 create_visitor(A & a, B & b)
673 template<
typename A,
typename B,
typename C>
675 create_visitor(A & a, B & b, C & c)
686 template<
typename A,
typename B,
typename C,
typename D>
689 create_visitor(A & a, B & b, C & c, D & d)
702 template<
typename A,
typename B,
typename C,
typename D,
typename E>
705 create_visitor(A & a, B & b, C & c, D & d, E & e)
720 template<
typename A,
typename B,
typename C,
typename D,
typename E,
724 create_visitor(A & a, B & b, C & c, D & d, E & e, F & f)
741 template<
typename A,
typename B,
typename C,
typename D,
typename E,
742 typename F,
typename G>
746 create_visitor(A & a, B & b, C & c, D & d, E & e, F & f, G & g)
765 template<
typename A,
typename B,
typename C,
typename D,
typename E,
766 typename F,
typename G,
typename H>
770 create_visitor(A & a, B & b, C & c, D & d, E & e, F & f, G & g, H & h)
791 template<
typename A,
typename B,
typename C,
typename D,
typename E,
792 typename F,
typename G,
typename H,
typename I>
796 create_visitor(A & a, B & b, C & c, D & d, E & e, F & f, G & g, H & h, I & i)
819 template<
typename A,
typename B,
typename C,
typename D,
typename E,
820 typename F,
typename G,
typename H,
typename I,
typename J>
825 create_visitor(A & a, B & b, C & c, D & d, E & e, F & f, G & g, H & h, I & i,
void visit_before_tree(TREE &, FEATURES &, LABELS &, WEIGHTS &)
Do something before a tree has been learned.
Definition: random_forest_visitors.hxx:99
void visit_before_tree(TREE &tree, FEATURES &features, LABELS &, WEIGHTS &weights)
Definition: random_forest_visitors.hxx:270
void visit_after_split(TREE &, FEATURES &, LABELS &, WEIGHTS &, SCORER &, ITER, ITER, ITER)
Do something after the split was made.
Definition: random_forest_visitors.hxx:121
void visit_after_tree(RF &, FEATURES &, LABELS &, WEIGHTS &)
Do something after a tree has been learned.
Definition: random_forest_visitors.hxx:106
void deactivate()
Deactivate the visitor.
Definition: random_forest_visitors.hxx:150
MultiArrayView subarray(difference_type p, difference_type q) const
Definition: multi_array.hxx:1528
Base class from which all random forest visitors derive.
Definition: random_forest_visitors.hxx:68
bool is_active() const
Return whether the visitor is active or not.
Definition: random_forest_visitors.hxx:134
size_t repetition_count_
Definition: random_forest_visitors.hxx:457
The default visitor node (= "do nothing").
Definition: random_forest_visitors.hxx:509
void visit_after_training(VISITORS &visitors, RF &rf, const FEATURES &features, const LABELS &labels)
Definition: random_forest_visitors.hxx:208
Compute the variable importance.
Definition: random_forest_visitors.hxx:257
Compute the out of bag error.
Definition: random_forest_visitors.hxx:172
double oob_err_
Definition: random_forest_visitors.hxx:246
void visit_after_tree(RF &rf, const FEATURES &features, const LABELS &labels, WEIGHTS &)
Definition: random_forest_visitors.hxx:327
Definition: random_forest_visitors.hxx:635
void activate()
Activate the visitor.
Definition: random_forest_visitors.hxx:142
Class for fixed size vectors.This class contains an array of size SIZE of the specified VALUETYPE...
Definition: accessor.hxx:940
void visit_after_split(TREE &tree, FEATURES &, LABELS &labels, WEIGHTS &weights, SCORER &scorer, ITER begin, ITER, ITER end)
Definition: random_forest_visitors.hxx:307
void visit_before_tree(TREE &, FEATURES &, LABELS &, WEIGHTS &weights)
Definition: random_forest_visitors.hxx:180
Container elements of the statically linked visitor list. Use the create_visitor() functions to creat...
Definition: random_forest_visitors.hxx:519
MultiArray< 2, double > variable_importance_
Definition: random_forest_visitors.hxx:452
FFTWComplex< R >::NormType abs(const FFTWComplex< R > &a)
absolute value (= magnitude)
Definition: fftw3.hxx:1002
void visit_before_training()
Do something before training starts.
Definition: random_forest_visitors.hxx:80
void visit_after_training(VISITORS &, RF &, const FEATURES &, const LABELS &)
Do something after all trees have been learned.
Definition: random_forest_visitors.hxx:90
void visit_after_training(VISITORS &visitors, RF &rf, const FEATURES &features, const LABELS &)
Definition: random_forest_visitors.hxx:403