#include "TTree.h"
#include "TLeaf.h"
#include "TString.h"
#include "TClass.h"
#include "TH1D.h"
#include "TKey.h"
#include "TVector.h"
#include <stdlib.h>
#include <fstream>
#include "TPluginManager.h"
#include "TMVA/Reader.h"
#include "TMVA/Config.h"
#include "TMVA/Methods.h"
#define TMVA_Reader_TestIO__
#undef TMVA_Reader_TestIO__
ClassImp(TMVA::Reader)
#ifdef _WIN32
#pragma warning ( disable : 4355 )
#endif
TMVA::Reader::Reader( TString theOption, Bool_t verbose )
: Configurable( theOption ),
fDataSet( new DataSet ),
fVerbose( verbose ),
fLogger ( "Reader" )
{
SetConfigName( GetName() );
DeclareOptions();
ParseOptions();
Init();
}
TMVA::Reader::Reader( vector<TString>& inputVars, TString theOption, Bool_t verbose )
: Configurable( theOption ),
fDataSet( new DataSet ),
fVerbose( verbose ),
fLogger ( this )
{
SetConfigName( GetName() );
DeclareOptions();
ParseOptions();
for (vector<TString>::iterator ivar = inputVars.begin(); ivar != inputVars.end(); ivar++)
Data().AddVariable( *ivar );
Init();
}
TMVA::Reader::Reader( vector<string>& inputVars, TString theOption, Bool_t verbose )
: Configurable( theOption ),
fDataSet( new DataSet ),
fVerbose( verbose ),
fLogger ( this )
{
SetConfigName( GetName() );
DeclareOptions();
ParseOptions();
for (vector<string>::iterator ivar = inputVars.begin(); ivar != inputVars.end(); ivar++)
Data().AddVariable( ivar->c_str() );
Init();
}
TMVA::Reader::Reader( const string varNames, TString theOption, Bool_t verbose )
: Configurable( theOption ),
fDataSet( new DataSet ),
fVerbose( verbose ),
fLogger ( this )
{
SetConfigName( GetName() );
DeclareOptions();
ParseOptions();
this->DecodeVarNames(varNames);
Init();
}
TMVA::Reader::Reader( const TString varNames, TString theOption, Bool_t verbose )
: Configurable( theOption ),
fDataSet( new DataSet ),
fVerbose( verbose ),
fLogger ( this )
{
SetConfigName( GetName() );
DeclareOptions();
ParseOptions();
this->DecodeVarNames(varNames);
Init();
}
void TMVA::Reader::DeclareOptions()
{
Bool_t silent = kFALSE;
Bool_t color = kTRUE;
DeclareOptionRef( fVerbose, "V", "verbose flag" );
DeclareOptionRef( color, "Color", "Color flag (default on)" );
DeclareOptionRef( silent, "Silent", "Boolean silent flag (default off)" );
ParseOptions(kFALSE);
if (Verbose()) fLogger.SetMinType( kVERBOSE );
gConfig().SetUseColor( color );
gConfig().SetSilent ( silent );
if (fDataSet!=0) fDataSet->SetVerbose(Verbose());
}
TMVA::Reader::~Reader()
{
if (fDataSet) delete fDataSet;
std::map<TString, IMethod*>::iterator methodIt = fMethodMap.begin();
for(;methodIt != fMethodMap.end(); methodIt++) delete methodIt->second;
}
void TMVA::Reader::Init( void )
{
}
void TMVA::Reader::AddVariable( const TString& expression, Float_t* datalink )
{
Data().AddVariable(expression, 'F', (void*)datalink);
}
void TMVA::Reader::AddVariable( const TString& expression, Int_t* datalink )
{
Data().AddVariable(expression, 'I', (void*)datalink);
}
TMVA::IMethod* TMVA::Reader::BookMVA( const TString& methodTag, const TString& weightfile )
{
if (Data().GetNVariables() <= 0) {
fLogger << kFATAL
<< "<BookMVA>: before booking you must register references to your MVA input "
<< "variables via the call: \"reader->AddVariable( \"myFirstVar\", &muFirstVar );\" " << Endl;
}
std::map<TString, IMethod*>::iterator it = fMethodMap.find( methodTag );
if (it != fMethodMap.end()) {
fLogger << kFATAL << "<BookMVA> method tag \"" << methodTag << "\" already exists!" << Endl;
}
fLogger << kINFO << "Booking method tag \"" << methodTag << "\"" << Endl;
TString methodName, methodTitle;
GetMethodNameTitle(weightfile,methodName,methodTitle);
TMVA::Types::EMVA typeIndex = Types::Instance().GetMethodType(methodName);
if( typeIndex == TMVA::Types::kMaxMethod ) typeIndex = TMVA::Types::kPlugins;
MethodBase* method = (MethodBase*)this->BookMVA( typeIndex, weightfile );
method->SetMethodTitle(methodTitle);
fLogger << kINFO << "Read method name : \"" << method->GetMethodName() << "\"" << Endl;
fLogger << kINFO << " - method title : \"" << method->GetMethodTitle() << "\"" << Endl;
fLogger << kINFO << "Method tag : \"" << methodTag << "\"" << Endl;
return fMethodMap[methodTag] = method;
}
TMVA::IMethod* TMVA::Reader::BookMVA( TMVA::Types::EMVA methodType, TString weightfile )
{
IMethod* method = 0;
TPluginManager* pluginManager(0);
TPluginHandler* pluginHandler(0);
TString methodName, methodTitle;
switch (methodType) {
case (TMVA::Types::kCuts):
method = new TMVA::MethodCuts( Data(), weightfile );
break;
case (TMVA::Types::kLikelihood):
method = new TMVA::MethodLikelihood( Data(), weightfile );
break;
case (TMVA::Types::kPDERS):
method = new TMVA::MethodPDERS( Data(), weightfile );
break;
case (TMVA::Types::kKNN):
method = new TMVA::MethodKNN( Data(), weightfile );
break;
case (TMVA::Types::kHMatrix):
method = new TMVA::MethodHMatrix( Data(), weightfile );
break;
case (TMVA::Types::kFisher):
method = new TMVA::MethodFisher( Data(), weightfile );
break;
case (TMVA::Types::kFDA):
method = new TMVA::MethodFDA( Data(), weightfile );
break;
case (TMVA::Types::kMLP):
method = new TMVA::MethodMLP( Data(), weightfile );
break;
case (TMVA::Types::kCFMlpANN):
method = new TMVA::MethodCFMlpANN( Data(), weightfile );
break;
case (TMVA::Types::kTMlpANN):
method = new TMVA::MethodTMlpANN( Data(), weightfile );
break;
case (TMVA::Types::kSVM):
method = new TMVA::MethodSVM( Data(), weightfile );
break;
case (TMVA::Types::kBDT):
method = new TMVA::MethodBDT( Data(), weightfile );
break;
case (TMVA::Types::kRuleFit):
method = new TMVA::MethodRuleFit( Data(), weightfile );
break;
case (TMVA::Types::kBayesClassifier):
method = new TMVA::MethodBayesClassifier( Data(), weightfile );
break;
case (TMVA::Types::kPlugins):
GetMethodNameTitle(weightfile, methodName, methodTitle);
fLogger << kINFO << "Searching for plugin for " << methodName << " " << Endl;
pluginManager = gROOT->GetPluginManager();
pluginHandler = pluginManager->FindHandler("TMVA@@MethodBase",methodName );
if (pluginHandler) {
if (pluginHandler->LoadPlugin() == 0) {
method = (TMVA::MethodBase*) pluginHandler->ExecPlugin(2, &Data(), &weightfile);
if(method==0) {
fLogger << kFATAL << "Couldn't instantiate plugin for " << methodName << "." << Endl;
} else {
fLogger << kINFO << "Found plugin for " << methodName << " " << Endl;
}
} else {
fLogger << kFATAL << "Couldn't load any plugin for " << methodName << "." << Endl;
}
} else {
fLogger << kFATAL << "Couldn't find plugin handler for TMVA@@MethodBase and " << methodName << Endl;
}
break;
default:
fLogger << kFATAL << "Classifier: " << methodType << " not implemented" << Endl;
return 0;
}
((MethodBase*)method)->ReadStateFromFile();
fLogger << kINFO << "Booked classifier " << ((MethodBase*)method)->GetMethodName()
<< " with title: \"" << ((MethodBase*)method)->GetMethodTitle() << "\"" << Endl;
#ifdef TMVA_Reader_TestIO__
std::ofstream tfile( weightfile+".control" );
((MethodBase*)method)->WriteStateToStream(tfile);
tfile.close();
#endif
return method;
}
TMVA::IMethod* TMVA::Reader::FindMVA( const TString& methodTag )
{
std::map<TString, IMethod*>::iterator it = fMethodMap.find( methodTag );
if (it != fMethodMap.end()) return it->second;
return 0;
}
Double_t TMVA::Reader::EvaluateMVA( const std::vector<Float_t>& inputVec, const TString& methodTag, Double_t aux )
{
for (UInt_t ivar=0; ivar<inputVec.size(); ivar++) Data().GetEvent().SetVal( ivar, inputVec[ivar] );
return EvaluateMVA( methodTag, aux );
}
Double_t TMVA::Reader::EvaluateMVA( const std::vector<Double_t>& inputVec, const TString& methodTag, Double_t aux )
{
for (UInt_t ivar=0; ivar<inputVec.size(); ivar++) Data().GetEvent().SetVal( ivar, (Float_t)inputVec[ivar] );
return EvaluateMVA( methodTag, aux );
}
Double_t TMVA::Reader::EvaluateMVA( const TString& methodTag, Double_t aux )
{
IMethod* method = 0;
std::map<TString, IMethod*>::iterator it = fMethodMap.find( methodTag );
if (it == fMethodMap.end()) {
fLogger << kINFO << "<EvaluateMVA> unknown classifier in map; "
<< "you looked for \"" << methodTag << "\" within available methods: " << Endl;
for (it = fMethodMap.begin(); it!=fMethodMap.end(); it++) fLogger << " --> " << it->first << Endl;
fLogger << "Check calling string" << kFATAL << Endl;
}
else method = it->second;
return this->EvaluateMVA( (MethodBase*)method, aux );
}
Double_t TMVA::Reader::EvaluateMVA( MethodBase* method, Double_t aux )
{
method->GetVarTransform().GetEventRaw().CopyVarValues(Data().GetEvent());
if (method->GetMethodType() != Types::kLikelihood)
method->GetVarTransform().ApplyTransformation(Types::kSignal);
if (method->GetMethodType() == TMVA::Types::kCuts)
((TMVA::MethodCuts*)method)->SetTestSignalEfficiency( aux );
return method->GetMvaValue();
}
Double_t TMVA::Reader::GetProba( const TString& methodTag, Double_t ap_sig, Double_t mvaVal )
{
IMethod* method = 0;
std::map<TString, IMethod*>::iterator it = fMethodMap.find( methodTag );
if (it == fMethodMap.end()) {
for (it = fMethodMap.begin(); it!=fMethodMap.end(); it++) fLogger << "M" << it->first << Endl;
fLogger << kFATAL << "<EvaluateMVA> unknown classifier in map: " << method << "; "
<< "you looked for " << methodTag<< " while the available methods are : " << Endl;
}
else method = it->second;
MethodBase* kl = (MethodBase*)method;
if (mvaVal == -9999999) mvaVal = kl->GetMvaValue();
return kl->GetProba( mvaVal, ap_sig );
}
Double_t TMVA::Reader::GetRarity( const TString& methodTag, Double_t mvaVal )
{
IMethod* method = 0;
std::map<TString, IMethod*>::iterator it = fMethodMap.find( methodTag );
if (it == fMethodMap.end()) {
for (it = fMethodMap.begin(); it!=fMethodMap.end(); it++) fLogger << "M" << it->first << Endl;
fLogger << kFATAL << "<EvaluateMVA> unknown classifier in map: \"" << method << "\"; "
<< "you looked for \"" << methodTag<< "\" while the available methods are : " << Endl;
}
else method = it->second;
MethodBase* kl = (MethodBase*)method;
if (mvaVal == -9999999) mvaVal = kl->GetMvaValue();
return kl->GetRarity( mvaVal );
}
void TMVA::Reader::DecodeVarNames( const string varNames )
{
size_t ipos = 0, f = 0;
while (f != varNames.length()) {
f = varNames.find( ':', ipos );
if (f > varNames.length()) f = varNames.length();
string subs = varNames.substr( ipos, f-ipos ); ipos = f+1;
Data().AddVariable( subs.c_str() );
}
}
void TMVA::Reader::DecodeVarNames( const TString varNames )
{
TString format;
Int_t n = varNames.Length();
TString format_obj;
for (int i=0; i< n+1 ; i++) {
format.Append(varNames(i));
if ( (varNames(i)==':') || (i==n)) {
format.Chop();
format_obj = format;
format_obj.ReplaceAll("@","");
Data().AddVariable( format_obj );
format.Resize(0);
}
}
}
void TMVA::Reader::GetMethodNameTitle(const TString& weightfile, TString& methodName, TString& methodTitle) {
ifstream fin( weightfile );
if (!fin.good()) {
fLogger << kFATAL << "<BookMVA> fatal error: "
<< "unable to open input weight file: " << weightfile << Endl;
}
char buf[512];
fin.getline(buf,512);
while (!TString(buf).BeginsWith("Method")) fin.getline(buf,512);
TString lstr(buf);
Int_t idx1 = lstr.First(':')+2; Int_t idx2 = lstr.Index(' ',idx1)-idx1; if (idx2<0) idx2=lstr.Length();
fin.close();
TString fullname = lstr(idx1,idx2);
idx1 = fullname.First(':');
Int_t idxtit = (idx1<0 ? fullname.Length() : idx1);
methodName = fullname(0, idxtit);
Bool_t notit;
if (idx1<0) {
methodTitle=methodName;
notit=kTRUE;
}
else {
methodTitle=fullname(idxtit+2,fullname.Length()-1);
notit=kFALSE;
}
}
Last change: Sat Nov 1 10:21:59 2008
Last generated: 2008-11-01 10:21
This page has been automatically generated. If you have any comments or suggestions about the page layout send a mail to ROOT support, or contact the developers with any questions or problems regarding ROOT.