#include "Riostream.h"
#include "TMath.h"
#include "TVectorD.h"
#include "TH1.h"
#include "TH2.h"
#include "TProfile.h"
#include "TMVA/VariableTransformBase.h"
#include "TMVA/Ranking.h"
#include "TMVA/Config.h"
#include "TMVA/Tools.h"
ClassImp(TMVA::VariableTransformBase)
TMVA::VariableTransformBase::VariableTransformBase( std::vector<VariableInfo>& varinfo,
Types::EVariableTransform tf )
: TObject(),
fEvent( 0 ),
fEventRaw( 0 ),
fVariableTransform(tf),
fEnabled( kTRUE ),
fCreated( kFALSE ),
fNormalise( kFALSE ),
fTransformName("TransBase"),
fVariables( varinfo ),
fCurrentTree(0),
fCurrentEvtIdx(0),
fOutputBaseDir(0),
fRanking(0),
fLogger( GetName() )
{
std::vector<VariableInfo>::iterator it = fVariables.begin();
for (; it!=fVariables.end(); it++ ) (*it).ResetMinMax();
}
TMVA::VariableTransformBase::~VariableTransformBase()
{
if (fEvent != fEventRaw && fEvent != 0) { delete fEvent; fEvent = 0; }
if (fEventRaw != 0) { delete fEventRaw; fEventRaw = 0; }
if(fRanking) delete fRanking;
}
void TMVA::VariableTransformBase::ResetBranchAddresses( TTree* tree ) const
{
tree->ResetBranchAddresses();
fCurrentTree = 0;
GetEventRaw().SetBranchAddresses(tree);
}
void TMVA::VariableTransformBase::CreateEvent() const
{
Bool_t allowExternalLinks = kFALSE;
fEvent = new Event(fVariables, allowExternalLinks);
}
Bool_t TMVA::VariableTransformBase::ReadEvent( TTree* tr, UInt_t evidx, Types::ESBType type ) const
{
if (tr == 0) fLogger << kFATAL << "<ReadEvent> zero Tree Pointer encountered" << Endl;
Bool_t needRead = kFALSE;
if (fEventRaw == 0) {
needRead = kTRUE;
GetEventRaw();
ResetBranchAddresses( tr );
}
if (tr != fCurrentTree) {
needRead = kTRUE;
if (fCurrentTree!=0) fCurrentTree->ResetBranchAddresses();
fCurrentTree = tr;
ResetBranchAddresses( tr );
}
if (evidx != fCurrentEvtIdx) {
needRead = kTRUE;
fCurrentEvtIdx = evidx;
}
if (!needRead) return kTRUE;
std::vector<TBranch*>::iterator brIt = fEventRaw->Branches().begin();
for (;brIt!=fEventRaw->Branches().end(); brIt++) (*brIt)->GetEntry(evidx);
if (type == Types::kTrueType ) type = fEventRaw->IsSignal() ? Types::kSignal : Types::kBackground;
ApplyTransformation(type);
return kTRUE;
}
void TMVA::VariableTransformBase::UpdateNorm ( Int_t ivar, Double_t x )
{
if (x < fVariables[ivar].GetMin()) fVariables[ivar].SetMin( x );
if (x > fVariables[ivar].GetMax()) fVariables[ivar].SetMax( x );
}
void TMVA::VariableTransformBase::CalcNorm( TTree* tr )
{
if (!IsCreated()) return;
if (tr == 0) return;
ResetBranchAddresses( tr );
UInt_t nvar = GetNVariables();
UInt_t nevts = tr->GetEntries();
TVectorD x2( nvar ); x2 *= 0;
TVectorD x0( nvar ); x0 *= 0;
Double_t sumOfWeights = 0;
for (UInt_t ievt=0; ievt<nevts; ievt++) {
ReadEvent( tr, ievt, Types::kSignal );
Double_t weight = GetEvent().GetWeight();
sumOfWeights += weight;
for (UInt_t ivar=0; ivar<nvar; ivar++) {
Double_t x = GetEvent().GetVal(ivar);
UpdateNorm( ivar, x );
x0(ivar) += x*weight;
x2(ivar) += x*x*weight;
}
}
for (UInt_t ivar=0; ivar<nvar; ivar++) {
Double_t mean = x0(ivar)/sumOfWeights;
fVariables[ivar].SetMean( mean );
fVariables[ivar].SetRMS( TMath::Sqrt( x2(ivar)/sumOfWeights - mean*mean) );
}
fLogger << kVERBOSE << "Set minNorm/maxNorm for variables to: " << Endl;
fLogger << setprecision(3);
for (UInt_t ivar=0; ivar<GetNVariables(); ivar++)
fLogger << " " << fVariables[ivar].GetInternalVarName()
<< "\t: [" << fVariables[ivar].GetMin() << "\t, " << fVariables[ivar].GetMax() << "\t] " << Endl;
fLogger << setprecision(5);
}
void TMVA::VariableTransformBase::PlotVariables( TTree* theTree )
{
if (!IsCreated()) return;
if (theTree == 0) return;
ResetBranchAddresses( theTree );
fLogger << kVERBOSE << "Plot input variables from '" << theTree->GetName() << "'" << Endl;
TString transfType = "_"; transfType += GetName();
const UInt_t nvar = GetNVariables();
TVectorD x2S( nvar ); x2S *= 0;
TVectorD x2B( nvar ); x2B *= 0;
TVectorD x0S( nvar ); x0S *= 0;
TVectorD x0B( nvar ); x0B *= 0;
TVectorD rmsS( nvar ), meanS( nvar );
TVectorD rmsB( nvar ), meanB( nvar );
UInt_t nevts = (UInt_t)theTree->GetEntries();
Double_t nS = 0, nB = 0;
for (UInt_t ievt=0; ievt<nevts; ievt++) {
ReadEvent( theTree, ievt, Types::kSignal );
Double_t weight = GetEvent().GetWeight();
if (GetEvent().IsSignal()) nS += weight;
else nB += weight;
for (UInt_t ivar=0; ivar<nvar; ivar++) {
Double_t x = GetEvent().GetVal(ivar);
if (GetEvent().IsSignal()) {
x0S(ivar) += x*weight;
x2S(ivar) += x*x*weight;
}
else {
x0B(ivar) += x*weight;
x2B(ivar) += x*x*weight;
}
}
}
for (UInt_t ivar=0; ivar<nvar; ivar++) {
meanS(ivar) = x0S(ivar)/nS;
meanB(ivar) = x0B(ivar)/nB;
rmsS(ivar) = x2S(ivar)/nS - x0S(ivar)*x0S(ivar)/nS/nS;
rmsB(ivar) = x2B(ivar)/nB - x0B(ivar)*x0B(ivar)/nB/nB;
if (rmsS(ivar) <= 0) {
fLogger << kWARNING << "Variable \"" << Variable(ivar).GetExpression()
<< "\" has zero or negative RMS^2 for signal "
<< "==> set to zero. Please check the variable content" << Endl;
rmsS(ivar) = 0;
}
if (rmsB(ivar) <= 0) {
fLogger << kWARNING << "Variable \"" << Variable(ivar).GetExpression()
<< "\" has zero or negative RMS^2 for background "
<< "==> set to zero. Please check the variable content" << Endl;
rmsB(ivar) = 0;
}
rmsS(ivar) = TMath::Sqrt( rmsS(ivar) );
rmsB(ivar) = TMath::Sqrt( rmsB(ivar) );
}
std::vector<TH1F*> vS ( nvar );
std::vector<TH1F*> vB ( nvar );
std::vector<std::vector<TH2F*> > mycorrS( nvar );
std::vector<std::vector<TH2F*> > mycorrB( nvar );
std::vector<std::vector<TProfile*> > myprofS( nvar );
std::vector<std::vector<TProfile*> > myprofB( nvar );
for (UInt_t ivar=0; ivar < nvar; ivar++) {
mycorrS[ivar].resize(nvar);
mycorrB[ivar].resize(nvar);
myprofS[ivar].resize(nvar);
myprofB[ivar].resize(nvar);
}
if (nvar > (UInt_t)gConfig().GetVariablePlotting().fMaxNumOfAllowedVariablesForScatterPlots) {
Int_t nhists = nvar*(nvar - 1)/2;
fLogger << kINFO << gTools().Color("dgreen") << Endl;
fLogger << kINFO << "<PlotVariables> Will not produce scatter plots ==> " << Endl;
fLogger << kINFO
<< "| The number of " << nvar << " input variables would require " << nhists << " two-dimensional" << Endl;
fLogger << kINFO
<< "| histograms, which would flood the computer's memory. Note that this" << Endl;
fLogger << kINFO
<< "| suppression does not have any consequences for your analysis, other" << Endl;
fLogger << kINFO
<< "| than not disposing of these scatter plots. You can modify the maximum" << Endl;
fLogger << kINFO
<< "| number of input variables allowed to generate scatter plots in your" << Endl;
fLogger << "| script via the command line:" << Endl;
fLogger << kINFO
<< "| \"(TMVA::gConfig().GetVariablePlotting()).fMaxNumOfAllowedVariablesForScatterPlots = <some int>;\""
<< gTools().Color("reset") << Endl;
fLogger << Endl;
fLogger << kINFO << "Some more output" << Endl;
}
Float_t timesRMS = gConfig().GetVariablePlotting().fTimesRMS;
UInt_t nbins1D = gConfig().GetVariablePlotting().fNbins1D;
UInt_t nbins2D = gConfig().GetVariablePlotting().fNbins2D;
for (UInt_t i=0; i<nvar; i++) {
TString myVari = Variable(i).GetInternalVarName();
if (Variable(i).GetVarType() == 'I') {
Int_t xmin = TMath::Nint( Variable(i).GetMin() );
Int_t xmax = TMath::Nint( Variable(i).GetMax() + 1 );
Int_t nbins = xmax - xmin;
vS[i] = new TH1F( Form("%s__S%s", myVari.Data(), transfType.Data()), Variable(i).GetExpression(), nbins, xmin, xmax );
vB[i] = new TH1F( Form("%s__B%s", myVari.Data(), transfType.Data()), Variable(i).GetExpression(), nbins, xmin, xmax );
}
else {
Double_t xmin = TMath::Max( Variable(i).GetMin(), TMath::Min( meanS(i) - timesRMS*rmsS(i), meanB(i) - timesRMS*rmsB(i) ) );
Double_t xmax = TMath::Min( Variable(i).GetMax(), TMath::Max( meanS(i) + timesRMS*rmsS(i), meanB(i) + timesRMS*rmsB(i) ) );
if (xmin >= xmax) xmax = xmin*1.1;
if (xmin >= xmax) xmax = xmin + 1;
xmax += (xmax - xmin)/nbins1D;
vS[i] = new TH1F( Form("%s__S%s", myVari.Data(), transfType.Data()), Variable(i).GetExpression(), nbins1D, xmin, xmax );
vB[i] = new TH1F( Form("%s__B%s", myVari.Data(), transfType.Data()), Variable(i).GetExpression(), nbins1D, xmin, xmax );
}
vS[i]->SetXTitle(Variable(i).GetExpression());
vB[i]->SetXTitle(Variable(i).GetExpression());
vS[i]->SetLineColor(4);
vB[i]->SetLineColor(2);
if (nvar <= (UInt_t)gConfig().GetVariablePlotting().fMaxNumOfAllowedVariablesForScatterPlots) {
for (UInt_t j=i+1; j<nvar; j++) {
TString myVarj = Variable(j).GetInternalVarName();
mycorrS[i][j] = new TH2F( Form( "scat_%s_vs_%s_sig%s", myVarj.Data(), myVari.Data(), transfType.Data() ),
Form( "%s versus %s (signal)%s", myVarj.Data(), myVari.Data(), transfType.Data() ),
nbins2D, Variable(i).GetMin(), Variable(i).GetMax(),
nbins2D, Variable(j).GetMin(), Variable(j).GetMax() );
mycorrS[i][j]->SetXTitle(Variable(i).GetExpression());
mycorrS[i][j]->SetYTitle(Variable(j).GetExpression());
mycorrB[i][j] = new TH2F( Form( "scat_%s_vs_%s_bgd%s", myVarj.Data(), myVari.Data(), transfType.Data() ),
Form( "%s versus %s (background)%s", myVarj.Data(), myVari.Data(), transfType.Data() ),
nbins2D, Variable(i).GetMin(), Variable(i).GetMax(),
nbins2D, Variable(j).GetMin(), Variable(j).GetMax() );
mycorrB[i][j]->SetXTitle(Variable(i).GetExpression());
mycorrB[i][j]->SetYTitle(Variable(j).GetExpression());
myprofS[i][j] = new TProfile( Form( "prof_%s_vs_%s_sig%s", myVarj.Data(), myVari.Data(), transfType.Data() ),
Form( "profile %s versus %s (signal)%s", myVarj.Data(), myVari.Data(), transfType.Data() ),
nbins1D, Variable(i).GetMin(), Variable(i).GetMax() );
myprofB[i][j] = new TProfile( Form( "prof_%s_vs_%s_bgd%s", myVarj.Data(), myVari.Data(), transfType.Data() ),
Form( "profile %s versus %s (background)%s", myVarj.Data(), myVari.Data(), transfType.Data() ),
nbins1D, Variable(i).GetMin(), Variable(i).GetMax() );
}
}
}
for (Int_t ievt=0; ievt<theTree->GetEntries(); ievt++) {
ReadEvent( theTree, ievt, Types::kSignal );
Float_t weight = GetEvent().GetWeight();
for (UInt_t i=0; i<nvar; i++) {
Float_t vali = GetEvent().GetVal(i);
if (GetEvent().IsSignal()) vS[i]->Fill( vali, weight );
else vB[i]->Fill( vali, weight );
if (nvar <= (UInt_t)gConfig().GetVariablePlotting().fMaxNumOfAllowedVariablesForScatterPlots) {
for (UInt_t j=i+1; j<nvar; j++) {
Float_t valj = GetEvent().GetVal(j);
if (GetEvent().IsSignal()) {
mycorrS[i][j]->Fill( vali, valj, weight );
myprofS[i][j]->Fill( vali, valj, weight );
}
else {
mycorrB[i][j]->Fill( vali, valj, weight );
myprofB[i][j]->Fill( vali, valj, weight );
}
}
}
}
}
if (fRanking) delete fRanking;
fRanking = new Ranking( GetName(), "Separation" );
for (UInt_t i=0; i<nvar; i++) {
Double_t sep = gTools().GetSeparation( *vS[i], *vB[i] );
fRanking->AddRank( Rank( vS[i]->GetTitle(), sep ) );
}
TString outputDir = TString("InputVariables_") + GetName();
TObject* o = GetOutputBaseDir()->FindObject(outputDir);
if (o != 0) {
fLogger << kFATAL << "A " << o->ClassName() << " already exists in "
<< GetOutputBaseDir()->GetPath() << Endl;
}
TDirectory* localDir = GetOutputBaseDir()->mkdir( outputDir );
localDir->cd();
fLogger << kVERBOSE << "Create and switch to directory " << localDir->GetPath() << Endl;
for (UInt_t i=0; i<nvar; i++) {
vS[i]->Write();
vB[i]->Write();
vS[i]->SetDirectory(0);
vB[i]->SetDirectory(0);
delete vS[i];
delete vB[i];
}
if (nvar <= (UInt_t)gConfig().GetVariablePlotting().fMaxNumOfAllowedVariablesForScatterPlots) {
localDir = localDir->mkdir( "CorrelationPlots" );
localDir ->cd();
fLogger << kINFO << "Create scatter and profile plots in target-file directory: " << Endl;
fLogger << kINFO << localDir->GetPath() << Endl;
for (UInt_t i=0; i<nvar; i++) {
for (UInt_t j=i+1; j<nvar; j++) {
mycorrS[i][j]->Write();
mycorrB[i][j]->Write();
myprofS[i][j]->Write();
myprofB[i][j]->Write();
mycorrS[i][j]->SetDirectory(0);
mycorrB[i][j]->SetDirectory(0);
myprofS[i][j]->SetDirectory(0);
myprofB[i][j]->SetDirectory(0);
delete mycorrS[i][j];
delete mycorrB[i][j];
delete myprofS[i][j];
delete myprofB[i][j];
}
}
}
GetOutputBaseDir()->cd();
theTree->ResetBranchAddresses();
}
void TMVA::VariableTransformBase::PrintVariableRanking() const
{
fLogger << kINFO << "Ranking input variables..." << Endl;
fRanking->Print();
}
void TMVA::VariableTransformBase::WriteVarsToStream( std::ostream& o, const TString& prefix ) const
{
o << prefix << "NVar " << GetNVariables() << endl;
std::vector<VariableInfo>::const_iterator varIt = fVariables.begin();
for (; varIt!=fVariables.end(); varIt++) { o << prefix; varIt->WriteToStream(o); }
}
void TMVA::VariableTransformBase::ReadVarsFromStream( std::istream& istr )
{
TString dummy;
UInt_t readNVar;
istr >> dummy >> readNVar;
if (readNVar!=fVariables.size()) {
fLogger << kFATAL << "You declared "<< fVariables.size() << " variables in the Reader"
<< " while there are " << readNVar << " variables declared in the file"
<< Endl;
}
VariableInfo varInfo;
std::vector<VariableInfo>::iterator varIt = fVariables.begin();
int varIdx = 0;
for (; varIt!=fVariables.end(); varIt++, varIdx++) {
varInfo.ReadFromStream(istr);
if (varIt->GetExpression() == varInfo.GetExpression()) {
varInfo.SetExternalLink((*varIt).GetExternalLink());
(*varIt) = varInfo;
}
else {
fLogger << kINFO << "The definition (or the order) of the variables found in the input file is" << Endl;
fLogger << kINFO << "is not the same as the one declared in the Reader (which is necessary for" << Endl;
fLogger << kINFO << "the correct working of the classifier):" << Endl;
fLogger << kINFO << " var #" << varIdx <<" declared in Reader: " << varIt->GetExpression() << Endl;
fLogger << kINFO << " var #" << varIdx <<" declared in file : " << varInfo.GetExpression() << Endl;
fLogger << kFATAL << "The expression declared to the Reader needs to be checked (name and/or order is wrong)" << Endl;
}
}
}
Last change: Sat Nov 1 10:22:02 2008
Last generated: 2008-11-01 10:22
This page has been automatically generated. If you have any comments or suggestions about the page layout send a mail to ROOT support, or contact the developers with any questions or problems regarding ROOT.