CARTClassifier.h
Go to the documentation of this file.
1 //===========================================================================
2 /*!
3  *
4  *
5  * \brief Cart Classifier
6  *
7  *
8  *
9  * \author K. N. Hansen, J. Kremer
10  * \date 2012
11  *
12  *
13  * \par Copyright 1995-2015 Shark Development Team
14  *
15  * <BR><HR>
16  * This file is part of Shark.
17  * <http://image.diku.dk/shark/>
18  *
19  * Shark is free software: you can redistribute it and/or modify
20  * it under the terms of the GNU Lesser General Public License as published
21  * by the Free Software Foundation, either version 3 of the License, or
22  * (at your option) any later version.
23  *
24  * Shark is distributed in the hope that it will be useful,
25  * but WITHOUT ANY WARRANTY; without even the implied warranty of
26  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27  * GNU Lesser General Public License for more details.
28  *
29  * You should have received a copy of the GNU Lesser General Public License
30  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
31  *
32  */
33 //===========================================================================
34 
35 #ifndef SHARK_MODELS_TREES_CARTCLASSIFIER_H
36 #define SHARK_MODELS_TREES_CARTCLASSIFIER_H
37 
38 
42 #include <shark/Data/Dataset.h>
43 
44 namespace shark {
45 
46 
47 ///
48 /// \brief CART Classifier.
49 ///
50 /// \par
51 /// The CARTClassifier predicts a class label
52 /// using the CART algorithm.
53 ///
54 /// \par
55 /// It is a decision tree algorithm.
56 ///
57 template<class LabelType>
58 class CARTClassifier : public AbstractModel<RealVector,LabelType>
59 {
60 private:
62 public:
65 // Information about a single split. misclassProp, r and g are variables used in the cost complexity step
66  struct SplitInfo{
67  std::size_t nodeId;
68  std::size_t attributeIndex;
70  std::size_t leftNodeId;
71  std::size_t rightNodeId;
72  LabelType label;
73  double misclassProp;//TODO: remove this
74  std::size_t r;//TODO: remove this
75  double g;//TODO: remove this
76 
77  template<class Archive>
78  void serialize(Archive & ar, const unsigned int version){
79  ar & nodeId;
80  ar & attributeIndex;
81  ar & attributeValue;
82  ar & leftNodeId;
83  ar & rightNodeId;
84  ar & label;
85  ar & misclassProp;
86  ar & r;
87  ar & g;
88  }
89  };
90 
91  /// Vector of structs that contains the splitting information and the labels.
92  /// The class label is a normalized histogram in the classification case.
93  /// In the regression case, the label is the regression value.
94  typedef std::vector<SplitInfo> SplitMatrixType;
95 
96  /// Constructor
98  {}
99 
100  /// Constructor taking the splitMatrix as argument
101  CARTClassifier(SplitMatrixType const& splitMatrix)
102  {
103  setSplitMatrix(splitMatrix);
104  }
105 
106  /// Constructor taking the splitMatrix as argument as well as maximum number of attributes
107  CARTClassifier(SplitMatrixType const& splitMatrix, std::size_t d)
108  {
109  setSplitMatrix(splitMatrix);
110  m_inputDimension = d;
111  }
112 
113  /// \brief From INameable: return the class name.
114  std::string name() const
115  { return "CARTClassifier"; }
116 
117  boost::shared_ptr<State> createState()const{
118  return boost::shared_ptr<State>(new EmptyState());
119  }
120 
121  using base_type::eval;
122  /// \brief Evaluate the Tree on a batch of patterns
123  void eval(const BatchInputType& patterns, BatchOutputType& outputs)const{
124  std::size_t numPatterns = shark::size(patterns);
125  //evaluate the first pattern alone and create the batch output from that
126  LabelType const& firstResult = evalPattern(row(patterns,0));
127  outputs = Batch<LabelType>::createBatch(firstResult,numPatterns);
128  get(outputs,0) = firstResult;
129 
130  //evaluate the rest
131  for(std::size_t i = 0; i != numPatterns; ++i){
132  get(outputs,i) = evalPattern(row(patterns,i));
133  }
134  }
135 
136  void eval(const BatchInputType& patterns, BatchOutputType& outputs, State& state)const{
137  eval(patterns,outputs);
138  }
139  /// \brief Evaluate the Tree on a single pattern
140  void eval(RealVector const & pattern, LabelType& output){
141  output = evalPattern(pattern);
142  }
143 
144  /// Set the model split matrix.
145  void setSplitMatrix(SplitMatrixType const& splitMatrix){
146  m_splitMatrix = splitMatrix;
148  }
149 
150  /// \brief The model does not have any parameters.
151  std::size_t numberOfParameters()const{
152  return 0;
153  }
154 
155  /// \brief The model does not have any parameters.
156  RealVector parameterVector() const {
157  return RealVector();
158  }
159 
160  /// \brief The model does not have any parameters.
161  void setParameterVector(const RealVector& param) {
162  SHARK_ASSERT(param.size() == 0);
163  }
164 
165  /// from ISerializable, reads a model from an archive
166  void read(InArchive& archive){
167  archive >> m_splitMatrix;
168  }
169 
170  /// from ISerializable, writes a model to an archive
171  void write(OutArchive& archive) const {
172  archive << m_splitMatrix;
173  }
174 
175 
176  //Count how often attributes are used
177  UIntVector countAttributes() const {
179  UIntVector r(m_inputDimension, 0);
180  typename SplitMatrixType::const_iterator it;
181  for(it = m_splitMatrix.begin(); it != m_splitMatrix.end(); ++it) {
182  //std::cout << "NodeId: " <<it->leftNodeId << std::endl;
183  if(it->leftNodeId != 0) { // not a label
184  r(it->attributeIndex)++;
185  }
186  }
187  return r;
188  }
189 
190  ///Return input dimension
191  std::size_t inputSize() const {
192  return m_inputDimension;
193  }
194 
195  //Set input dimension
196  void setInputDimension(std::size_t d) {
197  m_inputDimension = d;
198  }
199 
200  /// Compute oob error, given an oob dataset (Classification)
202  // define loss
204 
205  // predict oob data
206  Data<RealVector> predOOB = (*this)(dataOOB.inputs());
207 
208  // count average number of oob misclassifications
209  m_OOBerror = lossOOB.eval(dataOOB.labels(), predOOB);
210  }
211 
212  /// Compute oob error, given an oob dataset (Regression)
213  void computeOOBerror(const RegressionDataset& dataOOB){
214  // define loss
216 
217  // predict oob data
218  Data<RealVector> predOOB = (*this)(dataOOB.inputs());
219 
220  // Compute mean squared error
221  m_OOBerror = lossOOB.eval(dataOOB.labels(), predOOB);
222  }
223 
224  /// Return OOB error
225  double OOBerror() const {
226  return m_OOBerror;
227  }
228 
229  /// Return feature importances
230  RealVector const& featureImportances() const {
231  return m_featureImportances;
232  }
233 
234  /// Compute feature importances, given an oob dataset (Classification)
237 
238  // define loss
240 
241  // compute oob error
242  computeOOBerror(dataOOB);
243 
244  // count average number of correct oob predictions
245  double accuracyOOB = 1. - m_OOBerror;
246 
247  // go through all dimensions, permute each dimension across all elements and train the tree on it
248  for(std::size_t i=0;i!=m_inputDimension;++i) {
249  // create permuted dataset by copying
250  ClassificationDataset pDataOOB(dataOOB);
251  pDataOOB.makeIndependent();
252 
253  // permute current dimension
254  RealVector v = getColumn(pDataOOB.inputs(), i);
255  std::random_shuffle(v.begin(), v.end());
256  setColumn(pDataOOB.inputs(), i, v);
257 
258  // evaluate the data set for which one feature dimension was permuted with this tree
259  Data<RealVector> pPredOOB = (*this)(pDataOOB.inputs());
260 
261  // count the number of correct predictions
262  double accuracyPermutedOOB = 1. - lossOOB.eval(pDataOOB.labels(),pPredOOB);
263 
264  // store importance
265  m_featureImportances[i] = std::fabs(accuracyOOB - accuracyPermutedOOB);
266  }
267  }
268 
269  /// Compute feature importances, given an oob dataset (Regression)
272 
273  // define loss
275 
276  // compute oob error
277  computeOOBerror(dataOOB);
278 
279  // mean squared error for oob sample
280  double mseOOB = m_OOBerror;
281 
282  // go through all dimensions, permute each dimension across all elements and train the tree on it
283  for(std::size_t i=0;i!=m_inputDimension;++i) {
284  // create permuted dataset by copying
285  RegressionDataset pDataOOB(dataOOB);
286  pDataOOB.makeIndependent();
287 
288  // permute current dimension
289  RealVector v = getColumn(pDataOOB.inputs(), i);
290  std::random_shuffle(v.begin(), v.end());
291  setColumn(pDataOOB.inputs(), i, v);
292 
293  // evaluate the data set for which one feature dimension was permuted with this tree
294  Data<RealVector> pPredOOB = (*this)(pDataOOB.inputs());
295 
296  // mean squared error of permuted oob sample
297  double msePermutedOOB = lossOOB.eval(pDataOOB.labels(),pPredOOB);
298 
299  // store importance
300  m_featureImportances[i] = std::fabs(msePermutedOOB - mseOOB);
301  }
302  }
303 
304 protected:
305  /// split matrix of the model
306  SplitMatrixType m_splitMatrix;
307 
308  /// \brief Finds the index of the node with a certain nodeID in an unoptimized split matrix.
309  std::size_t findNode(std::size_t nodeId)const{
310  std::size_t index = 0;
311  for(; nodeId != m_splitMatrix[index].nodeId; ++index);
312  return index;
313  }
314 
315  /// Optimize a split matrix, so constant lookup can be used.
316  /// The optimization is done by changing the index of the children
317  /// to use indices instead of node ID.
318  /// Furthermore, the node IDs are converted to index numbers.
319  void optimizeSplitMatrix(SplitMatrixType& splitMatrix)const{
320  for(std::size_t i = 0; i < splitMatrix.size(); i++){
321  splitMatrix[i].leftNodeId = findNode(splitMatrix[i].leftNodeId);
322  splitMatrix[i].rightNodeId = findNode(splitMatrix[i].rightNodeId);
323  }
324  for(std::size_t i = 0; i < splitMatrix.size(); i++){
325  splitMatrix[i].nodeId = i;
326  }
327  }
328 
329  /// Evaluate the CART tree on a single sample
330  template<class Vector>
331  LabelType const& evalPattern(Vector const& pattern)const{
332  std::size_t nodeId = 0;
333  while(m_splitMatrix[nodeId].leftNodeId != 0){
334  if(pattern[m_splitMatrix[nodeId].attributeIndex]<=m_splitMatrix[nodeId].attributeValue){
335  //Branch on left node
336  nodeId = m_splitMatrix[nodeId].leftNodeId;
337  }else{
338  //Branch on right node
339  nodeId = m_splitMatrix[nodeId].rightNodeId;
340  }
341  }
342  return m_splitMatrix[nodeId].label;
343  }
344 
345 
346  ///Number of attributes (set by trainer)
347  std::size_t m_inputDimension;
348 
349  // feature importances
351 
352  // oob error
353  double m_OOBerror;
354 };
355 
356 
357 }
358 #endif