#include <string>
#include <stdlib.h>
#include "TFile.h"
#include "TMath.h"
#include "TTree.h"
#include "TMVA/MethodKNN.h"
#include "TMVA/Tools.h"
#include "TMVA/Ranking.h"
ClassImp(TMVA::MethodKNN)
using std::endl;
TMVA::MethodKNN::MethodKNN(const TString& jobName,
const TString& methodTitle,
DataSet& theData,
const TString& theOption,
TDirectory* theTargetDir)
:TMVA::MethodBase(jobName, methodTitle, theData, theOption, theTargetDir),
fModule(0)
{
InitKNN();
SetConfigName( TString("Method") + GetMethodName() );
DeclareOptions();
ParseOptions();
ProcessOptions();
}
TMVA::MethodKNN::MethodKNN(DataSet& theData,
const TString& theWeightFile,
TDirectory* theTargetDir)
:TMVA::MethodBase(theData, theWeightFile, theTargetDir),
fModule(0)
{
InitKNN();
DeclareOptions();
}
TMVA::MethodKNN::~MethodKNN()
{
if (fModule) delete fModule;
}
void TMVA::MethodKNN::DeclareOptions()
{
DeclareOptionRef(fnkNN = 40, "nkNN", "Number of k-nearest neighbors");
DeclareOptionRef(fTreeOptDepth = 6, "TreeOptDepth", "Binary tree optimisation depth");
DeclareOptionRef(fScaleFrac = 0.80, "ScaleFrac", "Fraction of events used for scaling");
DeclareOptionRef(fUseKernel = kFALSE, "UseKernel", "Use polynomial kernel weight");
DeclareOptionRef(fTrim = kFALSE, "Trim", "Use equal number of signal and background events");
}
void TMVA::MethodKNN::ProcessOptions()
{
MethodBase::ProcessOptions();
if (!(fnkNN > 0)) {
fnkNN = 10;
fLogger << kWARNING << "kNN must be a positive integer: set kNN = " << fnkNN << Endl;
}
if (fScaleFrac < 0.0) {
fScaleFrac = 0.0;
fLogger << kWARNING << "ScaleFrac can not be negative: set ScaleFrac = " << fScaleFrac << Endl;
}
if (fScaleFrac > 1.0) {
fScaleFrac = 1.0;
}
if (!(fTreeOptDepth > 0)) {
fTreeOptDepth = 6;
fLogger << kWARNING << "Optimize must be a positive integer: set Optimize = " << fTreeOptDepth << Endl;
}
fLogger << kVERBOSE
<< "kNN options: " << Endl
<< " kNN = " << fnkNN << Endl
<< " UseKernel = " << fUseKernel << Endl
<< " ScaleFrac = " << fScaleFrac << Endl
<< " Trim = " << fTrim << Endl
<< " Optimize = " << fTreeOptDepth << Endl;
}
void TMVA::MethodKNN::InitKNN()
{
SetMethodName("KNN");
SetMethodType(TMVA::Types::kKNN);
SetTestvarName();
fModule = new kNN::ModulekNN();
fSumOfWeightsS = 0;
fSumOfWeightsB = 0;
}
void TMVA::MethodKNN::MakeKNN()
{
if (!fModule) fLogger << kFATAL << "ModulekNN is not created" << Endl;
fModule->Clear();
std::string option;
if (fScaleFrac > 0.0) {
option += "metric";
}
if (fTrim) {
option += "trim";
}
fLogger << kINFO << "Creating kd-tree with " << fEvent.size() << " events" << Endl;
for (kNN::EventVec::const_iterator event = fEvent.begin(); event != fEvent.end(); ++event) {
fModule->Add(*event);
}
fModule->Fill(static_cast<UInt_t>(fTreeOptDepth),
static_cast<UInt_t>(100.0*fScaleFrac),
option);
}
Double_t TMVA::MethodKNN::PolKernel(const Double_t value) const
{
const Double_t avalue = TMath::Abs(value);
if (!(avalue < 1.0)) {
return 0.0;
}
const Double_t prod = 1.0 - avalue * avalue * avalue;
return (prod * prod * prod);
}
void TMVA::MethodKNN::Train()
{
fLogger << kINFO << "<Train> start..." << Endl;
if (!CheckSanity()) {
fLogger << kFATAL << "Sanity check failed" << Endl;
}
if (IsNormalised()) {
fLogger << kINFO << "Input events are normalized - setting ScaleFrac to 0" << Endl;
fScaleFrac = 0.0;
}
if (!fEvent.empty()) {
fLogger << kINFO << "Erasing " << fEvent.size() << " previously stored events" << Endl;
fEvent.clear();
}
Int_t nvar = -1;
fLogger << kINFO << "Reading " << Data().GetNEvtTrain() << " events" << Endl;
for (Int_t ievt = 0; ievt < Data().GetNEvtTrain(); ++ievt) {
ReadTrainingEvent(ievt);
if (nvar < 0) {
nvar = GetNvar();
}
if (nvar != GetNvar() || nvar < 1) {
fLogger << kFATAL << "MethodKNN::Train() - mismatched or wrong number of event variables" << Endl;
}
const Double_t weight = GetEventWeight();
kNN::VarVec vvec(static_cast<UInt_t>(nvar), 0.0);
for (Int_t ivar = 0; ivar < nvar; ++ivar) {
vvec[ivar] = GetEventVal(ivar);
}
if (IsSignalEvent()) {
fSumOfWeightsS += weight;
fEvent.push_back(kNN::Event(vvec, weight, 1));
}
else {
fSumOfWeightsB += weight;
fEvent.push_back(kNN::Event(vvec, weight, 2));
}
}
fLogger << kINFO << "Number of signal events " << fSumOfWeightsS << Endl;
fLogger << kINFO << "Number of background events " << fSumOfWeightsB << Endl;
MakeKNN();
}
Double_t TMVA::MethodKNN::GetMvaValue()
{
const Int_t nvar = GetNvar();
const Double_t evweight = GetEventWeight();
kNN::VarVec vvec(static_cast<UInt_t>(nvar), 0.0);
for (Int_t ivar = 0; ivar < nvar; ++ivar) {
vvec[ivar] = GetEventVal(ivar);
}
const UInt_t knn = static_cast<UInt_t>(fnkNN);
fModule->Find(kNN::Event(vvec, evweight, 3), knn + 1);
const kNN::List &rlist = fModule->GetkNNList();
if (rlist.size() != knn + 1) {
fLogger << kFATAL << "kNN result list is empty" << Endl;
return -100.0;
}
Double_t maxradius = -1.0;
if (fUseKernel) {
UInt_t kcount = 0;
for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit) {
if (!(lit->second > 0.0)) {
continue;
}
++kcount;
if (maxradius < lit->second || maxradius < 0.0) {
maxradius = lit->second;
}
if (kcount == knn) {
break;
}
}
if (!(maxradius > 0.0)) {
fLogger << kFATAL << "kNN radius is not positive" << Endl;
return -100.0;
}
maxradius = 1.0/TMath::Sqrt(maxradius);
}
UInt_t all_count = 0;
Double_t all_weight = 0, sig_weight = 0, bac_weight = 0;
for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit) {
const kNN::Node<kNN::Event> &node = *(lit->first);
if (!(lit->second > 0.0)) {
continue;
}
Double_t weight = node.GetWeight();
if (fUseKernel) {
weight *= PolKernel(TMath::Sqrt(lit->second)*maxradius);
}
++all_count;
all_weight += weight;
if (node.GetEvent().GetType() == 1) {
sig_weight += weight;
}
else if (node.GetEvent().GetType() == 2) {
bac_weight += weight;
}
else {
fLogger << kFATAL << "Unknown type for training event" << Endl;
}
if (all_count == knn) {
break;
}
}
if (all_count < 1 || all_count != knn) {
fLogger << kFATAL << "kNN result list is empty or has wrong size" << Endl;
return -100.0;
}
if (!(all_weight > 0.0)) {
fLogger << kFATAL << "kNN result total weight is not positive" << Endl;
return -100.0;
}
return sig_weight/all_weight;
}
const TMVA::Ranking* TMVA::MethodKNN::CreateRanking()
{
return 0;
}
void TMVA::MethodKNN::WriteWeightsToStream(ostream& os) const
{
fLogger << kINFO << "Starting WriteWeightsToStream(ostream& os) function..." << Endl;
if (fEvent.empty()) {
fLogger << kWARNING << "MethodKNN contains no events " << Endl;
return;
}
os << "# MethodKNN will write " << fEvent.size() << " events " << endl;
os << "# event number, type, weight, variable values" << endl;
const std::string delim = ", ";
UInt_t ievent = 0;
for (kNN::EventVec::const_iterator event = fEvent.begin(); event != fEvent.end(); ++event, ++ievent) {
os << ievent << delim;
os << event->GetType() << delim;
os << event->GetWeight() << delim;
for (UInt_t ivar = 0; ivar < event->GetNVar(); ++ivar) {
if (ivar + 1 < event->GetNVar()) {
os << event->GetVar(ivar) << delim;
}
else {
os << event->GetVar(ivar) << endl;
}
}
}
}
void TMVA::MethodKNN::ReadWeightsFromStream(istream& is)
{
fLogger << kINFO << "Starting ReadWeightsFromStream(istream& is) function..." << Endl;
if (!fEvent.empty()) {
fLogger << kINFO << "Erasing " << fEvent.size() << " previously stored events" << Endl;
fEvent.clear();
}
UInt_t nvar = 0;
while (!is.eof()) {
std::string line;
std::getline(is, line);
if (line.empty() || line.find("#") != std::string::npos) {
continue;
}
UInt_t count = 0;
std::string::size_type pos=0;
while( (pos=line.find(',',pos)) != std::string::npos ) { count++; pos++; }
if (nvar == 0) {
nvar = count - 2;
}
if (count < 3 || nvar != count - 2) {
fLogger << kFATAL << "Missing comma delimeter(s)" << Endl;
}
Int_t ievent = -1, type = -1;
Double_t weight = -1.0;
kNN::VarVec vvec(nvar, 0.0);
UInt_t vcount = 0;
std::string::size_type prev = 0;
for (std::string::size_type ipos = 0; ipos < line.size(); ++ipos) {
if (line[ipos] != ',' && ipos + 1 != line.size()) continue;
if (!(ipos > prev)) fLogger << kFATAL << "Wrog substring limits" << Endl;
std::string vstring = line.substr(prev, ipos - prev);
if (ipos + 1 == line.size()) vstring = line.substr(prev, ipos - prev + 1);
if (vstring.empty()) fLogger << kFATAL << "Failed to parse string" << Endl;
if (vcount == 0) ievent = atoi(vstring.c_str());
else if (vcount == 1) type = atoi(vstring.c_str());
else if (vcount == 2) weight = atof(vstring.c_str());
else if (vcount - 3 < vvec.size()) vvec[vcount - 3] = atof(vstring.c_str());
else fLogger << kFATAL << "Wrong variable count" << Endl;
prev = ipos + 1;
++vcount;
}
fEvent.push_back(kNN::Event(vvec, weight, type));
}
fLogger << kINFO << "Read " << fEvent.size() << " events from text file" << Endl;
MakeKNN();
}
void TMVA::MethodKNN::WriteWeightsToStream(TFile &rf) const
{
fLogger << kINFO << "Starting WriteWeightsToStream(TFile &rf) function..." << Endl;
if (fEvent.empty()) {
fLogger << kWARNING << "MethodKNN contains no events " << Endl;
return;
}
kNN::Event *event = new kNN::Event();
TTree *tree = new TTree("knn", "event tree");
tree->SetDirectory(0);
tree->Branch("event", "TMVA::kNN::Event", &event);
Double_t size = 0.0;
for (kNN::EventVec::const_iterator it = fEvent.begin(); it != fEvent.end(); ++it) {
(*event) = (*it);
size += tree->Fill();
}
rf.WriteTObject(tree, "knn", "Overwrite");
size /= 1048576.0;
fLogger << kINFO << "Wrote " << size << "MB and " << fEvent.size()
<< " events to ROOT file" << Endl;
delete tree;
delete event;
}
void TMVA::MethodKNN::ReadWeightsFromStream(TFile &rf)
{
fLogger << kINFO << "Starting ReadWeightsFromStream(TFile &rf) function..." << Endl;
if (!fEvent.empty()) {
fLogger << kINFO << "Erasing " << fEvent.size() << " previously stored events" << Endl;
fEvent.clear();
}
TTree *tree = dynamic_cast<TTree *>(rf.Get("knn"));
if (!tree) {
fLogger << kFATAL << "Failed to find knn tree" << Endl;
return;
}
kNN::Event *event = new kNN::Event();
tree->SetBranchAddress("event", &event);
const Int_t nevent = tree->GetEntries();
Double_t size = 0.0;
for (Int_t i = 0; i < nevent; ++i) {
size += tree->GetEntry(i);
fEvent.push_back(*event);
}
size /= 1048576.0;
fLogger << kINFO << "Read " << size << "MB and " << fEvent.size()
<< " events from ROOT file" << Endl;
delete event;
MakeKNN();
}
void TMVA::MethodKNN::MakeClassSpecific(std::ostream& fout, const TString& className) const
{
fout << " // not implemented for class: \"" << className << "\"" << endl;
fout << "};" << endl;
}
void TMVA::MethodKNN::GetHelpMessage() const
{
fLogger << Endl;
fLogger << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
fLogger << Endl;
fLogger << "Sorry - not available" << Endl;
fLogger << Endl;
fLogger << gTools().Color("bold") << "--- Performance optimisation:" << gTools().Color("reset") << Endl;
fLogger << Endl;
fLogger << "Sorry - not available" << Endl;
fLogger << Endl;
fLogger << gTools().Color("bold") << "--- Performance tuning via configuration options:"
<< gTools().Color("reset") << Endl;
fLogger << Endl;
fLogger << "Sorry - not available" << Endl;
}
Last change: Sat Nov 1 10:21:48 2008
Last generated: 2008-11-01 10:21
This page has been automatically generated. If you have any comments or suggestions about the page layout send a mail to ROOT support, or contact the developers with any questions or problems regarding ROOT.