Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions roofit/histfactory/test/testHistFactory.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

#include "../../roofitcore/test/gtest_wrapper.h"

#include <regex>
#include <set>

namespace {
Expand Down Expand Up @@ -709,3 +710,87 @@ INSTANTIATE_TEST_SUITE_P(HistFactory, HFFixtureFit,
testing::Values(false, true), // non-uniform bins or not
testing::Values(ROOFIT_EVAL_BACKENDS_WITH_CODEGEN)),
getNameFromInfo);

// Regression test for the HS3 importer's handling of "shapefactor" modifiers.
// HistFactory's ShapeFactor is exported with type "shapesys" (the only
// modifier type the writer ever emits for a ParamHistFunc), but valid HS3 JSON
// may also use the dedicated "shapefactor" type for an unconstrained
// ParamHistFunc - and the importer must accept it. This test starts from a
// HistFactory model built with MakeModelAndMeasurementFast, rewrites the
// modifier type to "shapefactor" in the JSON, and checks that the importer
// recognises it. Without the fix in JSONFactories_HistFactory.cxx the import
// throws "modifier ... of unknown type 'shapefactor'".
TEST(HistFactory, HS3ImportShapeFactorModifier)
{
using namespace RooStats::HistFactory;
RooHelpers::LocalChangeMsgLevel changeMsgLvl(RooFit::WARNING);

const std::string inputFile = "TestHS3ShapeFactor.root";
{
TFile f(inputFile.c_str(), "RECREATE");
auto *data = new TH1D("data", "data", 2, 1, 2);
auto *signal = new TH1D("signal", "signal", 2, 1, 2);
auto *bkg = new TH1D("background", "background", 2, 1, 2);
data->SetBinContent(1, 220);
data->SetBinContent(2, 230);
signal->SetBinContent(1, 10);
signal->SetBinContent(2, 20);
bkg->SetBinContent(1, 200);
bkg->SetBinContent(2, 200);
for (auto *h : {data, signal, bkg})
f.WriteTObject(h);
}

Measurement meas("meas", "meas");
meas.SetOutputFilePrefix("HS3ShapeFactor");
meas.SetPOI("SigXsecOverSM");
meas.AddConstantParam("Lumi");
meas.SetLumi(1.0);
meas.SetLumiRelErr(0.10);

Channel chan("channel1");
chan.SetData("data", inputFile);

Sample sig("signal", "signal", inputFile);
sig.AddNormFactor("SigXsecOverSM", 1, 0, 3);
chan.AddSample(sig);

// ShapeFactor on the background: an unconstrained, bin-by-bin scaling.
// Make the gammas constant so that the workspace is well-defined for
// re-export (free shapefactor gammas have no constraints attached).
Sample bkg("background", "background", inputFile);
ShapeFactor sf;
sf.SetName("bkgShape");
sf.SetConstant(true);
bkg.AddShapeFactor(sf);
chan.AddSample(bkg);

meas.AddChannel(chan);
meas.CollectHistograms();

std::unique_ptr<RooWorkspace> ws{MakeModelAndMeasurementFast(meas)};
ASSERT_NE(ws, nullptr);

const std::string js = RooJSONFactoryWSTool{*ws}.exportJSONtoString();

// Rewrite the modifier type for "bkgShape" from "shapesys" to "shapefactor".
// The HistFactory exporter always writes "shapesys", but the importer should
// accept the more accurate "shapefactor" type as well.
const std::regex pattern{"\"name\":\"bkgShape\",\"parameters\":\\[([^\\]]*)\\],\"type\":\"shapesys\""};
const std::string jsShapeFactor =
std::regex_replace(js, pattern, "\"name\":\"bkgShape\",\"parameters\":[$1],\"type\":\"shapefactor\"");
ASSERT_NE(js, jsShapeFactor) << "Failed to substitute shapesys -> shapefactor in JSON";

RooWorkspace wsFromJson{"new"};
ASSERT_NO_THROW(RooJSONFactoryWSTool{wsFromJson}.importJSONfromString(jsShapeFactor))
<< "Importer rejected the 'shapefactor' modifier type";

// The imported workspace should expose the same ParamHistFunc gammas.
EXPECT_NE(wsFromJson.var("gamma_bkgShape_bin_0"), nullptr);
EXPECT_NE(wsFromJson.var("gamma_bkgShape_bin_1"), nullptr);

// Re-exporting should give back the original JSON, since the writer emits
// type "shapesys" in both cases.
const std::string js2 = RooJSONFactoryWSTool{wsFromJson}.exportJSONtoString();
EXPECT_EQ(js, js2) << "JSON -> WS -> JSON roundtrip changed the JSON";
}
44 changes: 27 additions & 17 deletions roofit/hs3/src/JSONFactories_HistFactory.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -198,24 +198,31 @@ ParamHistFunc &createPHF(const std::string &phfname, std::string const &sysname,
{
RooWorkspace &ws = *tool.workspace();

size_t n = std::max(vals.size(), parnames.size());
RooArgList gammas;
for (std::size_t i = 0; i < vals.size(); ++i) {
for (std::size_t i = 0; i < n; ++i) {
const std::string name = parnames.empty() ? defaultGammaName(sysname, i) : parnames[i];
gammas.add(getOrCreate<RooRealVar>(ws, name, 1., gammaMin, gammaMax));
auto *e = dynamic_cast<RooAbsReal *>(ws.obj(name.c_str()));
if (e)
gammas.add(*e);
else
gammas.add(getOrCreate<RooRealVar>(ws, name, 1., gammaMin, gammaMax));
}

auto &phf = tool.wsEmplace<ParamHistFunc>(phfname, observables, gammas);

if (constraintType != "Const") {
auto constraintsInfo = createGammaConstraints(
gammas, vals, minSigma, constraintType == "Poisson" ? Constraint::Poisson : Constraint::Gaussian);
for (auto const &term : constraintsInfo.constraints) {
ws.import(*term, RooFit::RecycleConflictNodes());
constraints.add(*ws.pdf(term->GetName()));
}
} else {
for (auto *gamma : static_range_cast<RooRealVar *>(gammas)) {
gamma->setConstant(true);
if (vals.size() > 0) {
if (constraintType != "Const") {
auto constraintsInfo = createGammaConstraints(
gammas, vals, minSigma, constraintType == "Poisson" ? Constraint::Poisson : Constraint::Gaussian);
for (auto const &term : constraintsInfo.constraints) {
ws.import(*term, RooFit::RecycleConflictNodes());
constraints.add(*ws.pdf(term->GetName()));
}
} else {
for (auto *gamma : static_range_cast<RooRealVar *>(gammas)) {
gamma->setConstant(true);
}
}
}

Expand Down Expand Up @@ -374,19 +381,22 @@ bool importHistSample(RooJSONFactoryWSTool &tool, RooDataHist &dh, RooArgSet con
sysname + "High_" + prefixedName, varlist,
RooJSONFactoryWSTool::readBinnedData(data["hi"], sysname + "High_" + prefixedName, varlist)));
constraints.add(getOrCreateConstraint(tool, mod, par, sampleName));
} else if (modtype == "shapesys") {
} else if (modtype == "shapesys" || modtype == "shapefactor") {
std::string funcName = channelName + "_" + sysname + "_ShapeSys";
// funcName should be "<channel_name>_<sysname>_ShapeSys"
std::vector<double> vals;
for (const auto &v : mod["data"]["vals"].children()) {
vals.push_back(v.val_double());
if (mod["data"].has_child("vals")) {
for (const auto &v : mod["data"]["vals"].children()) {
vals.push_back(v.val_double());
}
}
std::vector<std::string> parnames;
for (const auto &v : mod["parameters"].children()) {
parnames.push_back(v.val());
}
if (vals.empty()) {
RooJSONFactoryWSTool::error("unable to instantiate shapesys '" + sysname + "' with 0 values!");
if (vals.empty() && parnames.empty()) {
RooJSONFactoryWSTool::error("unable to instantiate shapesys '" + sysname +
"' with neither values nor parameters!");
}
std::string constraint(mod.has_child("constraint_type") ? mod["constraint_type"].val()
: mod.has_child("constraint") ? mod["constraint"].val()
Expand Down
20 changes: 10 additions & 10 deletions roofit/hs3/src/RooJSONFactoryWSTool.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@ std::unique_ptr<RooAbsData> loadData(const JSONNode &p, RooWorkspace &workspace)
*
* This function imports an analysis, represented by the provided JSONNodes 'analysisNode' and 'likelihoodsNode',
* into the workspace represented by the provided RooWorkspace. The analysis information is read from the JSONNodes
* and added to the workspace as one or more RooStats::ModelConfig objects.
* and added to the workspace as one or more RooFit::ModelConfig objects.
*
* @param rootnode The root JSONNode representing the entire JSON file.
* @param analysisNode The JSONNode representing the analysis to be imported.
Expand All @@ -567,8 +567,8 @@ void importAnalysis(const JSONNode &rootnode, const JSONNode &analysisNode, cons
if (workspace.obj(mcname))
return;

workspace.import(RooStats::ModelConfig{mcname.c_str(), mcname.c_str()});
auto *mc = static_cast<RooStats::ModelConfig *>(workspace.obj(mcname));
workspace.import(RooFit::ModelConfig{mcname.c_str(), mcname.c_str()});
auto *mc = static_cast<RooFit::ModelConfig *>(workspace.obj(mcname));
mc->SetWS(workspace);

auto *nllNode = RooJSONFactoryWSTool::findNamedChild(likelihoodsNode, analysisNode["likelihood"].val());
Expand Down Expand Up @@ -1738,7 +1738,7 @@ void RooJSONFactoryWSTool::importDependants(const JSONNode &n)
}
}

void RooJSONFactoryWSTool::exportModelConfig(JSONNode &rootnode, RooStats::ModelConfig const &mc,
void RooJSONFactoryWSTool::exportModelConfig(JSONNode &rootnode, RooFit::ModelConfig const &mc,
const std::vector<CombinedData> &combDataSets,
const std::vector<RooAbsData *> &singleDataSets)
{
Expand Down Expand Up @@ -1773,7 +1773,7 @@ void RooJSONFactoryWSTool::exportModelConfig(JSONNode &rootnode, RooStats::Model
}
}

void RooJSONFactoryWSTool::exportSingleModelConfig(JSONNode &rootnode, RooStats::ModelConfig const &mc,
void RooJSONFactoryWSTool::exportSingleModelConfig(JSONNode &rootnode, RooFit::ModelConfig const &mc,
std::string const &analysisName,
std::map<std::string, std::string> const *dataComponents)
{
Expand Down Expand Up @@ -1936,7 +1936,7 @@ void RooJSONFactoryWSTool::exportAllObjects(JSONNode &n)

// export all ModelConfig objects and attached Pdfs
for (TObject *obj : _workspace.allGenericObjects()) {
if (auto mc = dynamic_cast<RooStats::ModelConfig *>(obj)) {
if (auto mc = dynamic_cast<RooFit::ModelConfig *>(obj)) {
exportModelConfig(n, *mc, combData, singleData);
}
}
Expand Down Expand Up @@ -2446,7 +2446,7 @@ RooWorkspace RooJSONFactoryWSTool::cleanWS(const RooWorkspace &ws, bool onlyMode
RooWorkspace tmpWS = RooWorkspace();
if (onlyModelConfig) {
for (auto *obj : ws.allGenericObjects()) {
if (auto *mc = dynamic_cast<RooStats::ModelConfig *>(obj)) {
if (auto *mc = dynamic_cast<RooFit::ModelConfig *>(obj)) {
tmpWS.import(*mc->GetPdf(), RooFit::RecycleConflictNodes(true));
}
}
Expand Down Expand Up @@ -2479,7 +2479,7 @@ RooWorkspace RooJSONFactoryWSTool::cleanWS(const RooWorkspace &ws, bool onlyMode
}

/*
if (auto* mc = dynamic_cast<RooStats::ModelConfig*>(obj)) {
if (auto* mc = dynamic_cast<RooFit::ModelConfig*>(obj)) {
// Import the PDF
tmpWS.import(*mc->GetPdf());

Expand All @@ -2500,7 +2500,7 @@ RooWorkspace RooJSONFactoryWSTool::cleanWS(const RooWorkspace &ws, bool onlyMode
tmpWS.import(*nuis);


RooStats::ModelConfig* mc_new = new RooStats::ModelConfig(mc->GetName(), mc->GetName());
RooFit::ModelConfig* mc_new = new RooFit::ModelConfig(mc->GetName(), mc->GetName());

mc_new->SetPdf(*tmpWS.pdf(mc->GetPdf()->GetName()));
mc_new->SetObservables(*tmpWS.set(obs->GetName()));
Expand Down Expand Up @@ -2627,7 +2627,7 @@ RooWorkspace RooJSONFactoryWSTool::sanitizeWS(const RooWorkspace &ws)
}
}

if (auto *mc = dynamic_cast<RooStats::ModelConfig *>(obj)) {
if (auto *mc = dynamic_cast<RooFit::ModelConfig *>(obj)) {
// Sanitize ModelConfig name
if (!isValidName(mc->GetName())) {
mc->SetName(sanitizeName(mc->GetName()).c_str());
Expand Down
2 changes: 1 addition & 1 deletion roofit/hs3/test/testHS3SimultaneousFit.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ std::unique_ptr<RooFitResult> writeJSONAndFitModel(std::string &jsonStr)
// Simultaneous PDF and model config
ws.factory("SIMUL::simPdf(channelCat[channel_1=0, channel_2=1], channel_1=model_1, channel_2=model_2)");

RooStats::ModelConfig modelConfig{"ModelConfig"};
RooFit::ModelConfig modelConfig{"ModelConfig"};

modelConfig.SetWS(ws);
modelConfig.SetPdf("simPdf");
Expand Down
4 changes: 2 additions & 2 deletions roofit/hs3/test/testRooFitHS3.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
#include <RooSimultaneous.h>
#include <RooWorkspace.h>
#include <RooFormulaVar.h>
#include <RooStats/ModelConfig.h>
#include <RooFit/ModelConfig.h>

#include <TROOT.h>

Expand Down Expand Up @@ -524,7 +524,7 @@ TEST(RooFitHS3, ModelConfigWithMultiVarGaussian)
ws1.import(mv, RooFit::Silence(), RooFit::RecycleConflictNodes());

// Build a ModelConfig referencing the pdf and its observables
RooStats::ModelConfig mc{"mc", &ws1};
RooFit::ModelConfig mc{"mc", &ws1};
mc.SetPdf(*ws1.pdf("mvgauss"));
mc.SetObservables("x,y");
ws1.import(mc);
Expand Down
Loading