#include <algorithm>
#include "Riostream.h"
#include "TRandom.h"
#include "TRandom2.h"
#include "TMath.h"
#include "TObjString.h"
#include "TMVA/MethodBDT.h"
#include "TMVA/Tools.h"
#include "TMVA/Timer.h"
#include "TMVA/Ranking.h"
#include "TMVA/SdivSqrtSplusB.h"
#include "TMVA/BinarySearchTree.h"
#include "TMVA/SeparationBase.h"
#include "TMVA/GiniIndex.h"
#include "TMVA/CrossEntropy.h"
#include "TMVA/MisClassificationError.h"
#include "TMVA/CCPruner.h"
using std::vector;
ClassImp(TMVA::MethodBDT)
TMVA::MethodBDT::MethodBDT( const TString& jobName, const TString& methodTitle, DataSet& theData,
const TString& theOption, TDirectory* theTargetDir )
: TMVA::MethodBase( jobName, methodTitle, theData, theOption, theTargetDir )
{
InitBDT();
SetConfigName( TString("Method") + GetMethodName() );
DeclareOptions();
ParseOptions();
ProcessOptions();
if (HasTrainingTree()) {
fLogger << kVERBOSE << "Method has been called " << Endl;
this->InitEventSample();
}
else {
fLogger << kWARNING << "No training Tree given: you will not be allowed to call ::Train etc." << Endl;
}
BaseDir()->cd();
fBoostWeightHist = new TH1F("BoostWeight","Ada Boost weight Distribution",100,1,30);
fBoostWeightHist->SetXTitle("boost weight");
fBoostWeightVsTree = new TH1F("BoostWeightVsTree","Ada Boost weights vs tree",fNTrees,0,fNTrees);
fBoostWeightVsTree->SetXTitle("#tree");
fBoostWeightVsTree->SetYTitle("boost weight");
fErrFractHist = new TH1F("ErrFractHist","error fraction vs tree number",fNTrees,0,fNTrees);
fErrFractHist->SetXTitle("#tree");
fErrFractHist->SetYTitle("error fraction");
fNodesBeforePruningVsTree = new TH1I("NodesBeforePruning","nodes before pruning",fNTrees,0,fNTrees);
fNodesBeforePruningVsTree->SetXTitle("#tree");
fNodesBeforePruningVsTree->SetYTitle("#tree nodes");
fNodesAfterPruningVsTree = new TH1I("NodesAfterPruning","nodes after pruning",fNTrees,0,fNTrees);
fNodesAfterPruningVsTree->SetXTitle("#tree");
fNodesAfterPruningVsTree->SetYTitle("#tree nodes");
fMonitorNtuple= new TTree("MonitorNtuple","BDT variables");
fMonitorNtuple->Branch("iTree",&fITree,"iTree/I");
fMonitorNtuple->Branch("boostWeight",&fBoostWeight,"boostWeight/D");
fMonitorNtuple->Branch("errorFraction",&fErrorFraction,"errorFraction/D");
}
TMVA::MethodBDT::MethodBDT( DataSet& theData,
const TString& theWeightFile,
TDirectory* theTargetDir )
: TMVA::MethodBase( theData, theWeightFile, theTargetDir )
{
InitBDT();
DeclareOptions();
}
void TMVA::MethodBDT::DeclareOptions()
{
DeclareOptionRef(fNTrees, "NTrees", "Number of trees in the forest");
DeclareOptionRef(fBoostType, "BoostType", "Boosting type for the trees in the forest");
AddPreDefVal(TString("AdaBoost"));
AddPreDefVal(TString("Bagging"));
DeclareOptionRef(fAdaBoostBeta=1.0, "AdaBoostBeta", "Parameter for AdaBoost algorithm");
DeclareOptionRef(fRandomisedTrees,"UseRandomisedTrees","Choose at each node splitting a random set of variables");
DeclareOptionRef(fUseNvars,"UseNvars","Number of variables used if randomised Tree option is chosen");
DeclareOptionRef(fUseWeightedTrees=kTRUE, "UseWeightedTrees",
"Use weighted trees or simple average in classification from the forest");
DeclareOptionRef(fSepTypeS="GiniIndex", "SeparationType", "Separation criterion for node splitting");
DeclareOptionRef(fUseYesNoLeaf=kTRUE, "UseYesNoLeaf",
"Use Sig or Bkg categories, or the purity=S/(S+B) as classification of the leaf node");
DeclareOptionRef(fNodePurityLimit=0.5, "NodePurityLimit", "In boosting/pruning, nodes with purity > NodePurityLimit are signal; background otherwise.");
AddPreDefVal(TString("MisClassificationError"));
AddPreDefVal(TString("GiniIndex"));
AddPreDefVal(TString("CrossEntropy"));
AddPreDefVal(TString("SDivSqrtSPlusB"));
DeclareOptionRef(fNodeMinEvents, "nEventsMin", "Minimum number of events required in a leaf node (default: max(20, N_train/(Nvar^2)/10) ) ");
DeclareOptionRef(fNCuts, "nCuts", "Number of steps during node cut optimisation");
DeclareOptionRef(fPruneStrength, "PruneStrength", "Pruning strength");
DeclareOptionRef(fPruneMethodS, "PruneMethod", "Method used for pruning (removal) of statistically insignificant branches");
DeclareOptionRef(fPruneBeforeBoost=kFALSE, "PruneBeforeBoost", "Flag to prune the tree before applying boosting algorithm");
AddPreDefVal(TString("NoPruning"));
AddPreDefVal(TString("ExpectedError"));
AddPreDefVal(TString("CostComplexity"));
DeclareOptionRef(fNoNegWeightsInTraining,"NoNegWeightsInTraining","Ignore negative event weights in the training process" );
}
void TMVA::MethodBDT::ProcessOptions()
{
MethodBase::ProcessOptions();
fSepTypeS.ToLower();
if (fSepTypeS == "misclassificationerror") fSepType = new MisClassificationError();
else if (fSepTypeS == "giniindex") fSepType = new GiniIndex();
else if (fSepTypeS == "crossentropy") fSepType = new CrossEntropy();
else if (fSepTypeS == "sdivsqrtsplusb") fSepType = new SdivSqrtSplusB();
else {
fLogger << kINFO << GetOptions() << Endl;
fLogger << kFATAL << "<ProcessOptions> unknown Separation Index option called" << Endl;
}
fPruneMethodS.ToLower();
if (fPruneMethodS == "expectederror" ) fPruneMethod = DecisionTree::kExpectedErrorPruning;
else if (fPruneMethodS == "costcomplexity" ) fPruneMethod = DecisionTree::kCostComplexityPruning;
else if (fPruneMethodS == "nopruning" ) fPruneMethod = DecisionTree::kNoPruning;
else {
fLogger << kINFO << GetOptions() << Endl;
fLogger << kFATAL << "<ProcessOptions> unknown PruneMethod option called" << Endl;
}
if (fPruneStrength < 0) fAutomatic = kTRUE;
else fAutomatic = kFALSE;
if (this->Data().HasNegativeEventWeights()){
fLogger << kINFO << " You are using a Monte Carlo that has also negative weights. "
<< "That should in principle be fine as long as on average you end up with "
<< "something positive. For this you have to make sure that the minimal number "
<< "of (unweighted) events demanded for a tree node (currently you use: nEventsMin="
<<fNodeMinEvents<<", you can set this via the BDT option string when booking the "
<< "classifier) is large enough to allow for reasonable averaging!!! "
<< " If this does not help.. maybe you want to try the option: NoNegWeightsInTraining "
<< "which ignores events with negative weight in the training. " << Endl
<< Endl << "Note: You'll get a WARNING message during the training if that should ever happen" << Endl;
}
if (fRandomisedTrees){
fLogger << kINFO << " Randomised trees use *bagging* as *boost* method and no pruning" << Endl;
fPruneMethod = DecisionTree::kNoPruning;
fBoostType = "Bagging";
}
}
void TMVA::MethodBDT::InitBDT( void )
{
SetMethodName( "BDT" );
SetMethodType( Types::kBDT );
SetTestvarName();
fNTrees = 200;
fBoostType = "AdaBoost";
fNodeMinEvents = TMath::Max( 20, int( this->Data().GetNEvtTrain() / this->GetNvar()/ this->GetNvar() / 10) );
fNCuts = 20;
fPruneMethodS = "CostComplexity";
fPruneMethod = DecisionTree::kCostComplexityPruning;
fPruneStrength = 5;
fDeltaPruneStrength=0.1;
fNoNegWeightsInTraining=kFALSE;
fRandomisedTrees= kFALSE;
fUseNvars = GetNvar();
SetSignalReferenceCut( 0 );
}
TMVA::MethodBDT::~MethodBDT( void )
{
for (UInt_t i=0; i<fEventSample.size(); i++) delete fEventSample[i];
for (UInt_t i=0; i<fValidationSample.size(); i++) delete fValidationSample[i];
for (UInt_t i=0; i<fForest.size(); i++) delete fForest[i];
}
void TMVA::MethodBDT::InitEventSample( void )
{
if (!HasTrainingTree()) fLogger << kFATAL << "<Init> Data().TrainingTree() is zero pointer" << Endl;
Int_t nevents = Data().GetNEvtTrain();
Int_t ievt=0;
Bool_t first=kTRUE;
for (; ievt<nevents; ievt++) {
ReadTrainingEvent(ievt);
Event* event = new Event( GetEvent() );
if ( ! (fNoNegWeightsInTraining && event->GetWeight() < 0 ) ) {
if (first){
first = kFALSE;
fLogger << kINFO << "Events with negative event weights are ignored during the BDT training (option NoNegWeightsInTraining="<< fNoNegWeightsInTraining << Endl;
}
if ( ievt%2 == 0 || !fAutomatic ) fEventSample.push_back( event );
else fValidationSample.push_back( event );
}
}
fLogger << kINFO << "<InitEventSample> Internally I use " << fEventSample.size()
<< " for Training and " << fValidationSample.size()
<< " for Validation " << Endl;
}
void TMVA::MethodBDT::Train( void )
{
if (!CheckSanity()) fLogger << kFATAL << "<Train> sanity check failed" << Endl;
if (IsNormalised()) fLogger << kFATAL << "\"Normalise\" option cannot be used with BDT; "
<< "please remove the option from the configuration string, or "
<< "use \"!Normalise\""
<< Endl;
fLogger << kINFO << "Training "<< fNTrees << " Decision Trees ... patience please" << Endl;
Timer timer( fNTrees, GetName() );
Int_t nNodesBeforePruningCount = 0;
Int_t nNodesAfterPruningCount = 0;
Int_t nNodesBeforePruning = 0;
Int_t nNodesAfterPruning = 0;
SeparationBase *qualitySepType = new GiniIndex();
TH1D *alpha = new TH1D("alpha","PruneStrengths",fNTrees,0,fNTrees);
alpha->SetXTitle("#tree");
alpha->SetYTitle("PruneStrength");
for (int itree=0; itree<fNTrees; itree++) {
timer.DrawProgressBar( itree );
fForest.push_back( new DecisionTree( fSepType, fNodeMinEvents, fNCuts, qualitySepType,
fRandomisedTrees, fUseNvars, itree));
fForest.back()->SetNodePurityLimit(fNodePurityLimit);
nNodesBeforePruning = fForest.back()->BuildTree(fEventSample);
nNodesBeforePruningCount += nNodesBeforePruning;
fNodesBeforePruningVsTree->SetBinContent(itree+1,nNodesBeforePruning);
if(!fPruneBeforeBoost || fPruneMethod == DecisionTree::kNoPruning)
fBoostWeights.push_back( this->Boost(fEventSample, fForest.back(), itree) );
if(fPruneMethod != DecisionTree::kNoPruning) {
fForest.back()->SetPruneMethod(fPruneMethod);
if(!fAutomatic) {
fForest.back()->SetPruneStrength(fPruneStrength);
fForest.back()->PruneTree();
}
else {
if(fPruneMethod == DecisionTree::kCostComplexityPruning) {
CCPruner* pruneTool = new CCPruner(fForest.back(), &fValidationSample, fSepType);
pruneTool->Optimize();
std::vector<DecisionTreeNode*> nodes = pruneTool->GetOptimalPruneSequence();
fPruneStrength = pruneTool->GetOptimalPruneStrength();
for(UInt_t i = 0; i < nodes.size(); i++)
fForest.back()->PruneNode(nodes[i]);
delete pruneTool;
}
else {
fPruneStrength = this->PruneTree(fForest.back(), itree);
}
}
if (fUseYesNoLeaf){
fForest.back()->CleanTree();
}
nNodesAfterPruning = fForest.back()->GetNNodes();
nNodesAfterPruningCount += nNodesAfterPruning;
fNodesAfterPruningVsTree->SetBinContent(itree+1,nNodesAfterPruning);
if(fPruneBeforeBoost)
fBoostWeights.push_back( this->Boost(fEventSample, fForest.back(), itree) );
alpha->SetBinContent(itree+1,fPruneStrength);
}
fITree = itree;
fMonitorNtuple->Fill();
}
alpha->Write();
fLogger << kINFO << "<Train> elapsed time: " << timer.GetElapsedTime()
<< " " << Endl;
if (fPruneMethod == DecisionTree::kNoPruning) {
fLogger << kINFO << "<Train> average number of nodes (w/o pruning) : "
<< nNodesBeforePruningCount/fNTrees << Endl;
}
else {
fLogger << kINFO << "<Train> average number of nodes before/after pruning : "
<< nNodesBeforePruningCount/fNTrees << " / "
<< nNodesAfterPruningCount/fNTrees
<< Endl;
}
}
Double_t TMVA::MethodBDT::PruneTree( DecisionTree *dt, Int_t itree)
{
Double_t alpha = 0;
Double_t delta = fDeltaPruneStrength;
DecisionTree* dcopy;
vector<Double_t> q;
multimap<Double_t,Double_t> quality;
Int_t nnodes = dt->GetNNodes();
Bool_t forceStop = kFALSE;
Int_t troubleCount = 0,
previousNnodes = nnodes;
nnodes=dt->GetNNodes();
while (nnodes > 3 && !forceStop) {
dcopy = new DecisionTree(*dt);
dcopy->SetPruneStrength(alpha+=delta);
dcopy->PruneTree();
q.push_back(this->TestTreeQuality((dcopy)));
quality.insert(pair<const Double_t,Double_t>(q.back(),alpha));
nnodes = dcopy->GetNNodes();
if (previousNnodes == nnodes) troubleCount++;
else {
troubleCount=0;
if (nnodes < previousNnodes / 2 ) fDeltaPruneStrength /= 2.;
}
previousNnodes = nnodes;
if (troubleCount > 20) {
if (itree == 0 && fPruneStrength <=0) {
fDeltaPruneStrength *= 5;
fLogger << kINFO << "<PruneTree> trouble determining optimal prune strength"
<< " for Tree " << itree
<< " --> first try to increase the step size"
<< " currently Prunestrenght= " << alpha
<< " stepsize " << fDeltaPruneStrength << " " << Endl;
troubleCount = 0;
fPruneStrength = 1;
}
else if (itree == 0 && fPruneStrength <=2) {
fDeltaPruneStrength *= 5;
fLogger << kINFO << "<PruneTree> trouble determining optimal prune strength"
<< " for Tree " << itree
<< " --> try to increase the step size even more.. "
<< " if that stitill didn't work, TRY IT BY HAND"
<< " currently Prunestrenght= " << alpha
<< " stepsize " << fDeltaPruneStrength << " " << Endl;
troubleCount = 0;
fPruneStrength = 3;
}
else{
forceStop=kTRUE;
fLogger << kINFO << "<PruneTree> trouble determining optimal prune strength"
<< " for Tree " << itree << " at tested prune strength: " << alpha
<< " --> abort forced, use same strength as for previous tree:"
<< fPruneStrength << Endl;
}
}
if (fgDebugLevel==1) fLogger << kINFO << "Pruneed with ("<<alpha
<< ") give quality: " << q.back()
<< " and #nodes: " << nnodes
<< Endl;
delete dcopy;
}
if (!forceStop) {
multimap<Double_t,Double_t>::reverse_iterator it=quality.rend();
it++;
fPruneStrength = it->second;
fDeltaPruneStrength *= Double_t(q.size())/20.;
}
char buffer[10];
sprintf (buffer,"quad%d",itree);
TH1D *qual=new TH1D(buffer,"Quality of tree prune steps",q.size(),0.,alpha);
qual->SetXTitle("PruneStrength");
qual->SetYTitle("TreeQuality (Purity)");
for (UInt_t i=0; i< q.size(); i++) {
qual->SetBinContent(i+1,q[i]);
}
qual->Write();
dt->SetPruneStrength(fPruneStrength);
dt->PruneTree();
return fPruneStrength;
}
Double_t TMVA::MethodBDT::TestTreeQuality( DecisionTree *dt )
{
Double_t ncorrect=0, nfalse=0;
for (UInt_t ievt=0; ievt<fValidationSample.size(); ievt++) {
Bool_t isSignalType= (dt->CheckEvent(*(fValidationSample[ievt])) > fNodePurityLimit ) ? 1 : 0;
if (isSignalType == (fValidationSample[ievt]->IsSignal()) ) {
ncorrect += fValidationSample[ievt]->GetWeight();
}
else{
nfalse += fValidationSample[ievt]->GetWeight();
}
}
return ncorrect / (ncorrect + nfalse);
}
Double_t TMVA::MethodBDT::Boost( vector<TMVA::Event*> eventSample, DecisionTree *dt, Int_t iTree )
{
if (fBoostType=="AdaBoost") return this->AdaBoost(eventSample, dt);
else if (fBoostType=="Bagging") return this->Bagging(eventSample, iTree);
else {
fLogger << kINFO << GetOptions() << Endl;
fLogger << kFATAL << "<Boost> unknown boost option called" << Endl;
}
return -1;
}
Double_t TMVA::MethodBDT::AdaBoost( vector<TMVA::Event*> eventSample, DecisionTree *dt )
{
Double_t err=0, sumw=0, sumwfalse=0;
for (vector<TMVA::Event*>::iterator e=eventSample.begin(); e!=eventSample.end();e++) {
Bool_t isSignalType = (dt->CheckEvent(*(*e),fUseYesNoLeaf) > fNodePurityLimit );
Double_t w = (*e)->GetWeight();
sumw += w;
if (!(isSignalType == (*e)->IsSignal())) {
sumwfalse+= w;
}
}
err = sumwfalse/sumw;
Double_t newSumw=0;
Int_t i=0;
Double_t boostWeight=1.;
if (err > 0.5) {
fLogger << kWARNING << " The error rate in the BDT boosting is > 0.5. "
<< " That should not happen, please check your code (i.e... the BDT code), I "
<< " set it to 0.5.. just to continue.." << Endl;
err = 0.5;
} else if (err < 0) {
fLogger << kWARNING << " The error rate in the BDT boosting is < 0. That can happen"
<< " due to improper treatment of negative weights in a Monte Carlo.. (if you have"
<< " an idea on how to do it in a better way, please let me know (Helge.Voss@cern.ch)"
<< " for the time being I set it to its absolute value.. just to continue.." << Endl;
err = TMath::Abs(err);
}
if (fAdaBoostBeta == 1) {
boostWeight = (1-err)/err;
}
else {
boostWeight = TMath::Power((1.0 - err)/err, fAdaBoostBeta);
}
for (vector<TMVA::Event*>::iterator e=eventSample.begin(); e!=eventSample.end();e++) {
if (!( (dt->CheckEvent(*(*e),fUseYesNoLeaf) > fNodePurityLimit ) == (*e)->IsSignal())) {
if ( (*e)->GetWeight() > 0 ){
(*e)->SetBoostWeight( (*e)->GetBoostWeight() * boostWeight);
} else {
(*e)->SetBoostWeight( (*e)->GetBoostWeight() / boostWeight);
}
}
newSumw+=(*e)->GetWeight();
i++;
}
for (vector<TMVA::Event*>::iterator e=eventSample.begin(); e!=eventSample.end();e++) {
(*e)->SetBoostWeight( (*e)->GetBoostWeight() * sumw / newSumw );
}
fBoostWeightHist->Fill(boostWeight);
fBoostWeightVsTree->SetBinContent(fForest.size(),boostWeight);
fErrFractHist->SetBinContent(fForest.size(),err);
fBoostWeight = boostWeight;
fErrorFraction = err;
return TMath::Log(boostWeight);
}
Double_t TMVA::MethodBDT::Bagging( vector<TMVA::Event*> eventSample, Int_t iTree )
{
Double_t newSumw=0;
Double_t newWeight;
TRandom2 *trandom = new TRandom2(iTree);
for (vector<TMVA::Event*>::iterator e=eventSample.begin(); e!=eventSample.end();e++) {
newWeight = trandom->PoissonD(1);
(*e)->SetBoostWeight(newWeight);
newSumw+=(*e)->GetBoostWeight();
}
for (vector<TMVA::Event*>::iterator e=eventSample.begin(); e!=eventSample.end();e++) {
(*e)->SetBoostWeight( (*e)->GetBoostWeight() * eventSample.size() / newSumw );
}
return 1.;
}
void TMVA::MethodBDT::WriteWeightsToStream( ostream& o) const
{
o << "NTrees= " << fForest.size() <<endl;
for (UInt_t i=0; i< fForest.size(); i++) {
o << "Tree " << i << " boostWeight " << fBoostWeights[i] << endl;
(fForest[i])->Print(o);
}
}
void TMVA::MethodBDT::ReadWeightsFromStream( istream& istr )
{
TString var, dummy;
istr >> dummy >> fNTrees;
fLogger << kINFO << "Read " << fNTrees << " Decision trees" << Endl;
for (UInt_t i=0;i<fForest.size();i++) delete fForest[i];
fForest.clear();
fBoostWeights.clear();
Int_t iTree;
Double_t boostWeight;
for (int i=0;i<fNTrees;i++) {
istr >> dummy >> iTree >> dummy >> boostWeight;
if (iTree != i) {
fForest.back()->Print( cout );
fLogger << kFATAL << "Error while reading weight file; mismatch Itree="
<< iTree << " i=" << i
<< " dummy " << dummy
<< " boostweight " << boostWeight
<< Endl;
}
fForest.push_back( new DecisionTree() );
fForest.back()->Read(istr);
fBoostWeights.push_back(boostWeight);
}
}
Double_t TMVA::MethodBDT::GetMvaValue()
{
Double_t myMVA = 0;
Double_t norm = 0;
for (UInt_t itree=0; itree<fForest.size(); itree++) {
if (fUseWeightedTrees) {
myMVA += fBoostWeights[itree] * fForest[itree]->CheckEvent(GetEvent(),fUseYesNoLeaf);
norm += fBoostWeights[itree];
}
else {
myMVA += fForest[itree]->CheckEvent(GetEvent(),fUseYesNoLeaf);
norm += 1;
}
}
return myMVA /= norm;
}
void TMVA::MethodBDT::WriteMonitoringHistosToFile( void ) const
{
fLogger << kINFO << "Write monitoring histograms to file: " << BaseDir()->GetPath() << Endl;
fBoostWeightHist->Write();
fBoostWeightVsTree->Write();
fErrFractHist->Write();
fNodesBeforePruningVsTree->Write();
fNodesAfterPruningVsTree->Write();
fMonitorNtuple->Write();
}
vector< Double_t > TMVA::MethodBDT::GetVariableImportance()
{
fVariableImportance.resize(GetNvar());
Double_t sum=0;
for (int itree = 0; itree < fNTrees; itree++) {
vector<Double_t> relativeImportance(fForest[itree]->GetVariableImportance());
for (UInt_t i=0; i< relativeImportance.size(); i++) {
fVariableImportance[i] += relativeImportance[i];
}
}
for (UInt_t i=0; i< fVariableImportance.size(); i++) sum += fVariableImportance[i];
for (UInt_t i=0; i< fVariableImportance.size(); i++) fVariableImportance[i] /= sum;
return fVariableImportance;
}
Double_t TMVA::MethodBDT::GetVariableImportance( UInt_t ivar )
{
vector<Double_t> relativeImportance = this->GetVariableImportance();
if (ivar < (UInt_t)relativeImportance.size()) return relativeImportance[ivar];
else fLogger << kFATAL << "<GetVariableImportance> ivar = " << ivar << " is out of range " << Endl;
return -1;
}
const TMVA::Ranking* TMVA::MethodBDT::CreateRanking()
{
fRanking = new Ranking( GetName(), "Variable Importance" );
vector< Double_t> importance(this->GetVariableImportance());
for (Int_t ivar=0; ivar<GetNvar(); ivar++) {
fRanking->AddRank( *new Rank( GetInputExp(ivar), importance[ivar] ) );
}
return fRanking;
}
void TMVA::MethodBDT::GetHelpMessage() const
{
fLogger << Endl;
fLogger << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
fLogger << Endl;
fLogger << "Boosted Decision Trees are a collection of individual decision" << Endl;
fLogger << "trees which form a multivariate classifier by (weighted) majority " << Endl;
fLogger << "vote of the individual trees. Consecutive decision trees are " << Endl;
fLogger << "trained using the original training data set with re-weighted " << Endl;
fLogger << "events. By default, the AdaBoost method is employed, which gives " << Endl;
fLogger << "events that were misclassified in the previous tree a larger " << Endl;
fLogger << "weight in the training of the following tree." << Endl;
fLogger << Endl;
fLogger << "Decision trees are a sequence of binary splits of the data sample" << Endl;
fLogger << "using a single descriminant variable at a time. A test event " << Endl;
fLogger << "ending up after the sequence of left-right splits in a final " << Endl;
fLogger << "(\"leaf\") node is classified as either signal or background" << Endl;
fLogger << "depending on the majority type of training events in that node." << Endl;
fLogger << Endl;
fLogger << gTools().Color("bold") << "--- Performance optimisation:" << gTools().Color("reset") << Endl;
fLogger << Endl;
fLogger << "By the nature of the binary splits performed on the individual" << Endl;
fLogger << "variables, decision trees do not deal well with linear correlations" << Endl;
fLogger << "between variables (they need to approximate the linear split in" << Endl;
fLogger << "the two dimensional space by a sequence of splits on the two " << Endl;
fLogger << "variables individually). Hence decorrelation could be useful " << Endl;
fLogger << "to optimise the BDT performance." << Endl;
fLogger << Endl;
fLogger << gTools().Color("bold") << "--- Performance tuning via configuration options:" << gTools().Color("reset") << Endl;
fLogger << Endl;
fLogger << "The two most important parameters in the configuration are the " << Endl;
fLogger << "minimal number of events requested by a leaf node (option " << Endl;
fLogger << "\"nEventsMin\"). If this number is too large, detailed features " << Endl;
fLogger << "in the parameter space cannot be modeled. If it is too small, " << Endl;
fLogger << "the risk to overtain rises." << Endl;
fLogger << " (Imagine the decision tree is split until the leaf node contains" << Endl;
fLogger << " only a single event. In such a case, no training event is " << Endl;
fLogger << " misclassified, while the situation will look very different" << Endl;
fLogger << " for the test sample.)" << Endl;
fLogger << Endl;
fLogger << "The default minumal number is currently set to " << Endl;
fLogger << " max(20, (N_training_events / N_variables^2 / 10) " << Endl;
fLogger << "and can be changed by the user." << Endl;
fLogger << Endl;
fLogger << "The other crucial paramter, the pruning strength (\"PruneStrength\")," << Endl;
fLogger << "is also related to overtraining. It is a regularistion parameter " << Endl;
fLogger << "that is used when determining after the training which splits " << Endl;
fLogger << "are considered statistically insignificant and are removed. The" << Endl;
fLogger << "user is advised to carefully watch the BDT screen output for" << Endl;
fLogger << "the comparison between efficiencies obtained on the training and" << Endl;
fLogger << "the independent test sample. They should be equal within statistical" << Endl;
fLogger << "errors." << Endl;
}
void TMVA::MethodBDT::MakeClassSpecific( std::ostream& fout, const TString& className ) const
{
fout << " std::vector<BDT_DecisionTreeNode*> fForest; // i.e. root nodes of decision trees" << endl;
fout << " std::vector<double> fBoostWeights; // the weights applied in the individual boosts" << endl;
fout << "};" << endl << endl;
fout << "double " << className << "::GetMvaValue__( const std::vector<double>& inputValues ) const" << endl;
fout << "{" << endl;
fout << " double myMVA = 0;" << endl;
fout << " double norm = 0;" << endl;
fout << " for (unsigned int itree=0; itree<fForest.size(); itree++){" << endl;
fout << " BDT_DecisionTreeNode *current = fForest[itree];" << endl;
fout << " while (current->GetNodeType() == 0) { //intermediate node" << endl;
fout << " if (current->GoesRight(inputValues)) current=(BDT_DecisionTreeNode*)current->GetRight();" << endl;
fout << " else current=(BDT_DecisionTreeNode*)current->GetLeft();" << endl;
fout << " }" << endl;
if (fUseWeightedTrees) {
if (fUseYesNoLeaf) fout << " myMVA += fBoostWeights[itree] * current->GetNodeType();" << endl;
else fout << " myMVA += fBoostWeights[itree] * current->GetPurity();" << endl;
fout << " norm += fBoostWeights[itree];" << endl;
}
else {
if (fUseYesNoLeaf) fout << " myMVA += current->GetNodeType();" << endl;
else fout << " myMVA += current->GetPurity();" << endl;
fout << " norm += 1.;" << endl;
}
fout << " }" << endl;
fout << " return myMVA /= norm;" << endl;
fout << "};" << endl << endl;
fout << "void " << className << "::Initialize()" << endl;
fout << "{" << endl;
for (int itree=0; itree<fNTrees; itree++) {
fout << " // itree = " << itree << endl;
fout << " fBoostWeights.push_back(" << fBoostWeights[itree] << ");" << endl;
fout << " fForest.push_back( " << endl;
this->MakeClassInstantiateNode((DecisionTreeNode*)fForest[itree]->GetRoot(), fout, className);
fout <<" );" << endl;
}
fout << " return;" << endl;
fout << "};" << endl;
fout << " " << endl;
fout << "// Clean up" << endl;
fout << "inline void " << className << "::Clear() " << endl;
fout << "{" << endl;
fout << " for (unsigned int itree=0; itree<fForest.size(); itree++) { " << endl;
fout << " delete fForest[itree]; " << endl;
fout << " }" << endl;
fout << "}" << endl;
}
void TMVA::MethodBDT::MakeClassSpecificHeader( std::ostream& fout, const TString& ) const
{
fout << "#ifndef NN" << endl;
fout << "#define NN new BDT_DecisionTreeNode" << endl;
fout << "#endif" << endl;
fout << " " << endl;
fout << "#ifndef BDT_DecisionTreeNode__def" << endl;
fout << "#define BDT_DecisionTreeNode__def" << endl;
fout << " " << endl;
fout << "class BDT_DecisionTreeNode {" << endl;
fout << " " << endl;
fout << "public:" << endl;
fout << " " << endl;
fout << " // constructor of an essentially \"empty\" node floating in space" << endl;
fout << " BDT_DecisionTreeNode ( BDT_DecisionTreeNode* left," << endl;
fout << " BDT_DecisionTreeNode* right," << endl;
fout << " double cutValue, bool cutType, int selector," << endl;
fout << " int nodeType, double purity ) :" << endl;
fout << " fLeft ( left )," << endl;
fout << " fRight ( right )," << endl;
fout << " fCutValue( cutValue )," << endl;
fout << " fCutType ( cutType )," << endl;
fout << " fSelector( selector )," << endl;
fout << " fNodeType( nodeType )," << endl;
fout << " fPurity ( purity ) {}" << endl << endl;
fout << " virtual ~BDT_DecisionTreeNode();" << endl << endl;
fout << " // test event if it decends the tree at this node to the right" << endl;
fout << " virtual bool GoesRight( const std::vector<double>& inputValues ) const;" << endl;
fout << " BDT_DecisionTreeNode* GetRight( void ) {return fRight; };" << endl << endl;
fout << " // test event if it decends the tree at this node to the left " << endl;
fout << " virtual bool GoesLeft ( const std::vector<double>& inputValues ) const;" << endl;
fout << " BDT_DecisionTreeNode* GetLeft( void ) { return fLeft; }; " << endl << endl;
fout << " // return S/(S+B) (purity) at this node (from training)" << endl << endl;
fout << " double GetPurity( void ) const { return fPurity; } " << endl;
fout << " // return the node type" << endl;
fout << " int GetNodeType( void ) const { return fNodeType; }" << endl << endl;
fout << "private:" << endl << endl;
fout << " BDT_DecisionTreeNode* fLeft; // pointer to the left daughter node" << endl;
fout << " BDT_DecisionTreeNode* fRight; // pointer to the right daughter node" << endl;
fout << " double fCutValue; // cut value appplied on this node to discriminate bkg against sig" << endl;
fout << " bool fCutType; // true: if event variable > cutValue ==> signal , false otherwise" << endl;
fout << " int fSelector; // index of variable used in node selection (decision tree) " << endl;
fout << " int fNodeType; // Type of node: -1 == Bkg-leaf, 1 == Signal-leaf, 0 = internal " << endl;
fout << " double fPurity; // Purity of node from training"<< endl;
fout << "}; " << endl;
fout << " " << endl;
fout << "//_______________________________________________________________________" << endl;
fout << "BDT_DecisionTreeNode::~BDT_DecisionTreeNode()" << endl;
fout << "{" << endl;
fout << " if (fLeft != NULL) delete fLeft;" << endl;
fout << " if (fRight != NULL) delete fRight;" << endl;
fout << "}; " << endl;
fout << " " << endl;
fout << "//_______________________________________________________________________" << endl;
fout << "bool BDT_DecisionTreeNode::GoesRight( const std::vector<double>& inputValues ) const" << endl;
fout << "{" << endl;
fout << " // test event if it decends the tree at this node to the right" << endl;
fout << " bool result = (inputValues[fSelector] > fCutValue );" << endl;
fout << " if (fCutType == true) return result; //the cuts are selecting Signal ;" << endl;
fout << " else return !result;" << endl;
fout << "}" << endl;
fout << " " << endl;
fout << "//_______________________________________________________________________" << endl;
fout << "bool BDT_DecisionTreeNode::GoesLeft( const std::vector<double>& inputValues ) const" << endl;
fout << "{" << endl;
fout << " // test event if it decends the tree at this node to the left" << endl;
fout << " if (!this->GoesRight(inputValues)) return true;" << endl;
fout << " else return false;" << endl;
fout << "}" << endl;
fout << " " << endl;
fout << "#endif" << endl;
fout << " " << endl;
}
void TMVA::MethodBDT::MakeClassInstantiateNode( DecisionTreeNode *n, std::ostream& fout, const TString& className ) const
{
if (n == NULL) {
fLogger << kFATAL << "MakeClassInstantiateNode: started with undefined node" <<Endl;
return ;
}
fout << "NN("<<endl;
if (n->GetLeft() != NULL){
this->MakeClassInstantiateNode( (DecisionTreeNode*)n->GetLeft() , fout, className);
}
else {
fout << "0";
}
fout << ", " <<endl;
if (n->GetRight() != NULL){
this->MakeClassInstantiateNode( (DecisionTreeNode*)n->GetRight(), fout, className );
}
else {
fout << "0";
}
fout << ", " << endl
<< setprecision(6)
<< n->GetCutValue() << ", "
<< n->GetCutType() << ", "
<< n->GetSelector() << ", "
<< n->GetNodeType() << ", "
<< n->GetPurity() << ") ";
}
Last change: Sat Nov 1 10:21:39 2008
Last generated: 2008-11-01 10:21
This page has been automatically generated. If you have any comments or suggestions about the page layout send a mail to ROOT support, or contact the developers with any questions or problems regarding ROOT.