#ifndef ROOT_TMVA_CCPruner
#define ROOT_TMVA_CCPruner
/**********************************************************************************
 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
 * Package: TMVA                                                                  *
 * Class  : CCPruner                                                              *
 * Web    : http://tmva.sourceforge.net                                           *
 *                                                                                *
 * Description: Cost Complexity Pruning                                           *
 * 
 * Author: Doug Schouten (dschoute@sfu.ca)
 *
 *                                                                                *
 * Copyright (c) 2007:                                                            *
 *      CERN, Switzerland                                                         *
 *      MPI-K Heidelberg, Germany                                                 *
 *      U. of Texas at Austin, USA                                                *
 *                                                                                *
 * Redistribution and use in source and binary forms, with or without             *
 * modification, are permitted according to the terms listed in LICENSE           *
 * (http://tmva.sourceforge.net/LICENSE)                                          *
 **********************************************************************************/

////////////////////////////////////////////////////////////////////////////////////////////////////////////
// CCPruner - a helper class to prune a decision tree using the Cost Complexity method                    //
// (see Classification and Regression Trees by Leo Breiman et al)                                         //
//                                                                                                        //
// Some definitions:                                                                                      //
//                                                                                                        //
// T_max - the initial, usually highly overtrained tree, that is to be pruned back                        // 
// R(T) - quality index (Gini, misclassification rate, or other) of a tree T                              //
// ~T - set of terminal nodes in T                                                                        //
// T' - the pruned subtree of T_max that has the best quality index R(T')                                 //
// alpha - the prune strength parameter in Cost Complexity pruning (R_alpha(T) = R(T) + alpha// |~T|)     //
//                                                                                                        //
// There are two running modes in CCPruner: (i) one may select a prune strength and prune back            //
// the tree T_max until the criterion                                                                     //
//             R(T) - R(t)                                                                                //
//  alpha <    ----------                                                                                 //
//             |~T_t| - 1                                                                                 //
//                                                                                                        //
// is true for all nodes t in T, or (ii) the algorithm finds the sequence of critical points              //
// alpha_k < alpha_k+1 ... < alpha_K such that T_K = root(T_max) and then selects the optimally-pruned    //
// subtree, defined to be the subtree with the best quality index for the validation sample.              //
////////////////////////////////////////////////////////////////////////////////////////////////////////////


#ifndef ROOT_TMVA_DecisionTree
#include "TMVA/DecisionTree.h"
#endif

/* #ifndef ROOT_TMVA_DecisionTreeNode */
/* #include "TMVA/DecisionTreeNode.h" */
/* #endif */

#ifndef ROOT_TMVA_Event
#include "TMVA/Event.h"
#endif

namespace TMVA {
  class DecisionTreeNode;
  class SeparationBase;

  class CCPruner{
  public: 
    typedef std::vector<Event*> EventList;

    CCPruner( DecisionTree* t_max, 
	      const EventList* validationSample = NULL,
	      SeparationBase* qualityIndex = NULL );
    ~CCPruner( );

    // set the pruning strength parameter alpha (if alpha < 0, the optimal alpha is calculated)
    void SetPruneStrength( Float_t alpha = -1.0 );

    void Optimize( );

    // return the list of pruning locations to define the optimal subtree T' of T_max
    std::vector<TMVA::DecisionTreeNode*> GetOptimalPruneSequence( ) const; 

    // return the quality index from the validation sample for the optimal subtree T'
    inline Float_t GetOptimalQualityIndex( ) const { return (fOptimalK >= 0 && fQualityIndexList.size() > 0 ?
						      fQualityIndexList[fOptimalK] : -1.0); }

    // return the prune strength (=alpha) corresponding to the prune sequence
    inline Float_t GetOptimalPruneStrength( ) const { return (fOptimalK >= 0 && fPruneStrengthList.size() > 0 ?
						       fPruneStrengthList[fOptimalK] : -1.0); }
   
  private:
    Float_t              fAlpha; //! regularization parameter in CC pruning
    const EventList*     fValidationSample; //! the event sample to select the optimally-pruned tree
    SeparationBase*      fQualityIndex; //! the quality index used to calculate R(t), R(T) = sum[t in ~T]{ R(t) }
    Bool_t               fOwnQIndex; //! flag indicates if fQualityIndex is owned by this

    DecisionTree*        fTree; //! (pruned) decision tree

    std::vector<TMVA::DecisionTreeNode*> fPruneSequence; //! map of weakest links (i.e., branches to prune) -> pruning index
    std::vector<Float_t> fPruneStrengthList;  //! map of alpha -> pruning index
    std::vector<Float_t> fQualityIndexList;   //! map of R(T) -> pruning index

    Int_t                fOptimalK;           //! index of the optimal tree in the pruned tree sequence
    Bool_t               fDebug;              //! debug flag
  };
}

inline void TMVA::CCPruner::SetPruneStrength( Float_t alpha ) {
  fAlpha = (alpha > 0 ? alpha : 0.0);
}
    

#endif


Last change: Sat Nov 1 10:21:30 2008
Last generated: 2008-11-01 10:21

This page has been automatically generated. If you have any comments or suggestions about the page layout send a mail to ROOT support, or contact the developers with any questions or problems regarding ROOT.