RFClassifier.h
Go to the documentation of this file.
1 //===========================================================================
2 /*!
3  *
4  *
5  * \brief Random Forest Classifier.
6  *
7  *
8  *
9  * \author K. N. Hansen, O.Krause, J. Kremer
10  * \date 2011-2012
11  *
12  *
13  * \par Copyright 1995-2015 Shark Development Team
14  *
15  * <BR><HR>
16  * This file is part of Shark.
17  * <http://image.diku.dk/shark/>
18  *
19  * Shark is free software: you can redistribute it and/or modify
20  * it under the terms of the GNU Lesser General Public License as published
21  * by the Free Software Foundation, either version 3 of the License, or
22  * (at your option) any later version.
23  *
24  * Shark is distributed in the hope that it will be useful,
25  * but WITHOUT ANY WARRANTY; without even the implied warranty of
26  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27  * GNU Lesser General Public License for more details.
28  *
29  * You should have received a copy of the GNU Lesser General Public License
30  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
31  *
32  */
33 //===========================================================================
34 
35 #ifndef SHARK_MODELS_TREES_RFCLASSIFIER_H
36 #define SHARK_MODELS_TREES_RFCLASSIFIER_H
37 
39 #include <shark/Models/MeanModel.h>
40 
41 namespace shark {
42 
43 
44 ///
45 /// \brief Random Forest Classifier.
46 ///
47 /// \par
48 /// The Random Forest Classifier predicts a class label
49 /// using the Random Forest algorithm as described in<br/>
50 /// Random Forests. Leo Breiman. Machine Learning, 1(45), pages 5-32. Springer, 2001.<br/>
51 ///
52 /// \par
53 /// It is a ensemble learner that uses multiple decision trees built
54 /// using the CART methodology.
55 ///
56 class RFClassifier : public MeanModel<CARTClassifier<RealVector> >
57 {
58 public:
59  /// \brief From INameable: return the class name.
60  std::string name() const
61  { return "RFClassifier"; }
62 
63  // compute the oob error for the forest
65  std::size_t n_trees = numberOfModels();
66  m_OOBerror = 0;
67  for(std::size_t j=0;j!=n_trees;++j){
68  m_OOBerror += m_models[j].OOBerror();
69  }
70  m_OOBerror /= n_trees;
71  }
72 
73  // compute the feature importances for the forest
76  std::size_t n_trees = numberOfModels();
77 
78  for(std::size_t i=0;i!=m_inputDimension;++i){
79  m_featureImportances[i] = 0;
80  for(std::size_t j=0;j!=n_trees;++j){
81  m_featureImportances[i] += m_models[j].featureImportances()[i];
82  }
83  m_featureImportances[i] /= n_trees;
84  }
85  }
86 
87  double const OOBerror() const {
88  return m_OOBerror;
89  }
90 
91  // returns the feature importances
92  RealVector const& featureImportances() const {
93  return m_featureImportances;
94  }
95 
96  //Count how often attributes are used
97  UIntVector countAttributes() const {
98  std::size_t n = m_models.size();
99  if(!n) return UIntVector();
100  UIntVector r = m_models[0].countAttributes();
101  for(std::size_t i=1; i< n; i++ ) {
102  noalias(r) += m_models[i].countAttributes();
103  }
104  return r;
105  }
106 
107  /// Set the dimension of the labels
108  void setLabelDimension(std::size_t in){
109  m_labelDimension = in;
110  }
111 
112  // Set the input dimension
113  void setInputDimension(std::size_t in){
114  m_inputDimension = in;
115  }
116 
117 protected:
118  // Dimension of label in the regression case, number of classes in the classification case.
119  std::size_t m_labelDimension;
120 
121  // Input dimension
122  std::size_t m_inputDimension;
123 
124  // oob error for the forest
125  double m_OOBerror;
126 
127  // feature importances for the forest
129 
130 };
131 
132 
133 }
134 #endif