#include <iostream>
#include <algorithm>
#include <vector>
#include "TMath.h"
#include "TMVA/DecisionTree.h"
#include "TMVA/DecisionTreeNode.h"
#include "TMVA/BinarySearchTree.h"
#include "TMVA/Tools.h"
#include "TMVA/GiniIndex.h"
#include "TMVA/CrossEntropy.h"
#include "TMVA/MisClassificationError.h"
#include "TMVA/SdivSqrtSplusB.h"
#include "TMVA/Event.h"
#include "TMVA/BDTEventWrapper.h"
#include "TMVA/CCPruner.h"
using std::vector;
ClassImp(TMVA::DecisionTree)
TMVA::DecisionTree::DecisionTree( void )
: BinaryTree(),
fNvars (0),
fNCuts (-1),
fSepType (NULL),
fMinSize (0),
fPruneMethod(kCostComplexityPruning),
fNodePurityLimit(0.5),
fRandomisedTree (kFALSE),
fUseNvars (0),
fMyTrandom (NULL),
fQualityIndex(NULL)
{
fLogger.SetSource( "DecisionTree" );
fMyTrandom = new TRandom2(0);
}
TMVA::DecisionTree::DecisionTree( DecisionTreeNode* n )
: BinaryTree(),
fNvars (0),
fNCuts (-1),
fSepType (NULL),
fMinSize (0),
fPruneMethod(kCostComplexityPruning),
fNodePurityLimit(0.5),
fRandomisedTree (kFALSE),
fUseNvars (0),
fMyTrandom (NULL),
fQualityIndex(NULL)
{
fLogger.SetSource( "DecisionTree" );
this->SetRoot( n );
this->SetParentTreeInNodes();
fLogger.SetSource( "DecisionTree" );
}
TMVA::DecisionTree::DecisionTree( TMVA::SeparationBase *sepType,Int_t minSize,
Int_t nCuts, TMVA::SeparationBase *qtype,
Bool_t randomisedTree, Int_t useNvars, Int_t iSeed):
BinaryTree(),
fNvars (0),
fNCuts (nCuts),
fSepType (sepType),
fMinSize (minSize),
fPruneMethod(kCostComplexityPruning),
fNodePurityLimit(0.5),
fRandomisedTree (randomisedTree),
fUseNvars (useNvars),
fMyTrandom (NULL),
fQualityIndex(qtype)
{
fLogger.SetSource( "DecisionTree" );
fMyTrandom = new TRandom2(iSeed);
}
TMVA::DecisionTree::DecisionTree( const DecisionTree &d):
BinaryTree(),
fNvars (d.fNvars),
fNCuts (d.fNCuts),
fSepType (d.fSepType),
fMinSize (d.fMinSize),
fPruneMethod(d.fPruneMethod),
fNodePurityLimit(0.5),
fRandomisedTree (d.fRandomisedTree),
fUseNvars (d.fUseNvars),
fMyTrandom (NULL),
fQualityIndex(d.fQualityIndex)
{
this->SetRoot( new DecisionTreeNode ( *((DecisionTreeNode*)(d.GetRoot())) ) );
this->SetParentTreeInNodes();
fNNodes = d.fNNodes;
fLogger.SetSource( "DecisionTree" );
}
TMVA::DecisionTree::~DecisionTree( void )
{
if (fMyTrandom) delete fMyTrandom;
}
void TMVA::DecisionTree::SetParentTreeInNodes( DecisionTreeNode *n)
{
if (n == NULL){
n = (DecisionTreeNode*) this->GetRoot();
if (n == NULL) {
fLogger << kFATAL << "SetParentTreeNodes: started with undefined ROOT node" <<Endl;
return ;
}
}
if ((this->GetLeftDaughter(n) == NULL) && (this->GetRightDaughter(n) != NULL) ) {
fLogger << kFATAL << " Node with only one daughter?? Something went wrong" << Endl;
return;
} else if ((this->GetLeftDaughter(n) != NULL) && (this->GetRightDaughter(n) == NULL) ) {
fLogger << kFATAL << " Node with only one daughter?? Something went wrong" << Endl;
return;
}
else {
if (this->GetLeftDaughter(n) != NULL){
this->SetParentTreeInNodes( this->GetLeftDaughter(n) );
}
if (this->GetRightDaughter(n) != NULL) {
this->SetParentTreeInNodes( this->GetRightDaughter(n) );
}
}
n->SetParentTree(this);
if (n->GetDepth() > this->GetTotalTreeDepth()) this->SetTotalTreeDepth(n->GetDepth());
return;
}
Int_t TMVA::DecisionTree::BuildTree( vector<TMVA::Event*> & eventSample,
TMVA::DecisionTreeNode *node)
{
if (node==NULL) {
node = new TMVA::DecisionTreeNode();
fNNodes = 1;
this->SetRoot(node);
this->GetRoot()->SetPos('s');
this->GetRoot()->SetDepth(0);
this->GetRoot()->SetParentTree(this);
}
UInt_t nevents = eventSample.size();
if (nevents > 0 ) {
fNvars = eventSample[0]->GetNVars();
fVariableImportance.resize(fNvars);
}
else fLogger << kFATAL << ":<BuildTree> eventsample Size == 0 " << Endl;
Double_t s=0, b=0;
Double_t suw=0, buw=0;
for (UInt_t i=0; i<eventSample.size(); i++){
if (eventSample[i]->IsSignal()){
s += eventSample[i]->GetWeight();
suw += 1;
}
else {
b += eventSample[i]->GetWeight();
buw += 1;
}
}
if (s+b < 0){
fLogger << kWARNING << " One of the Decision Tree nodes has negative total number of signal or background events. "
<< "(Nsig="<<s<<" Nbkg="<<b<<" Probaby you use a Monte Carlo with negative weights. That should in principle "
<< "be fine as long as on average you end up with something positive. For this you have to make sure that the "
<< "minimul number of (unweighted) events demanded for a tree node (currently you use: nEventsMin="<<fMinSize
<< ", you can set this via the BDT option string when booking the classifier) is large enough to allow for "
<< "reasonable averaging!!!" << Endl
<< " If this does not help.. maybe you want to try the option: NoNegWeightsInTraining which ignores events "
<< "with negative weight in the training." << Endl;
double nBkg=0.;
for (UInt_t i=0; i<eventSample.size(); i++){
if (! eventSample[i]->IsSignal()){
nBkg += eventSample[i]->GetWeight();
cout << "Event "<< i<< " has (original) weight: " << eventSample[i]->GetWeight()/eventSample[i]->GetBoostWeight()
<< " boostWeight: " << eventSample[i]->GetBoostWeight() << endl;
}
}
cout << " that gives in total: " << nBkg<<endl;
}
node->SetNSigEvents(s);
node->SetNBkgEvents(b);
node->SetNSigEvents_unweighted(suw);
node->SetNBkgEvents_unweighted(buw);
if (node == this->GetRoot()) {
node->SetNEvents(s+b);
node->SetNEvents_unweighted(suw+buw);
}
node->SetSeparationIndex(fSepType->GetSeparationIndex(s,b));
if ( eventSample.size() >= 2*fMinSize){
Double_t separationGain;
if(fNCuts > 0) separationGain = this->TrainNodeFast(eventSample, node);
else separationGain = this->TrainNodeFull(eventSample, node);
if (separationGain < std::numeric_limits<double>::epsilon()) {
if (node->GetPurity() > fNodePurityLimit) node->SetNodeType(1);
else node->SetNodeType(-1);
if (node->GetDepth() > this->GetTotalTreeDepth()) this->SetTotalTreeDepth(node->GetDepth());
}
else {
vector<TMVA::Event*> leftSample; leftSample.reserve(nevents);
vector<TMVA::Event*> rightSample; rightSample.reserve(nevents);
Double_t nRight=0, nLeft=0;
for (UInt_t ie=0; ie< nevents ; ie++){
if (node->GoesRight(*eventSample[ie])){
rightSample.push_back(eventSample[ie]);
nRight += eventSample[ie]->GetWeight();
}
else {
leftSample.push_back(eventSample[ie]);
nLeft += eventSample[ie]->GetWeight();
}
}
if (leftSample.size() == 0 || rightSample.size() == 0) {
fLogger << kFATAL << "<TrainNode> all events went to the same branch" << Endl
<< "--- Hence new node == old node ... check" << Endl
<< "--- left:" << leftSample.size()
<< " right:" << rightSample.size() << Endl
<< "--- this should never happen, please write a bug report to Helge.Voss@cern.ch"
<< Endl;
}
TMVA::DecisionTreeNode *rightNode = new TMVA::DecisionTreeNode(node,'r');
fNNodes++;
rightNode->SetNEvents(nRight);
rightNode->SetNEvents_unweighted(rightSample.size());
TMVA::DecisionTreeNode *leftNode = new TMVA::DecisionTreeNode(node,'l');
fNNodes++;
leftNode->SetNEvents(nLeft);
leftNode->SetNEvents_unweighted(leftSample.size());
node->SetNodeType(0);
node->SetLeft(leftNode);
node->SetRight(rightNode);
this->BuildTree(rightSample, rightNode);
this->BuildTree(leftSample, leftNode );
}
}
else{
if (node->GetPurity() > fNodePurityLimit) node->SetNodeType(1);
else node->SetNodeType(-1);
if (node->GetDepth() > this->GetTotalTreeDepth()) this->SetTotalTreeDepth(node->GetDepth());
}
return fNNodes;
}
void TMVA::DecisionTree::FillTree( vector<TMVA::Event*> & eventSample)
{
for (UInt_t i=0; i<eventSample.size(); i++){
this->FillEvent(*(eventSample[i]),NULL);
}
}
void TMVA::DecisionTree::FillEvent( TMVA::Event & event,
TMVA::DecisionTreeNode *node )
{
if (node == NULL) {
node = (TMVA::DecisionTreeNode*)this->GetRoot();
}
node->IncrementNEvents( event.GetWeight() );
node->IncrementNEvents_unweighted( );
if (event.IsSignal()){
node->IncrementNSigEvents( event.GetWeight() );
node->IncrementNSigEvents_unweighted( );
}
else {
node->IncrementNBkgEvents( event.GetWeight() );
node->IncrementNSigEvents_unweighted( );
}
node->SetSeparationIndex(fSepType->GetSeparationIndex(node->GetNSigEvents(),
node->GetNBkgEvents()));
if (node->GetNodeType() == 0){
if (node->GoesRight(event))
this->FillEvent(event,(TMVA::DecisionTreeNode*)(node->GetRight())) ;
else
this->FillEvent(event,(TMVA::DecisionTreeNode*)(node->GetLeft())) ;
}
}
void TMVA::DecisionTree::ClearTree()
{
if (this->GetRoot()!=NULL)
((DecisionTreeNode*)(this->GetRoot()))->ClearNodeAndAllDaughters();
}
void TMVA::DecisionTree::CleanTree(DecisionTreeNode *node)
{
if (node==NULL){
node = (DecisionTreeNode *)this->GetRoot();
}
DecisionTreeNode *l = (DecisionTreeNode*)node->GetLeft();
DecisionTreeNode *r = (DecisionTreeNode*)node->GetRight();
if (node->GetNodeType() == 0){
this->CleanTree(l);
this->CleanTree(r);
if (l->GetNodeType() * r->GetNodeType() > 0 ){
this->PruneNode(node);
}
}
}
void TMVA::DecisionTree::PruneTree()
{
if (fPruneMethod == kExpectedErrorPruning) this->PruneTreeEEP((DecisionTreeNode *)this->GetRoot());
else if (fPruneMethod == kCostComplexityPruning) this->PruneTreeCC();
else {
fLogger << kFATAL << "Selected pruning method not yet implemented "
<< Endl;
}
this->CountNodes();
};
void TMVA::DecisionTree::PruneTreeEEP(DecisionTreeNode *node)
{
DecisionTreeNode *l = (DecisionTreeNode*)node->GetLeft();
DecisionTreeNode *r = (DecisionTreeNode*)node->GetRight();
if (node->GetNodeType() == 0){
this->PruneTreeEEP(l);
this->PruneTreeEEP(r);
if (this->GetSubTreeError(node)*fPruneStrength >= this->GetNodeError(node)) {
this->PruneNode(node);
}
}
}
void TMVA::DecisionTree::PruneTreeCC()
{
CCPruner* pruneTool = new CCPruner(this, NULL, fSepType);
pruneTool->SetPruneStrength(fPruneStrength);
pruneTool->Optimize();
std::vector<DecisionTreeNode*> nodes = pruneTool->GetOptimalPruneSequence();
for(UInt_t i = 0; i < nodes.size(); i++)
this->PruneNode(nodes[i]);
delete pruneTool;
}
UInt_t TMVA::DecisionTree::CountLeafNodes(TMVA::DecisionTreeNode *n)
{
if (n == NULL){
n = (DecisionTreeNode*) this->GetRoot();
if (n == NULL) {
fLogger << kFATAL << "CountLeafNodes: started with undefined ROOT node" <<Endl;
return 0;
}
}
UInt_t countLeafs=0;
if ((this->GetLeftDaughter(n) == NULL) && (this->GetRightDaughter(n) == NULL) ) {
countLeafs += 1;
}
else {
if (this->GetLeftDaughter(n) != NULL){
countLeafs += this->CountLeafNodes( this->GetLeftDaughter(n) );
}
if (this->GetRightDaughter(n) != NULL) {
countLeafs += this->CountLeafNodes( this->GetRightDaughter(n) );
}
}
return countLeafs;
}
void TMVA::DecisionTree::DescendTree( DecisionTreeNode *n)
{
if (n == NULL){
n = (DecisionTreeNode*) this->GetRoot();
if (n == NULL) {
fLogger << kFATAL << "DescendTree: started with undefined ROOT node" <<Endl;
return ;
}
}
if ((this->GetLeftDaughter(n) == NULL) && (this->GetRightDaughter(n) == NULL) ) {
}
else if ((this->GetLeftDaughter(n) == NULL) && (this->GetRightDaughter(n) != NULL) ) {
fLogger << kFATAL << " Node with only one daughter?? Something went wrong" << Endl;
return;
}
else if ((this->GetLeftDaughter(n) != NULL) && (this->GetRightDaughter(n) == NULL) ) {
fLogger << kFATAL << " Node with only one daughter?? Something went wrong" << Endl;
return;
}
else {
if (this->GetLeftDaughter(n) != NULL){
this->DescendTree( this->GetLeftDaughter(n) );
}
if (this->GetRightDaughter(n) != NULL) {
this->DescendTree( this->GetRightDaughter(n) );
}
}
}
void TMVA::DecisionTree::PruneNode(DecisionTreeNode *node)
{
DecisionTreeNode *l = (DecisionTreeNode*)node->GetLeft();
DecisionTreeNode *r = (DecisionTreeNode*)node->GetRight();
node->SetRight(NULL);
node->SetLeft(NULL);
node->SetSelector(-1);
node->SetSeparationIndex(-1);
node->SetSeparationGain(-1);
if (node->GetPurity() > fNodePurityLimit) node->SetNodeType(1);
else node->SetNodeType(-1);
this->DeleteNode(l);
this->DeleteNode(r);
this->CountNodes();
}
Double_t TMVA::DecisionTree::GetNodeError(DecisionTreeNode *node)
{
Double_t errorRate = 0;
Double_t nEvts = node->GetNEvents();
Double_t f=0;
if (node->GetPurity() > fNodePurityLimit) f = node->GetPurity();
else f = (1-node->GetPurity());
Double_t df = TMath::Sqrt(f*(1-f)/nEvts );
errorRate = std::min(1.,(1 - (f-df) ));
return errorRate;
}
Double_t TMVA::DecisionTree::GetSubTreeError(DecisionTreeNode *node)
{
DecisionTreeNode *l = (DecisionTreeNode*)node->GetLeft();
DecisionTreeNode *r = (DecisionTreeNode*)node->GetRight();
if (node->GetNodeType() == 0) {
Double_t subTreeError =
(l->GetNEvents() * this->GetSubTreeError(l) +
r->GetNEvents() * this->GetSubTreeError(r)) /
node->GetNEvents();
return subTreeError;
}
else {
return this->GetNodeError(node);
}
}
TMVA::DecisionTreeNode* TMVA::DecisionTree::GetLeftDaughter( DecisionTreeNode *n)
{
return (DecisionTreeNode*) n->GetLeft();
}
TMVA::DecisionTreeNode* TMVA::DecisionTree::GetRightDaughter( DecisionTreeNode *n)
{
return (DecisionTreeNode*) n->GetRight();
}
TMVA::DecisionTreeNode* TMVA::DecisionTree::GetNode(ULong_t sequence, UInt_t depth)
{
DecisionTreeNode* current = (DecisionTreeNode*) this->GetRoot();
for (UInt_t i =0; i < depth; i++){
ULong_t tmp = 1 << i;
if ( tmp & sequence) current = this->GetRightDaughter(current);
else current = this->GetLeftDaughter(current);
}
return current;
}
void TMVA::DecisionTree::FindMinAndMax(vector<TMVA::Event*> & eventSample,
vector<Double_t> & xmin,
vector<Double_t> & xmax)
{
UInt_t num_events = eventSample.size();
for (Int_t ivar=0; ivar < fNvars; ivar++){
xmin[ivar]=xmax[ivar]=eventSample[0]->GetVal(ivar);
}
for (UInt_t i=1;i<num_events;i++){
for (Int_t ivar=0; ivar < fNvars; ivar++){
if (xmin[ivar]>eventSample[i]->GetVal(ivar))
xmin[ivar]=eventSample[i]->GetVal(ivar);
if (xmax[ivar]<eventSample[i]->GetVal(ivar))
xmax[ivar]=eventSample[i]->GetVal(ivar);
}
}
};
void TMVA::DecisionTree::SetCutPoints(vector<Double_t> & cut_points,
Double_t xmin,
Double_t xmax,
Int_t num_gridpoints)
{
Double_t step = (xmax - xmin)/num_gridpoints;
Double_t x = xmin + step/2;
for (Int_t j=0; j < num_gridpoints; j++){
cut_points[j] = x;
x += step;
}
};
Double_t TMVA::DecisionTree::TrainNodeFast(vector<TMVA::Event*> & eventSample,
TMVA::DecisionTreeNode *node)
{
vector<Double_t> *xmin = new vector<Double_t>( fNvars );
vector<Double_t> *xmax = new vector<Double_t>( fNvars );
Double_t separationGain = -1, sepTmp;
Double_t cutValue=-999;
Int_t mxVar=-1, cutIndex=0;
Bool_t cutType=kTRUE;
Double_t nTotS, nTotB;
Int_t nTotS_unWeighted, nTotB_unWeighted;
UInt_t nevents = eventSample.size();
for (int ivar=0; ivar < fNvars; ivar++){
(*xmin)[ivar]=(*xmax)[ivar]=eventSample[0]->GetVal(ivar);
}
for (UInt_t iev=1;iev<nevents;iev++){
for (Int_t ivar=0; ivar < fNvars; ivar++){
Double_t eventData = eventSample[iev]->GetVal(ivar);
if ((*xmin)[ivar]>eventData)(*xmin)[ivar]=eventData;
if ((*xmax)[ivar]<eventData)(*xmax)[ivar]=eventData;
}
}
vector< vector<Double_t> > nSelS (fNvars);
vector< vector<Double_t> > nSelB (fNvars);
vector< vector<Int_t> > nSelS_unWeighted (fNvars);
vector< vector<Int_t> > nSelB_unWeighted (fNvars);
vector< vector<Double_t> > significance (fNvars);
vector< vector<Double_t> > cutValues(fNvars);
vector< vector<Bool_t> > cutTypes(fNvars);
vector<Bool_t> useVariable(fNvars);
for (int ivar=0; ivar < fNvars; ivar++) useVariable[ivar]=kFALSE;
if (fRandomisedTree) {
if (fUseNvars==0) {
if (fNvars < 12) fUseNvars = TMath::Max(2,Int_t( Float_t(fNvars) / 2.5 ));
else if (fNvars < 40) fUseNvars = Int_t( Float_t(fNvars) / 5 );
else fUseNvars = Int_t( Float_t(fNvars) / 10 );
}
Int_t nSelectedVars = 0;
while ( nSelectedVars < fUseNvars ){
Double_t bla = fMyTrandom->Rndm()*fNvars;
useVariable[Int_t (bla)] = kTRUE;
for (int ivar=0; ivar < fNvars; ivar++) {
if (useVariable[ivar] == kTRUE) nSelectedVars++;
}
}
} else {
for (int ivar=0; ivar < fNvars; ivar++) useVariable[ivar] = kTRUE;
}
for (int ivar=0; ivar < fNvars; ivar++){
if ( useVariable[ivar] ) {
cutValues[ivar].resize(fNCuts);
cutTypes[ivar].resize(fNCuts);
nSelS[ivar].resize(fNCuts);
nSelB[ivar].resize(fNCuts);
nSelS_unWeighted[ivar].resize(fNCuts);
nSelB_unWeighted[ivar].resize(fNCuts);
significance[ivar].resize(fNCuts);
Double_t istepSize =( (*xmax)[ivar] - (*xmin)[ivar] ) / Double_t(fNCuts);
for (Int_t icut=0; icut<fNCuts; icut++){
cutValues[ivar][icut]=(*xmin)[ivar]+(Float_t(icut)+0.5)*istepSize;
}
}
}
nTotS=0; nTotB=0;
nTotS_unWeighted=0; nTotB_unWeighted=0;
for (UInt_t iev=0; iev<nevents; iev++){
Int_t eventType = eventSample[iev]->Type();
Double_t eventWeight = eventSample[iev]->GetWeight();
if (eventType==1){
nTotS+=eventWeight;
nTotS_unWeighted++;
}
else {
nTotB+=eventWeight;
nTotB_unWeighted++;
}
for (int ivar=0; ivar < fNvars; ivar++){
if ( useVariable[ivar] ) {
Double_t eventData = eventSample[iev]->GetVal(ivar);
for (Int_t icut=0; icut<fNCuts; icut++){
if (eventData > cutValues[ivar][icut]){
if (eventType==1) {
nSelS[ivar][icut]+=eventWeight;
nSelS_unWeighted[ivar][icut]++;
}
else {
nSelB[ivar][icut]+=eventWeight;
nSelB_unWeighted[ivar][icut]++;
}
}
}
}
}
}
for (int ivar=0; ivar < fNvars; ivar++) {
if ( useVariable[ivar] ){
for (Int_t icut=0; icut<fNCuts; icut++){
if ( (nSelS_unWeighted[ivar][icut] + nSelB_unWeighted[ivar][icut]) >= fMinSize &&
(( nTotS_unWeighted+nTotB_unWeighted)-
(nSelS_unWeighted[ivar][icut] + nSelB_unWeighted[ivar][icut])) >= fMinSize) {
sepTmp = fSepType->GetSeparationGain(nSelS[ivar][icut], nSelB[ivar][icut], nTotS, nTotB);
if (separationGain < sepTmp) {
separationGain = sepTmp;
mxVar = ivar;
cutIndex = icut;
}
}
}
}
}
if (mxVar >= 0) {
if (nSelS[mxVar][cutIndex]/nTotS > nSelB[mxVar][cutIndex]/nTotB) cutType=kTRUE;
else cutType=kFALSE;
cutValue = cutValues[mxVar][cutIndex];
node->SetSelector((UInt_t)mxVar);
node->SetCutValue(cutValue);
node->SetCutType(cutType);
node->SetSeparationGain(separationGain);
fVariableImportance[mxVar] += separationGain*separationGain * (nTotS+nTotB) * (nTotS+nTotB) ;
}
else {
separationGain = 0;
}
delete xmin;
delete xmax;
return separationGain;
}
Double_t TMVA::DecisionTree::TrainNodeFull(vector<TMVA::Event*> & eventSample,
TMVA::DecisionTreeNode *node)
{
Double_t nTotS = 0.0, nTotB = 0.0;
Int_t nTotS_unWeighted = 0, nTotB_unWeighted = 0;
vector<TMVA::BDTEventWrapper> bdtEventSample;
vector<Double_t> lCutValue( fNvars, 0.0 );
vector<Double_t> lSepGain( fNvars, -1.0 );
vector<Bool_t> lCutType( fNvars, kFALSE );
for( vector<TMVA::Event*>::const_iterator it = eventSample.begin(); it != eventSample.end(); ++it ) {
if( (*it)->Type() == 1 ) {
nTotS += (*it)->GetWeight();
++nTotS_unWeighted;
}
else {
nTotB += (*it)->GetWeight();
++nTotB_unWeighted;
}
bdtEventSample.push_back(TMVA::BDTEventWrapper(*it));
}
for( Int_t ivar = 0; ivar < fNvars; ivar++ ) {
TMVA::BDTEventWrapper::SetVarIndex(ivar);
std::sort( bdtEventSample.begin(),bdtEventSample.end() );
Double_t bkgWeightCtr = 0.0, sigWeightCtr = 0.0;
vector<TMVA::BDTEventWrapper>::iterator it = bdtEventSample.begin(), it_end = bdtEventSample.end();
for( ; it != it_end; ++it ) {
if( (**it)->Type() == 1 )
sigWeightCtr += (**it)->GetWeight();
else
bkgWeightCtr += (**it)->GetWeight();
it->SetCumulativeWeight(false,bkgWeightCtr);
it->SetCumulativeWeight(true,sigWeightCtr);
}
const Double_t fPMin = 1.0e-6;
Bool_t cutType = kFALSE;
Long64_t index = 0;
Double_t separationGain = -1.0, sepTmp = 0.0, cutValue = 0.0, dVal = 0.0, norm = 0.0;
for( it = bdtEventSample.begin(); it != it_end; ++it ) {
if( index == 0 ) { ++index; continue; }
if( *(*it) == NULL ) {
fLogger << kFATAL << "In TrainNodeFull(): have a null event! Where index="
<< index << ", and parent node=" << node->GetParent() << Endl;
break;
}
dVal = bdtEventSample[index].GetVal() - bdtEventSample[index-1].GetVal();
norm = TMath::Abs(bdtEventSample[index].GetVal() + bdtEventSample[index-1].GetVal());
if( index >= fMinSize &&
(nTotS_unWeighted + nTotB_unWeighted) - index >= fMinSize &&
TMath::Abs(dVal/(0.5*norm + 1)) > fPMin ) {
sepTmp = fSepType->GetSeparationGain( it->GetCumulativeWeight(true), it->GetCumulativeWeight(false), sigWeightCtr, bkgWeightCtr );
if( sepTmp > separationGain ) {
separationGain = sepTmp;
cutValue = it->GetVal() - 0.5*dVal;
Double_t nSelS = it->GetCumulativeWeight(true);
Double_t nSelB = it->GetCumulativeWeight(false);
if( nSelS/sigWeightCtr > nSelB/bkgWeightCtr ) cutType = kTRUE;
else cutType = kFALSE;
}
}
++index;
}
lCutType[ivar] = cutType;
lCutValue[ivar] = cutValue;
lSepGain[ivar] = separationGain;
}
Double_t separationGain = -1.0;
Int_t iVarIndex = -1;
for( Int_t ivar = 0; ivar < fNvars; ivar++ ) {
if( lSepGain[ivar] > separationGain ) {
iVarIndex = ivar;
separationGain = lSepGain[ivar];
}
}
if(iVarIndex >= 0) {
node->SetSelector(iVarIndex);
node->SetCutValue(lCutValue[iVarIndex]);
node->SetSeparationGain(lSepGain[iVarIndex]);
node->SetCutType(lCutType[iVarIndex]);
fVariableImportance[iVarIndex] += separationGain*separationGain * (nTotS+nTotB) * (nTotS+nTotB);
}
else {
separationGain = 0.0;
}
return separationGain;
}
Double_t TMVA::DecisionTree::CheckEvent(const TMVA::Event & e, Bool_t UseYesNoLeaf)
{
TMVA::DecisionTreeNode *current = (TMVA::DecisionTreeNode*)this->GetRoot();
while(current->GetNodeType() == 0){
if (current->GoesRight(e))
current=(TMVA::DecisionTreeNode*)current->GetRight();
else current=(TMVA::DecisionTreeNode*)current->GetLeft();
}
if (UseYesNoLeaf) return Double_t ( current->GetNodeType() );
else return current->GetPurity();
}
Double_t TMVA::DecisionTree::SamplePurity(vector<TMVA::Event*> eventSample)
{
Double_t sumsig=0, sumbkg=0, sumtot=0;
for (UInt_t ievt=0; ievt<eventSample.size(); ievt++) {
if (eventSample[ievt]->Type()==0) sumbkg+=eventSample[ievt]->GetWeight();
if (eventSample[ievt]->Type()==1) sumsig+=eventSample[ievt]->GetWeight();
sumtot+=eventSample[ievt]->GetWeight();
}
if (sumtot!= (sumsig+sumbkg)){
fLogger << kFATAL << "<SamplePurity> sumtot != sumsig+sumbkg"
<< sumtot << " " << sumsig << " " << sumbkg << Endl;
}
if (sumtot>0) return sumsig/(sumsig + sumbkg);
else return -1;
}
vector< Double_t > TMVA::DecisionTree::GetVariableImportance()
{
vector<Double_t> relativeImportance(fNvars);
Double_t sum=0;
for (int i=0; i< fNvars; i++) {
sum += fVariableImportance[i];
relativeImportance[i] = fVariableImportance[i];
}
for (int i=0; i< fNvars; i++) {
if (sum > std::numeric_limits<double>::epsilon())
relativeImportance[i] /= sum;
else
relativeImportance[i] = 0;
}
return relativeImportance;
}
Double_t TMVA::DecisionTree::GetVariableImportance(Int_t ivar)
{
vector<Double_t> relativeImportance = this->GetVariableImportance();
if (ivar >= 0 && ivar < fNvars) return relativeImportance[ivar];
else {
fLogger << kFATAL << "<GetVariableImportance>" << Endl
<< "--- ivar = " << ivar << " is out of range " << Endl;
}
return -1;
}
Last change: Sat Nov 1 10:21:33 2008
Last generated: 2008-11-01 10:21
This page has been automatically generated. If you have any comments or suggestions about the page layout send a mail to ROOT support, or contact the developers with any questions or problems regarding ROOT.