#include <algorithm>
#include "Riostream.h"
#include "TRandom.h"
#include "TObjString.h"
#include "TDirectory.h"
#include "TTree.h"
#include "TH2.h"
#include "TMVA/MethodCommittee.h"
#include "TMVA/Tools.h"
#include "TMVA/Timer.h"
#include "TMVA/Ranking.h"
#include "TMVA/Methods.h"
using std::vector;
ClassImp(TMVA::MethodCommittee)
TMVA::MethodCommittee::MethodCommittee( const TString& jobName, const TString& committeeTitle, DataSet& theData,
const TString& committeeOptions,
Types::EMVA method, const TString& methodOptions,
TDirectory* theTargetDir )
: TMVA::MethodBase( jobName, committeeTitle, theData, committeeOptions, theTargetDir ),
fMemberType( method ),
fMemberOption( methodOptions )
{
InitCommittee();
SetConfigName( TString("Method") + GetMethodName() );
DeclareOptions();
ParseOptions();
ProcessOptions();
fBoostFactorHist = new TH1F("fBoostFactor","Ada Boost weights",100,1,100);
fErrFractHist = new TH2F("fErrFractHist","error fraction vs tree number",
fNMembers,0,fNMembers,50,0,0.5);
fMonitorNtuple = new TTree("fMonitorNtuple","Committee variables");
fMonitorNtuple->Branch("iTree",&fITree,"iTree/I");
fMonitorNtuple->Branch("boostFactor",&fBoostFactor,"boostFactor/D");
fMonitorNtuple->Branch("errorFraction",&fErrorFraction,"errorFraction/D");
}
TMVA::MethodCommittee::MethodCommittee( DataSet& theData,
const TString& theWeightFile,
TDirectory* theTargetDir )
: TMVA::MethodBase( theData, theWeightFile, theTargetDir )
{
InitCommittee();
DeclareOptions();
}
void TMVA::MethodCommittee::DeclareOptions()
{
DeclareOptionRef(fNMembers, "NMembers", "number of members in the committee");
DeclareOptionRef(fUseMemberDecision=kFALSE, "UseMemberDecision", "use binary information from IsSignal");
DeclareOptionRef(fUseWeightedMembers=kTRUE, "UseWeightedMembers", "use weighted trees or simple average in classification from the forest");
DeclareOptionRef(fBoostType, "BoostType", "boosting type");
AddPreDefVal(TString("AdaBoost"));
AddPreDefVal(TString("Bagging"));
}
void TMVA::MethodCommittee::ProcessOptions()
{
MethodBase::ProcessOptions();
}
void TMVA::MethodCommittee::InitCommittee( void )
{
SetMethodName( "Committee" );
SetMethodType( TMVA::Types::kCommittee );
SetTestvarName();
fNMembers = 100;
fBoostType = "AdaBoost";
fCommittee.clear();
fBoostWeights.clear();
}
TMVA::MethodCommittee::~MethodCommittee( void )
{
for (UInt_t i=0; i<GetCommittee().size(); i++) delete fCommittee[i];
fCommittee.clear();
}
void TMVA::MethodCommittee::WriteStateToFile() const
{
TString fname(GetWeightFileName());
fLogger << kINFO << "creating weight file: " << fname << Endl;
std::ofstream* fout = new std::ofstream( fname );
if (!fout->good()) {
fLogger << kFATAL << "<WriteStateToFile> "
<< "unable to open output weight file: " << fname << Endl;
}
WriteStateToStream( *fout );
}
void TMVA::MethodCommittee::Train( void )
{
if (!CheckSanity()) fLogger << kFATAL << "<Train> sanity check failed" << Endl;
fLogger << kINFO << "will train "<< fNMembers << " committee members ... patience please" << Endl;
Timer timer( fNMembers, GetName() );
for (UInt_t imember=0; imember<fNMembers; imember++){
timer.DrawProgressBar( imember );
TMVA::IMethod *method = 0;
switch(fMemberType) {
case TMVA::Types::kCuts:
method = new TMVA::MethodCuts ( GetJobName(), GetMethodTitle(), Data(), fMemberOption ); break;
case TMVA::Types::kFisher:
method = new TMVA::MethodFisher ( GetJobName(), GetMethodTitle(), Data(), fMemberOption ); break;
case TMVA::Types::kKNN:
method = new TMVA::MethodKNN ( GetJobName(), GetMethodTitle(), Data(), fMemberOption ); break;
case TMVA::Types::kMLP:
method = new TMVA::MethodMLP ( GetJobName(), GetMethodTitle(), Data(), fMemberOption ); break;
case TMVA::Types::kTMlpANN:
method = new TMVA::MethodTMlpANN ( GetJobName(), GetMethodTitle(), Data(), fMemberOption ); break;
case TMVA::Types::kCFMlpANN:
method = new TMVA::MethodCFMlpANN ( GetJobName(), GetMethodTitle(), Data(), fMemberOption ); break;
case TMVA::Types::kLikelihood:
method = new TMVA::MethodLikelihood ( GetJobName(), GetMethodTitle(), Data(), fMemberOption ); break;
case TMVA::Types::kHMatrix:
method = new TMVA::MethodHMatrix ( GetJobName(), GetMethodTitle(), Data(), fMemberOption ); break;
case TMVA::Types::kPDERS:
method = new TMVA::MethodPDERS ( GetJobName(), GetMethodTitle(), Data(), fMemberOption ); break;
case TMVA::Types::kBDT:
method = new TMVA::MethodBDT ( GetJobName(), GetMethodTitle(), Data(), fMemberOption ); break;
case TMVA::Types::kSVM:
method = new TMVA::MethodSVM ( GetJobName(), GetMethodTitle(), Data(), fMemberOption ); break;
case TMVA::Types::kRuleFit:
method = new TMVA::MethodRuleFit ( GetJobName(), GetMethodTitle(), Data(), fMemberOption ); break;
case TMVA::Types::kBayesClassifier:
method = new TMVA::MethodBayesClassifier( GetJobName(), GetMethodTitle(), Data(), fMemberOption ); break;
default:
fLogger << kFATAL << "method: " << fMemberType << " does not exist" << Endl;
}
method->Train();
GetBoostWeights().push_back( this->Boost( method, imember ) );
GetCommittee().push_back( method );
fMonitorNtuple->Fill();
}
fLogger << kINFO << "elapsed time: " << timer.GetElapsedTime()
<< " " << Endl;
}
Double_t TMVA::MethodCommittee::Boost( TMVA::IMethod* method, UInt_t imember )
{
if (fBoostType=="AdaBoost") return this->AdaBoost( method );
else if (fBoostType=="Bagging") return this->Bagging( imember );
else {
fLogger << kINFO << GetOptions() << Endl;
fLogger << kFATAL << "<Boost> unknown boost option called" << Endl;
}
return 1.0;
}
Double_t TMVA::MethodCommittee::AdaBoost( TMVA::IMethod* method )
{
Double_t adaBoostBeta = 1.;
if (!HasTrainingTree()) fLogger << kFATAL << "<AdaBoost> Data().TrainingTree() is zero pointer" << Endl;
Event& event = GetEvent();
Double_t err=0, sumw=0, sumwfalse=0, count=0;
vector<Bool_t> correctSelected;
MethodBase* mbase = (MethodBase*)method;
for (Int_t ievt=0; ievt<Data().GetNEvtTrain(); ievt++) {
ReadTrainingEvent(ievt);
sumw += event.GetBoostWeight();
Bool_t isSignalType = mbase->IsSignalLike();
if (isSignalType == event.IsSignal()) correctSelected.push_back( kTRUE );
else {
sumwfalse += event.GetBoostWeight();
count += 1;
correctSelected.push_back( kFALSE );
}
}
if (0 == sumw) {
fLogger << kFATAL << "<AdaBoost> fatal error sum of event boostweights is zero" << Endl;
}
err = sumwfalse/sumw;
Double_t newSumw=0;
Int_t i=0;
Double_t boostFactor = 1;
if (err>0){
if (adaBoostBeta == 1){
boostFactor = (1-err)/err ;
}
else {
boostFactor = TMath::Power((1-err)/err,adaBoostBeta) ;
}
}
else {
boostFactor = 1000;
}
for (Int_t ievt=0; ievt<Data().GetNEvtTrain(); ievt++) {
ReadTrainingEvent(ievt);
if (!correctSelected[ievt]) event.SetBoostWeight( event.GetBoostWeight() * boostFactor);
newSumw += event.GetBoostWeight();
i++;
}
for (Int_t ievt=0; ievt<Data().GetNEvtTrain(); ievt++) {
event.SetBoostWeight( event.GetBoostWeight() * sumw / newSumw );
}
fBoostFactorHist->Fill(boostFactor);
fErrFractHist->Fill(GetCommittee().size(),err);
fBoostFactor = boostFactor;
fErrorFraction = err;
return log(boostFactor);
}
Double_t TMVA::MethodCommittee::Bagging( UInt_t imember )
{
Double_t newSumw = 0;
TRandom *trandom = new TRandom( imember );
Event& event = GetEvent();
for (Int_t ievt=0; ievt<Data().GetNEvtTrain(); ievt++) {
ReadTrainingEvent(ievt);
Double_t newWeight = trandom->Rndm();
event.SetBoostWeight( newWeight );
newSumw += newWeight;
}
for (Int_t ievt=0; ievt<Data().GetNEvtTrain(); ievt++) {
event.SetBoostWeight( event.GetBoostWeight() * Data().GetNEvtTrain() / newSumw );
}
return 1.0;
}
void TMVA::MethodCommittee::WriteWeightsToStream( ostream& o ) const
{
for (UInt_t imember=0; imember<GetCommittee().size(); imember++) {
o << endl;
o << "------------------------------ new member: " << imember << " ---------------" << endl;
o << "boost weight: " << GetBoostWeights()[imember] << endl;
((MethodBase*)GetCommittee()[imember])->WriteStateToStream( o );
}
}
void TMVA::MethodCommittee::ReadWeightsFromStream( istream& istr )
{
std::vector<IMethod*>::iterator member = GetCommittee().begin();
for (; member != GetCommittee().end(); member++) delete *member;
GetCommittee().clear();
GetBoostWeights().clear();
TString dummy;
UInt_t imember;
Double_t boostWeight;
for (UInt_t i=0; i<fNMembers; i++) {
istr >> dummy >> dummy >> dummy >> imember;
istr >> dummy >> dummy >> boostWeight;
if (imember != i) {
fLogger << kFATAL << "<ReadWeightsFromStream> fatal error while reading Weight file \n "
<< ": mismatch imember: " << imember << " != i: " << i << Endl;
}
TMVA::IMethod *method = 0;
switch(fMemberType) {
case TMVA::Types::kCuts:
method = new TMVA::MethodCuts ( Data(), "" ); break;
case TMVA::Types::kFisher:
method = new TMVA::MethodFisher ( Data(), "" ); break;
case TMVA::Types::kKNN:
method = new TMVA::MethodKNN ( Data(), "" ); break;
case TMVA::Types::kMLP:
method = new TMVA::MethodMLP ( Data(), "" ); break;
case TMVA::Types::kTMlpANN:
method = new TMVA::MethodTMlpANN ( Data(), "" ); break;
case TMVA::Types::kCFMlpANN:
method = new TMVA::MethodCFMlpANN ( Data(), "" ); break;
case TMVA::Types::kLikelihood:
method = new TMVA::MethodLikelihood ( Data(), "" ); break;
case TMVA::Types::kHMatrix:
method = new TMVA::MethodHMatrix ( Data(), "" ); break;
case TMVA::Types::kPDERS:
method = new TMVA::MethodPDERS ( Data(), "" ); break;
case TMVA::Types::kBDT:
method = new TMVA::MethodBDT ( Data(), "" ); break;
case TMVA::Types::kSVM:
method = new TMVA::MethodSVM ( Data(), "" ); break;
case TMVA::Types::kRuleFit:
method = new TMVA::MethodRuleFit ( Data(), "" ); break;
case TMVA::Types::kBayesClassifier:
method = new TMVA::MethodBayesClassifier( Data(), "" ); break;
default:
fLogger << kFATAL << "<ReadWeightsFromStream> fatal error: method: "
<< fMemberType << " does not exist" << Endl;
}
((MethodBase*)method)->ReadStateFromStream(istr);
GetCommittee().push_back(method);
GetBoostWeights().push_back(boostWeight);
}
}
Double_t TMVA::MethodCommittee::GetMvaValue()
{
Double_t myMVA = 0;
Double_t norm = 0;
for (UInt_t itree=0; itree<GetCommittee().size(); itree++) {
Double_t tmpMVA = ( fUseMemberDecision ? ( ((MethodBase*)GetCommittee()[itree])->IsSignalLike() ? 1.0 : -1.0 )
: GetCommittee()[itree]->GetMvaValue() );
if (fUseWeightedMembers){
myMVA += GetBoostWeights()[itree] * tmpMVA;
norm += GetBoostWeights()[itree];
}
else {
myMVA += tmpMVA;
norm += 1;
}
}
return (norm != 0) ? myMVA /= Double_t(norm) : -999;
}
void TMVA::MethodCommittee::WriteMonitoringHistosToFile( void ) const
{
fLogger << kINFO << "write monitoring histograms to file: " << BaseDir()->GetPath() << Endl;
fBoostFactorHist->Write();
fErrFractHist->Write();
fMonitorNtuple->Write();
BaseDir()->cd();
}
vector< Double_t > TMVA::MethodCommittee::GetVariableImportance()
{
fVariableImportance.resize(GetNvar());
return fVariableImportance;
}
Double_t TMVA::MethodCommittee::GetVariableImportance(UInt_t ivar)
{
vector<Double_t> relativeImportance = this->GetVariableImportance();
if (ivar < (UInt_t)relativeImportance.size()) return relativeImportance[ivar];
else fLogger << kFATAL << "<GetVariableImportance> ivar = " << ivar << " is out of range " << Endl;
return -1;
}
const TMVA::Ranking* TMVA::MethodCommittee::CreateRanking()
{
fRanking = new Ranking( GetName(), "Variable Importance" );
vector< Double_t> importance(this->GetVariableImportance());
for (Int_t ivar=0; ivar<GetNvar(); ivar++) {
fRanking->AddRank( *new Rank( GetInputExp(ivar), importance[ivar] ) );
}
return fRanking;
}
void TMVA::MethodCommittee::MakeClassSpecific( std::ostream& fout, const TString& className ) const
{
fout << " // not implemented for class: \"" << className << "\"" << endl;
fout << "};" << endl;
}
void TMVA::MethodCommittee::GetHelpMessage() const
{
fLogger << Endl;
fLogger << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
fLogger << Endl;
fLogger << "<None>" << Endl;
fLogger << Endl;
fLogger << gTools().Color("bold") << "--- Performance optimisation:" << gTools().Color("reset") << Endl;
fLogger << Endl;
fLogger << "<None>" << Endl;
fLogger << Endl;
fLogger << gTools().Color("bold") << "--- Performance tuning via configuration options:" << gTools().Color("reset") << Endl;
fLogger << Endl;
fLogger << "<None>" << Endl;
}
Last change: Sat Nov 1 10:21:43 2008
Last generated: 2008-11-01 10:21
This page has been automatically generated. If you have any comments or suggestions about the page layout send a mail to ROOT support, or contact the developers with any questions or problems regarding ROOT.