#ifndef ROOT_TMVA_Factory
#include "TMVA/Factory.h"
#endif
#include "Riostream.h"
#include "TFile.h"
#include "TTree.h"
#include "TLeaf.h"
#include "TEventList.h"
#include "TH1.h"
#include "TH2.h"
#include "TText.h"
#include "TTreeFormula.h"
#include "TStyle.h"
#include "TMatrixF.h"
#include "TMatrixDSym.h"
#include "TPaletteAxis.h"
#include "TPrincipal.h"
#include "TPluginManager.h"
#include "TMVA/Config.h"
#include "TMVA/Tools.h"
#include "TMVA/Ranking.h"
#include "TMVA/DataSet.h"
#include "TMVA/Methods.h"
#include "TMVA/Methods.h"
const Bool_t DEBUG_TMVA_Factory = kFALSE;
const int MinNoTrainingEvents = 10;
const int MinNoTestEvents     = 1;
const long int basketsize     = 1280000;
ClassImp(TMVA::Factory)
TMVA::Factory::Factory( TString jobName, TFile* theTargetFile, TString theOption )
   : Configurable           ( theOption ),
     fDataSet               ( new DataSet ),
     fTargetFile            ( theTargetFile ),
     fVerbose               ( kFALSE ),
     fJobName               ( jobName ),
     fDataAssignType        ( kUndefined ),
     fSuspendDATVerification( kFALSE ),
     fTrainSigAssignTree    ( 0 ),
     fTrainBkgAssignTree    ( 0 ),
     fTestSigAssignTree     ( 0 ),
     fTestBkgAssignTree     ( 0 ),
     fATreeEvent            ( 0 )
{  
   
   
   
   
   
   
   SetConfigDescription( "Configuration options for factory running" );
   
   
   
   Bool_t silent = kFALSE;
   Bool_t color  = !gROOT->IsBatch();
   SetConfigName( GetName() );
   DeclareOptionRef( fVerbose, "V", "Verbose mode" );
   DeclareOptionRef( color, "Color", "Color flag (default on)" );
   DeclareOptionRef( silent, "Silent", "Batch mode: boolean silent flag inhibiting any output from TMVA after the creation of the factory class object (default: False)" );
   ParseOptions( kFALSE );
   if (Verbose()) fLogger.SetMinType( kVERBOSE );
   gConfig().SetUseColor( color );
   gConfig().SetSilent( silent );
   Greetings();
   Data().SetBaseRootDir ( fTargetFile );
   Data().SetLocalRootDir( fTargetFile );
   Data().SetVerbose     ( Verbose() );
}
void TMVA::Factory::Greetings() 
{
   
   
   gTools().ROOTVersionMessage( fLogger ); 
   gTools().TMVAWelcomeMessage( fLogger, gTools().kLogoWelcomeMsg );
   gTools().TMVAVersionMessage( fLogger ); fLogger << Endl;
}
TMVA::Factory::~Factory( void )
{
   
   this->DeleteAllMethods();
   if (fTrainSigAssignTree) { delete fTrainSigAssignTree; fTrainSigAssignTree = 0; }
   if (fTrainBkgAssignTree) { delete fTrainBkgAssignTree; fTrainBkgAssignTree = 0; }
   if (fTestSigAssignTree)  { delete fTestSigAssignTree; fTestSigAssignTree = 0; }
   if (fTestBkgAssignTree)  { delete fTestBkgAssignTree; fTestBkgAssignTree = 0; }
   if ( fDataSet ) { delete fDataSet; fDataSet = 0; }
}
void TMVA::Factory::DeleteAllMethods( void )
{
   
   vector<IMethod*>::iterator itrMethod = fMethods.begin();
   for (; itrMethod != fMethods.end(); itrMethod++) {
      fLogger << kVERBOSE << "Delete method: " << (*itrMethod)->GetName() << Endl;    
      delete (*itrMethod);
   }
   fMethods.clear();
}
void TMVA::Factory::SetInputVariables( vector<TString>* theVariables ) 
{ 
   
   
   if (theVariables->size() == 0) {
      fLogger << kFATAL << "<SetInputVariables> Vector of input variables is empty" << Endl;
   }
   for (UInt_t i=0; i<theVariables->size(); i++) Data().AddVariable((*theVariables)[i]);
}
Bool_t TMVA::Factory::VerifyDataAssignType( DataAssignType thisType )
{
   
   
   if (fSuspendDATVerification) { fSuspendDATVerification = kFALSE; return kTRUE; }
   if      (thisType        == kUndefined) fLogger << kFATAL << "Big troubles in \"VerifyDataAssignType\"" << Endl;
   else if (fDataAssignType == kUndefined) fDataAssignType = thisType;
   else if (thisType        != fDataAssignType) {
      fLogger << kINFO 
              << "Given flags: thisType = " << thisType << ", DataAssignType = " << fDataAssignType << Endl;
      fLogger << kFATAL  
              << "You have added individual training or test events to the factory AND "
              << "also added complete trees. This is incompatible in TMVA -> please use " 
              << "the one OR the other (or contact the authors if such mixed access is required)." << Endl;
      return kFALSE;
   }
   
   return kTRUE;
}
void TMVA::Factory::CreateEventAssignTrees( TTree*& fAssignTree, const TString& name )
{
   
   fAssignTree = new TTree( name, name );
   fAssignTree->Branch( "type",   &fATreeType,   "ATreeType/I" );
   fAssignTree->Branch( "weight", &fATreeWeight, "ATreeWeight/I" );
   std::vector<VariableInfo>& vars = Data().GetVariableInfos();
   if (!fATreeEvent) fATreeEvent = new Float_t[vars.size()];
   for (UInt_t ivar=0; ivar<vars.size(); ivar++) {
      TString vname = vars[ivar].GetExpression();
      fAssignTree->Branch( vname, &(fATreeEvent[ivar]), vname + "/F" );
   }
}
void TMVA::Factory::AddSignalTrainingEvent( std::vector<Double_t>& event, Double_t weight )
{
   
   VerifyDataAssignType( kAssignEvents );
   
   if (!fTrainSigAssignTree) CreateEventAssignTrees( fTrainSigAssignTree, "LocalSigTrainTree" );
   
   
   if (event.size() != (UInt_t)fTrainSigAssignTree->GetNbranches()-2 || 
       event.size() != Data().GetVariableInfos().size()) {
      fLogger << kFATAL 
              << "Dimension mismatch in \"AddSignalTrainingEvent\", does the "
              << "input vector contain as many dimenions as variables defined ? => please check" << Endl;
   }
   
   
   fATreeType   = 1;
   fATreeWeight = weight;
   for (UInt_t ivar=0; ivar<event.size(); ivar++) fATreeEvent[ivar] = event[ivar];
   fTrainSigAssignTree->Fill();
}
void TMVA::Factory::AddBackgroundTrainingEvent( std::vector<Double_t>& event, Double_t weight )
{
   
   VerifyDataAssignType( kAssignEvents );
   
   if (!fTrainBkgAssignTree) CreateEventAssignTrees( fTrainBkgAssignTree, "LocalBkgTraiTree" );
   
   if (event.size() != (UInt_t)fTrainBkgAssignTree->GetNbranches()-2 || 
       event.size() != Data().GetVariableInfos().size()) {
      fLogger << kFATAL 
              << "Dimension mismatch in \"AddBackgroundTrainingEvent\", does the "
              << "input vector contain as many dimenions as variables defined ? => please check" << Endl;
   }
   
   
   fATreeType   = 0;
   fATreeWeight = weight;
   for (UInt_t ivar=0; ivar<event.size(); ivar++) fATreeEvent[ivar] = event[ivar];
   fTrainBkgAssignTree->Fill();
}
void TMVA::Factory::AddSignalTestEvent( std::vector<Double_t>& event, Double_t weight )
{
   
   VerifyDataAssignType( kAssignEvents );
   
   if (!fTestSigAssignTree) CreateEventAssignTrees( fTestSigAssignTree, "LocalSigTestTree" );
   
   if (event.size() != (UInt_t)fTestSigAssignTree->GetNbranches()-2 || 
       event.size() != Data().GetVariableInfos().size()) {
      fLogger << kFATAL 
              << "Dimension mismatch in \"AddSignalTestEvent\", does the "
              << "input vector contain as many dimenions as variables defined ? => please check" << Endl;
   }
   
   
   fATreeType   = 1;
   fATreeWeight = weight;
   for (UInt_t ivar=0; ivar<event.size(); ivar++) fATreeEvent[ivar] = event[ivar];
   fTestSigAssignTree->Fill();
}
void TMVA::Factory::AddBackgroundTestEvent( std::vector<Double_t>& event, Double_t weight )
{
   
   VerifyDataAssignType( kAssignEvents );
   
   if (!fTestBkgAssignTree) CreateEventAssignTrees( fTestBkgAssignTree, "LocalBkgTestTree" );
   
   if (event.size() != (UInt_t)fTestBkgAssignTree->GetNbranches()-2 || 
       event.size() != Data().GetVariableInfos().size()) {
      fLogger << kFATAL 
              << "Dimension mismatch in \"AddBackgroundTestEvent\", does the "
              << "input vector contain as many dimenions as variables defined ? => please check" << Endl;
   }
   
   
   fATreeType   = 0;
   fATreeWeight = weight;
   for (UInt_t ivar=0; ivar<event.size(); ivar++) fATreeEvent[ivar] = event[ivar];
   fTestBkgAssignTree->Fill();
}
Bool_t TMVA::Factory::SetInputTrees( TTree* signal, TTree* background, 
                                     Double_t signalWeight, Double_t backgroundWeight )
{
   
   if (!signal || !background) {
      fLogger << kFATAL << "Zero pointer for signal and/or background tree: " 
              << signal << " " << background << Endl;
      return kFALSE;
   }
   SetSignalTree    ( signal,     signalWeight );
   SetBackgroundTree( background, backgroundWeight );
   return kTRUE;
}
Bool_t TMVA::Factory::SetInputTrees( TTree* inputTree, TCut SigCut, TCut BgCut )
{
   
   
   
   if (!inputTree) {
      fLogger << kFATAL << "Zero pointer for input tree: " << inputTree << Endl;
      return kFALSE;
   }
   
   VerifyDataAssignType( kAssignTrees );
   TTree* signalTree = inputTree->CloneTree(0);
   TTree* backgTree  = inputTree->CloneTree(0);
   
   TIter next_branch1( signalTree->GetListOfBranches() );
   while (TBranch *branch = (TBranch*)next_branch1())
      branch->SetBasketSize(basketsize);
   TIter next_branch2( backgTree->GetListOfBranches() );
   while (TBranch *branch = (TBranch*)next_branch2())
      branch->SetBasketSize(basketsize);
   inputTree->Draw(">>signalList",SigCut,"goff");
   TEventList *signalList = (TEventList*)gDirectory->Get("signalList");
   inputTree->Draw(">>backgList",BgCut,"goff");    
   TEventList *backgList = (TEventList*)gDirectory->Get("backgList");
   if (backgList->GetN() == inputTree->GetEntries()) {
      TCut bgcut= !SigCut;
      inputTree->Draw(">>backgList",bgcut,"goff");    
      backgList = (TEventList*)gDirectory->Get("backgList");
   }
   signalList->Print();
   backgList ->Print();
  
   for (Int_t i=0;i<inputTree->GetEntries(); i++) {
      inputTree->GetEntry(i);
      if ((backgList->Contains(i)) && (signalList->Contains(i))) {
         fLogger << kWARNING << "Event " << i
                 << " is selected for signal and background sample! Skip it!" << Endl;
         continue;
      }
      if (signalList->Contains(i)) signalTree->Fill();
      if (backgList->Contains(i) ) backgTree->Fill();
   }
   signalTree->ResetBranchAddresses();
   backgTree ->ResetBranchAddresses();
   Data().AddSignalTree    ( signalTree, 1.0 );
   Data().AddBackgroundTree( backgTree,  1.0 );
   delete signalList;
   delete backgList;
   return kTRUE;    
}
Bool_t TMVA::Factory::SetInputTrees( TString datFileS, TString datFileB, 
                                     Double_t signalWeight, Double_t backgroundWeight )
{
   
   TTree* signalTree = new TTree( "TreeS", "Tree (S)" );
   TTree* backgTree  = new TTree( "TreeB", "Tree (B)" );
   
   VerifyDataAssignType( kAssignTrees );
  
   signalTree->ReadFile( datFileS );
   backgTree->ReadFile( datFileB );
   fLogger << kINFO << "Create TTree objects from ASCII input files ... \n- Signal file    : \""
           << datFileS << "\" \n- Background file: \""
           << datFileB << "\"" << Endl;
  
   ifstream in; 
   in.open(datFileS);
   if (!in.good()) {
      fLogger << kFATAL << "Could not open file: " << datFileS << Endl;
      return kFALSE;
   }
   in.close();
   in.open(datFileB);
   if (!in.good()) {
      fLogger << kFATAL << "Could not open file: " << datFileB << Endl;
      return kFALSE;
   }
   in.close();
    
   
   
   SetSignalTree    ( signalTree, signalWeight );
   SetBackgroundTree( backgTree,  backgroundWeight );
   return kTRUE;
}
void TMVA::Factory::SetInputTreesFromEventAssignTrees()
{
   
   SetWeightExpression( "weight" );
   fSuspendDATVerification = kTRUE; AddSignalTree    ( fTrainSigAssignTree, 1.0, Types::kTraining );
   fSuspendDATVerification = kTRUE; AddSignalTree    ( fTestSigAssignTree,  1.0, Types::kTesting  );
   fSuspendDATVerification = kTRUE; AddBackgroundTree( fTrainBkgAssignTree, 1.0, Types::kTraining );
   fSuspendDATVerification = kTRUE; AddBackgroundTree( fTestBkgAssignTree,  1.0, Types::kTesting  );
}
void TMVA::Factory::PrepareTrainingAndTestTree( const TCut& cut, 
                                                Int_t NsigTrain, Int_t NbkgTrain, Int_t NsigTest, Int_t NbkgTest,
                                                const TString& otherOpt )
{
   
   TString opt(Form("NsigTrain=%i:NbkgTrain=%i:NsigTest=%i:NbkgTest=%i:%s", 
                    NsigTrain, NbkgTrain, NsigTest, NbkgTest, otherOpt.Data()));
   PrepareTrainingAndTestTree( cut, cut, opt );
}
void TMVA::Factory::PrepareTrainingAndTestTree( const TCut& cut, Int_t Ntrain, Int_t Ntest )
{
   
   
   TString opt(Form("NsigTrain=%i:NbkgTrain=%i:NsigTest=%i:NbkgTest=%i:SplitMode=Random:EqualTrainSample:!V", 
                    Ntrain, Ntrain, Ntest, Ntest));
   PrepareTrainingAndTestTree( cut, cut, opt );
}
void TMVA::Factory::PrepareTrainingAndTestTree( const TCut& cut, const TString& splitOpt )
{ 
   
   PrepareTrainingAndTestTree( cut, cut, splitOpt );
}
void TMVA::Factory::PrepareTrainingAndTestTree( const TCut& sigcut, const TCut& bkgcut, const TString& splitOpt )
{ 
   
   
   if (fDataAssignType == kAssignEvents) SetInputTreesFromEventAssignTrees();
   fLogger << kINFO << "Preparing trees for training and testing..." << Endl;
   Data().SetCuts(sigcut, bkgcut); 
   Data().PrepareForTrainingAndTesting( splitOpt );
}
void TMVA::Factory::SetSignalTree( TTree* signal, Double_t weight )
{
   
   
   VerifyDataAssignType( kAssignTrees );
   
   Data().ClearSignalTreeList();
   AddSignalTree( signal, weight );
}
void TMVA::Factory::AddSignalTree( TTree* signal, Double_t weight, Types::ETreeType treetype )
{
   
   
   VerifyDataAssignType( kAssignTrees );
   
   Data().AddSignalTree( signal, weight, treetype );
}
void TMVA::Factory::AddSignalTree( TTree* signal, Double_t weight, const TString& treetype )
{
   
   
   VerifyDataAssignType( kAssignTrees );
   
   Types::ETreeType tt = Types::kMaxTreeType;
   TString tmpTreeType = treetype; tmpTreeType.ToLower();
   if      (tmpTreeType.Contains( "train" ) && tmpTreeType.Contains( "test" )) tt = Types::kMaxTreeType;
   else if (tmpTreeType.Contains( "train" ))                                   tt = Types::kTraining;
   else if (tmpTreeType.Contains( "test" ))                                    tt = Types::kTesting;
   else {
      fLogger << kFATAL << "<AddSignalTree> cannot interpret tree type: \"" << treetype 
              << "\" should be \"Training\" or \"Test\" or \"Training and Testing\"" << Endl;
   }
   Data().AddSignalTree( signal, weight, tt );
}
void TMVA::Factory::SetBackgroundTree( TTree* background, Double_t weight )
{
   
   
   VerifyDataAssignType( kAssignTrees );
   
   Data().ClearBackgroundTreeList();
   AddBackgroundTree( background, weight );
}
void TMVA::Factory::AddBackgroundTree( TTree* background, Double_t weight, Types::ETreeType treetype )
{
   
   
   VerifyDataAssignType( kAssignTrees );
   
   Data().AddBackgroundTree( background, weight, treetype );
}
void TMVA::Factory::AddBackgroundTree( TTree* background, Double_t weight, const TString & treetype )
{
   
   
   VerifyDataAssignType( kAssignTrees );
   
   
   Types::ETreeType tt = Types::kMaxTreeType;
   TString tmpTreeType = treetype; tmpTreeType.ToLower();
   if      (tmpTreeType.Contains( "train" ) && tmpTreeType.Contains( "test" )) tt = Types::kMaxTreeType;
   else if (tmpTreeType.Contains( "train" ))                                   tt = Types::kTraining;
   else if (tmpTreeType.Contains( "test" ))                                    tt = Types::kTesting;
   else {
      fLogger << kFATAL << "<AddBackgroundTree> cannot interpret tree type: \"" << treetype 
              << "\" should be \"Training\" or \"Test\" or \"Training and Testing\"" << Endl;
   }
   Data().AddBackgroundTree( background, weight, tt );
}
Bool_t TMVA::Factory::BookMethod( TString theMethodName, TString methodTitle, TString theOption ) 
{
   
   
   if (theMethodName != "Variable") fLogger << kINFO << "Create method: " << theMethodName << Endl;
   return BookMethod( Types::Instance().GetMethodType( theMethodName ), methodTitle, theOption );
   return kFALSE;
}
Bool_t TMVA::Factory::BookMethod( Types::EMVA theMethod, TString methodTitle, TString theOption ) 
{
   
   
   
   
   if (GetMethod( methodTitle ) != 0) {
      fLogger << kFATAL << "Booking failed since method with title <"
              << methodTitle <<"> already exists"
              << Endl;
   }
   TPluginManager *pluginManager(0);
   TPluginHandler *pluginHandler(0);
   
   MethodBase *method = 0;
   switch (theMethod) {
   case Types::kCuts:       
      method = new MethodCuts           ( fJobName, methodTitle, Data(), theOption ); break;
   case Types::kFisher:     
      method = new MethodFisher         ( fJobName, methodTitle, Data(), theOption ); break;
   case Types::kKNN:     
      method = new MethodKNN            ( fJobName, methodTitle, Data(), theOption ); break;
   case Types::kMLP:        
      method = new MethodMLP            ( fJobName, methodTitle, Data(), theOption ); break;
   case Types::kTMlpANN:    
      method = new MethodTMlpANN        ( fJobName, methodTitle, Data(), theOption ); break;
   case Types::kCFMlpANN:   
      method = new MethodCFMlpANN       ( fJobName, methodTitle, Data(), theOption ); break;
   case Types::kLikelihood: 
      method = new MethodLikelihood     ( fJobName, methodTitle, Data(), theOption ); break;
   case Types::kVariable:   
      method = new MethodVariable       ( fJobName, methodTitle, Data(), theOption ); break;
   case Types::kHMatrix:    
      method = new MethodHMatrix        ( fJobName, methodTitle, Data(), theOption ); break;
   case Types::kPDERS:      
      method = new MethodPDERS          ( fJobName, methodTitle, Data(), theOption ); break;
   case Types::kBDT:        
      method = new MethodBDT            ( fJobName, methodTitle, Data(), theOption ); break;
   case Types::kSVM:        
      method = new MethodSVM            ( fJobName, methodTitle, Data(), theOption ); break;
   case Types::kRuleFit:    
      method = new MethodRuleFit        ( fJobName, methodTitle, Data(), theOption ); break;
   case Types::kBayesClassifier:    
      method = new MethodBayesClassifier( fJobName, methodTitle, Data(), theOption ); break;
   case Types::kFDA:    
      method = new MethodFDA            ( fJobName, methodTitle, Data(), theOption ); break;
   case Types::kSeedDistance:    
      method = new MethodSeedDistance   ( fJobName, methodTitle, Data(), theOption ); break;
   case TMVA::Types::kPlugins:
      pluginManager = gROOT->GetPluginManager();
      pluginHandler = pluginManager->FindHandler("TMVA@@MethodBase", methodTitle);
      if (pluginHandler) {
         if (pluginHandler->LoadPlugin() == 0) {
            method = (TMVA::MethodBase*) pluginHandler->ExecPlugin(4, &fJobName, &methodTitle, &Data(), &theOption);
            if(method==0) {
               fLogger << kFATAL << "Couldn't instantiate plugin for " << methodTitle << "." << Endl;
            } else {
               fLogger << kINFO << "Found plugin for " << methodTitle << "  " << method << Endl;
            }
         } else {
            fLogger << kFATAL << "Couldn't load any plugin for " << methodTitle << "." << Endl; 
         }
      } else {
         fLogger << kFATAL << "Couldn't find plugin handler for TMVA@@MethodBase and " << methodTitle << Endl; 
      }
      break;
   default:
      fLogger << kFATAL << "Method: \"" << theMethod << "\" does not exist" << Endl;
   }
   fLogger << kINFO << "Booking method: " << method->GetMethodTitle() << Endl;
   fMethods.push_back( method );
   return kTRUE;
}
Bool_t TMVA::Factory::BookMethod( Types::EMVA theMethod, TString methodTitle, TString methodOption,
                                  Types::EMVA theCommittee, TString committeeOption ) 
{
   
   
   
   
   IMethod *method = 0;
   
   switch(theMethod) {
   case Types::kCommittee:    
      method = new MethodCommittee( fJobName, methodTitle, Data(), methodOption, theCommittee, committeeOption ); break;
   default:
      fLogger << kFATAL << "Method: " << theMethod << " does not exist" << Endl;
   }
   fMethods.push_back( method );
   return kTRUE;
}
TMVA::IMethod* TMVA::Factory::GetMethod( const TString &methodTitle ) const
{
   
   vector<IMethod*>::const_iterator itrMethod    = fMethods.begin();
   vector<IMethod*>::const_iterator itrMethodEnd = fMethods.end();
   
   for (; itrMethod != itrMethodEnd; itrMethod++) {
      MethodBase* mva = (MethodBase*)(*itrMethod);    
      if ( (mva->GetMethodTitle())==methodTitle ) return mva;
   }
   return 0;
}
void TMVA::Factory::TrainAllMethods( void ) 
{  
   
   fLogger << kINFO << "Training all methods..." << Endl;
   
   vector<IMethod*>::iterator itrMethod    = fMethods.begin();
   vector<IMethod*>::iterator itrMethodEnd = fMethods.end();
   for (; itrMethod != itrMethodEnd; itrMethod++) {
      MethodBase* mva = (MethodBase*)*itrMethod;
      if (Data().GetTrainingTree()->GetEntries() >= MinNoTrainingEvents) {
         fLogger << kINFO << "Train method: " << mva->GetMethodTitle() << Endl;
         mva->TrainMethod();
      }
      else {
         fLogger << kWARNING << "Method " << mva->GetMethodName() 
                 << " not trained (training tree has less entries ["
                 << Data().GetTrainingTree()->GetEntries() 
                 << "] than required [" << MinNoTrainingEvents << "]" << Endl; 
      }
   }
   
   fLogger << Endl;
   fLogger << kINFO << "Begin ranking of input variables..." << Endl;
   for (itrMethod = fMethods.begin(); itrMethod != itrMethodEnd; itrMethod++) {
      if (Data().GetTrainingTree()->GetEntries() >= MinNoTrainingEvents) {
         
         const Ranking* ranking = (*itrMethod)->CreateRanking();
         if (ranking != 0) ranking->Print();
         else fLogger << kINFO << "No variable ranking supplied by classifier: " 
                      << ((MethodBase*)*itrMethod)->GetMethodTitle() << Endl;
      }
   }
   fLogger << Endl;
}
void TMVA::Factory::TestAllMethods( void )
{
   
   fLogger << kINFO << "Testing all classifiers..." << Endl;
   if (Data().GetTrainingTree() == NULL) {
      fLogger << kWARNING << "You perform testing without training before, hope you "
              << "provided a reasonable test tree and weight files " << Endl;
   } 
   
   vector<IMethod*>::iterator itrMethod    = fMethods.begin();
   vector<IMethod*>::iterator itrMethodEnd = fMethods.end();
   for (; itrMethod != itrMethodEnd; itrMethod++) {
      MethodBase* mva = (MethodBase*)*itrMethod;
      fLogger << kINFO << "Test method: " << mva->GetMethodTitle() << Endl;
      mva->AddClassifierToTestTree(0);
   }
   if (DEBUG_TMVA_Factory) Data().GetTestTree()->Print();
}
void TMVA::Factory::MakeClass( const TString& methodTitle ) const
{
   
   
   if (methodTitle != "") {
      IMethod* method = GetMethod( methodTitle );
      if (method) method->MakeClass();
      else {
         fLogger << kWARNING << "<MakeClass> Could not find classifier \"" << methodTitle 
                 << "\" in list" << Endl;
      }
   }
   else {
      
      vector<IMethod*>::const_iterator itrMethod    = fMethods.begin();
      vector<IMethod*>::const_iterator itrMethodEnd = fMethods.end();
      for (; itrMethod != itrMethodEnd; itrMethod++) {
         MethodBase* method = (MethodBase*)*itrMethod;
         fLogger << kINFO << "Make response class for classifier: " << method->GetMethodTitle() << Endl;
         method->MakeClass();
      }
   }
}
void TMVA::Factory::PrintHelpMessage( const TString& methodTitle ) const
{
   
   
   if (methodTitle != "") {
      IMethod* method = GetMethod( methodTitle );
      if (method) method->PrintHelpMessage();
      else {
         fLogger << kWARNING << "<PrintHelpMessage> Could not find classifier \"" << methodTitle 
                 << "\" in list" << Endl;
      }
   }
   else {
      
      vector<IMethod*>::const_iterator itrMethod    = fMethods.begin();
      vector<IMethod*>::const_iterator itrMethodEnd = fMethods.end();
      for (; itrMethod != itrMethodEnd; itrMethod++) {
         MethodBase* method = (MethodBase*)*itrMethod;
         fLogger << kINFO << "Print help message for classifier: " << method->GetMethodTitle() << Endl;
         method->PrintHelpMessage();
      }
   }
}
void TMVA::Factory::EvaluateAllVariables( TString options )
{
   
   fLogger << kINFO << "Evaluating all variables..." << Endl;
   if (Data().GetTrainingTree() == NULL) {
      fLogger << kWARNING << "You perform testing without training before, hope you "
              << "provided a reasonable test tree and weight files " << Endl;
   } 
   for (UInt_t i=0; i<Data().GetNVariables(); i++) {
      TString s = Data().GetInternalVarName(i);
      if (options.Contains("V")) s += ":V";
      this->BookMethod( "Variable", s );
   }
}
void TMVA::Factory::EvaluateAllMethods( void )
{
   
   fLogger << kINFO << "Evaluating all classifiers..." << Endl;
   if (Data().GetTrainingTree() == NULL) {
      fLogger << kWARNING << "You perform testing without training before, hope you "
              << "provided a reasonable test tree and weight files " << Endl;
   } 
   
   
   
   
   
   
   Int_t isel;                  
   Int_t nmeth_used[2] = {0,0}; 
   vector<vector<TString> >  mname(2);
   vector<vector<Double_t> > sig(2), sep(2);
   vector<vector<Double_t> > eff01(2), eff10(2), eff30(2), effArea(2);
   vector<vector<Double_t> > eff01err(2), eff10err(2), eff30err(2);
   vector<vector<Double_t> > trainEff01(2), trainEff10(2), trainEff30(2);
   
   vector<IMethod*> methodsNoCuts; 
   
   vector<IMethod*>::iterator itrMethod    = fMethods.begin();
   vector<IMethod*>::iterator itrMethodEnd = fMethods.end();
   for (; itrMethod != itrMethodEnd; itrMethod++) {
      MethodBase* theMethod = (MethodBase*)*itrMethod;
      if (theMethod->GetMethodType() != Types::kCuts) methodsNoCuts.push_back( *itrMethod );
      fLogger << kINFO << "Evaluate classifier: " << theMethod->GetMethodTitle() << Endl;
      isel = (theMethod->GetMethodName().Contains("Variable")) ? 1 : 0;
      
      
      theMethod->Test(0);
      
      mname[isel].push_back( theMethod->GetMethodTitle() );
      sig[isel].push_back  ( theMethod->GetSignificance() );
      sep[isel].push_back  ( theMethod->GetSeparation() );
      
      
      Double_t err;
      eff01[isel].push_back( theMethod->GetEfficiency("Efficiency:0.01", Data().GetTestTree(), err) );
      eff01err[isel].push_back( err );
      eff10[isel].push_back( theMethod->GetEfficiency("Efficiency:0.10", Data().GetTestTree(), err) );
      eff10err[isel].push_back( err );
      eff30[isel].push_back( theMethod->GetEfficiency("Efficiency:0.30", Data().GetTestTree(), err) );
      eff30err[isel].push_back( err );
      effArea[isel].push_back( theMethod->GetEfficiency("", Data().GetTestTree(), err)  ); 
      trainEff01[isel].push_back( theMethod->GetTrainingEfficiency("Efficiency:0.01") ); 
      trainEff10[isel].push_back( theMethod->GetTrainingEfficiency("Efficiency:0.10") );
      trainEff30[isel].push_back( theMethod->GetTrainingEfficiency("Efficiency:0.30") );
      nmeth_used[isel]++;
      theMethod->WriteEvaluationHistosToFile();
   }
   
   for (Int_t k=0; k<2; k++) {
      vector< vector<Double_t> > vtemp;
      vtemp.push_back( effArea[k] );
      vtemp.push_back( eff10[k] ); 
      vtemp.push_back( eff01[k] );
      vtemp.push_back( eff30[k] );
      vtemp.push_back( eff10err[k] ); 
      vtemp.push_back( eff01err[k] );
      vtemp.push_back( eff30err[k] );
      vtemp.push_back( trainEff10[k] );
      vtemp.push_back( trainEff01[k] );
      vtemp.push_back( trainEff30[k] );
      vtemp.push_back( sig[k]   );
      vtemp.push_back( sep[k]   );
      vector<TString> vtemps = mname[k];
      gTools().UsefulSortDescending( vtemp, &vtemps );
      effArea[k]    = vtemp[0];
      eff10[k]      = vtemp[1];
      eff01[k]      = vtemp[2];
      eff30[k]      = vtemp[3];
      eff10err[k]   = vtemp[4];
      eff01err[k]   = vtemp[5];
      eff30err[k]   = vtemp[6];
      trainEff10[k] = vtemp[7];
      trainEff01[k] = vtemp[8];
      trainEff30[k] = vtemp[9];
      sig[k]        = vtemp[10];
      sep[k]        = vtemp[11];
      mname[k]      = vtemps;
   }
   
   
   
   
   
   
   const Int_t nvar = methodsNoCuts.size();
   if (nvar > 1) {
      
      Float_t  *fvec = new Float_t[nvar];
      Double_t *dvec = new Double_t[nvar];
      vector<Double_t> rvec;
      Int_t    type;
      
      TPrincipal* tpSig = new TPrincipal( nvar, "" );   
      TPrincipal* tpBkg = new TPrincipal( nvar, "" );   
      
      Int_t ivar = 0;
      vector<TString>* theVars = new vector<TString>;
      Data().GetTestTree()->ResetBranchAddresses();
      for (itrMethod = methodsNoCuts.begin(); itrMethod != methodsNoCuts.end(); itrMethod++, ivar++) {
         theVars->push_back( ((MethodBase*)*itrMethod)->GetTestvarName() );
         rvec.push_back( ((MethodBase*)*itrMethod)->GetSignalReferenceCut() );
         Data().GetTestTree()->SetBranchAddress( theVars->back(), &(fvec[ivar]) );
         theVars->back().ReplaceAll( "MVA_", "" );
      }
      Data().GetTestTree()->SetBranchAddress( "type", &type );
      
      TMatrixD* overlapS = new TMatrixD( nvar, nvar );
      TMatrixD* overlapB = new TMatrixD( nvar, nvar );
      (*overlapS) *= 0; 
      (*overlapB) *= 0; 
      
      for (Int_t ievt=0; ievt<Data().GetTestTree()->GetEntries(); ievt++) {
         Data().GetTestTree()->GetEntry(ievt);
         
         TMatrixD* theMat = 0;
         for (Int_t im=0; im<nvar; im++) dvec[im] = (Double_t)fvec[im];
         if (type == 1) { tpSig->AddRow( dvec ); theMat = overlapS; }
         else           { tpBkg->AddRow( dvec ); theMat = overlapB; }
         
         for (Int_t im=0; im<nvar; im++) {
            for (Int_t jm=im; jm<nvar; jm++) {
               if ((dvec[im] - rvec[im])*(dvec[jm] - rvec[jm]) > 0) { 
                  (*theMat)(im,jm)++; 
                  if (im != jm) (*theMat)(jm,im)++;
               }
            }
         }      
      }
      
      (*overlapS) *= (1.0/Data().GetNEvtSigTest());  
      (*overlapB) *= (1.0/Data().GetNEvtBkgdTest()); 
   
      tpSig->MakePrincipals();
      tpBkg->MakePrincipals();
      const TMatrixD* covMatS = tpSig->GetCovarianceMatrix();
      const TMatrixD* covMatB = tpBkg->GetCovarianceMatrix();
   
      const TMatrixD* corrMatS = gTools().GetCorrelationMatrix( covMatS );
      const TMatrixD* corrMatB = gTools().GetCorrelationMatrix( covMatB );
      
      if (corrMatS != 0 && corrMatB != 0) {
         fLogger << kINFO << Endl;
         fLogger << kINFO << "Inter-MVA correlation matrix (signal):" << Endl;
         gTools().FormattedOutput( *corrMatS, *theVars, fLogger );
         fLogger << kINFO << Endl;
         fLogger << kINFO << "Inter-MVA correlation matrix (background):" << Endl;
         gTools().FormattedOutput( *corrMatB, *theVars, fLogger );
         fLogger << kINFO << Endl;   
      }
      else fLogger << kWARNING << "<TestAllMethods> cannot compute correlation matrices" << Endl;
      
      fLogger << kINFO << "The following \"overlap\" matrices contain the fraction of events for which " << Endl;
      fLogger << kINFO << "the MVAs 'i' and 'j' have returned conform answers about \"signal-likeness\"" << Endl;
      fLogger << kINFO << "An event is signal-like, if its MVA output exceeds the following value:" << Endl;
      gTools().FormattedOutput( rvec, *theVars, "Method" , "Cut value", fLogger );
      fLogger << kINFO << "which correspond to the working point: eff(signal) = 1 - eff(background)" << Endl;
      
      if (nvar != (Int_t)fMethods.size()) 
         fLogger << kINFO << "Note: no correlations and overlap with cut method are provided at present" << Endl;
      fLogger << kINFO << Endl;
      fLogger << kINFO << "Inter-MVA overlap matrix (signal):" << Endl;
      gTools().FormattedOutput( *overlapS, *theVars, fLogger );
      fLogger << kINFO << Endl;
      
      fLogger << kINFO << "Inter-MVA overlap matrix (background):" << Endl;
      gTools().FormattedOutput( *overlapB, *theVars, fLogger );
      
      delete tpSig;
      delete tpBkg;
      delete corrMatS;
      delete corrMatB;
      delete theVars;
      delete overlapS;
      delete overlapB;
      delete [] fvec;
      delete [] dvec;
   }
   
   
   
   
   fLogger << Endl;
   TString hLine = "-----------------------------------------------------------------------------";
   fLogger << kINFO << "Evaluation results ranked by best signal efficiency and purity (area)" << Endl;
   fLogger << kINFO << hLine << Endl;
   fLogger << kINFO << "MVA              Signal efficiency at bkg eff. (error):  |  Sepa-    Signifi-"   << Endl;
   fLogger << kINFO << "Methods:         @B=0.01    @B=0.10    @B=0.30    Area   |  ration:  cance:  "   << Endl;
   fLogger << kINFO << hLine << Endl;
   for (Int_t k=0; k<2; k++) {
      if (k == 1 && nmeth_used[k] > 0) {
         fLogger << kINFO << hLine << Endl;
         fLogger << kINFO << "Input Variables: " << Endl << hLine << Endl;
      }
      for (Int_t i=0; i<nmeth_used[k]; i++) {
         if (k == 1) mname[k][i].ReplaceAll( "Variable_", "" );
         fLogger << kINFO << Form("%-15s: %1.3f(%02i)  %1.3f(%02i)  %1.3f(%02i)  %1.3f  |  %1.3f    %1.3f",
                                  (const char*)mname[k][i], 
                                  eff01[k][i], Int_t(1000*eff01err[k][i]), 
                                  eff10[k][i], Int_t(1000*eff10err[k][i]), 
                                  eff30[k][i], Int_t(1000*eff30err[k][i]), 
                                  effArea[k][i], 
                                  sep[k][i], sig[k][i]) << Endl;
      }
   }
   fLogger << kINFO << hLine << Endl;
   fLogger << kINFO << Endl;
   fLogger << kINFO << "Testing efficiency compared to training efficiency (overtraining check)" << Endl;
   fLogger << kINFO << hLine << Endl;
   fLogger << kINFO << "MVA           Signal efficiency: from test sample (from traing sample) "   << Endl;
   fLogger << kINFO << "Methods:         @B=0.01             @B=0.10            @B=0.30   "   << Endl;
   fLogger << kINFO << hLine << Endl;
   for (Int_t k=0; k<2; k++) {
      if (k == 1 && nmeth_used[k] > 0) {
         fLogger << kINFO << hLine << Endl;
         fLogger << kINFO << "Input Variables: " << Endl << hLine << Endl;
      }
      for (Int_t i=0; i<nmeth_used[k]; i++) {
         if (k == 1) mname[k][i].ReplaceAll( "Variable_", "" );
         fLogger << kINFO << Form("%-15s: %1.3f (%1.3f)       %1.3f (%1.3f)      %1.3f (%1.3f)",
                                  (const char*)mname[k][i], 
                                  eff01[k][i],trainEff01[k][i], 
                                  eff10[k][i],trainEff10[k][i],
                                  eff30[k][i],trainEff30[k][i]) << Endl;
      }
   }
   fLogger << kINFO << hLine << Endl;
   fLogger << kINFO << Endl;  
   fLogger << kINFO << "Write Test Tree '"<< Data().GetTestTree()->GetName()<<"' to file" << Endl;
   Data().BaseRootDir()->cd();
   Data().GetTestTree()->SetDirectory( Data().BaseRootDir() );
   Data().GetTestTree()->Write("",TObject::kOverwrite);
}
Last change: Sat Nov  1 10:21:34 2008
Last generated: 2008-11-01 10:21
This page has been automatically generated. If you have any comments or suggestions about the page layout send a mail to ROOT support, or contact the developers with any questions or problems regarding ROOT.