Skip to content

Commit e2e080e

Browse files
committed
[RF] Disable redundant dirty-flag propagation during minimization
When a likelihood is evaluated with the new `"cpu"` backend, the `RooFit::Evaluator` fully manages dependency tracking and re-evaluation of the computation graph. In this case, RooFit’s built-in dirty flag propagation in RooAbsArg becomes redundant and introduces significant overhead for large models. This patch disables regular dirty state propagation for all non-fundamental nodes in the Evaluator's computation graph by setting their OperMode to `RooAbsArg::ADirty`. Fundamental nodes (e.g. RooRealVar, RooCategory) are excluded because they are often shared with other computation graphs outside the Evaluator (usually the original pdf in the RooWorkspace). To set the OperMode of *all* RooAbsArgs to `ADirty` during minimization, while avoiding side effects outside the minimization scope, the dirty flag propagation for the fundamental nodes is only disabled temporarily in the RooMinimizer. This commit drastically speeds up fits with AD in particular (up to 2 x for large models), because with fast gradients, the dirty flag propagation that determines which part of the compute graph needs to be recomputed becomes the bottleneck. It was also redundant with a faster "dirty state" bookkeeping mechanism in the `RooFit::Evaluator` class itself. At this point, there is no performance regression anymore when disabling recursive dirty flag propagation for all evaluated nodes, so the old comment in the code about test 14 in stressRooFit being slow doesn't apply anymore. (cherry picked from commit fa97774)
1 parent ed098c2 commit e2e080e

6 files changed

Lines changed: 102 additions & 21 deletions

File tree

roofit/roofitcore/inc/RooEvaluatorWrapper.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,10 @@ class RooEvaluatorWrapper final : public RooAbsReal {
7171
void generateHessian();
7272

7373
void setUseGeneratedFunctionCode(bool);
74-
7574
void writeDebugMacro(std::string const &) const;
7675

76+
std::stack<std::unique_ptr<ChangeOperModeRAII>> setOperModes(RooAbsArg::OperMode opMode);
77+
7778
protected:
7879
double evaluate() const override;
7980

roofit/roofitcore/inc/RooFit/Evaluator.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ class Evaluator {
4545

4646
void setOffsetMode(RooFit::EvalContext::OffsetMode);
4747

48+
std::stack<std::unique_ptr<ChangeOperModeRAII>> setOperModes(RooAbsArg::OperMode opMode);
49+
4850
private:
4951
void processVariable(NodeInfo &nodeInfo);
5052
void processCategory(NodeInfo &nodeInfo);

roofit/roofitcore/inc/RooMinimizer.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ class RooMinimizer : public TObject {
234234
void fillCorrMatrix(RooFitResult &fitRes);
235235
void updateErrors();
236236

237+
RooAbsReal &_function;
237238
ROOT::Fit::FitConfig _config; ///< fitter configuration (options and parameter settings)
238239
std::unique_ptr<FitResult> _result; ///<! pointer to the object containing the result of the fit
239240
std::unique_ptr<ROOT::Math::Minimizer> _minimizer; ///<! pointer to used minimizer

roofit/roofitcore/src/RooEvaluatorWrapper.cxx

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -741,6 +741,11 @@ void RooEvaluatorWrapper::writeDebugMacro(std::string const &filename) const
741741
return _funcWrapper->writeDebugMacro(filename);
742742
}
743743

744+
std::stack<std::unique_ptr<ChangeOperModeRAII>> RooEvaluatorWrapper::setOperModes(RooAbsArg::OperMode opMode)
745+
{
746+
return _evaluator->setOperModes(opMode);
747+
}
748+
744749
} // namespace RooFit::Experimental
745750

746751
/// \endcond

roofit/roofitcore/src/RooFit/Evaluator.cxx

Lines changed: 55 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ RooAbsPdf::fitTo() is called and gets destroyed when the fitting ends.
4646
#include <iomanip>
4747
#include <numeric>
4848
#include <thread>
49+
#include <unordered_set>
4950

5051
namespace RooFit {
5152

@@ -325,16 +326,16 @@ void Evaluator::updateOutputSizes()
325326
for (auto &info : _nodes) {
326327
info.outputSize = outputSizeMap.at(info.absArg);
327328
info.isDirty = true;
328-
329-
// In principle we don't need dirty flag propagation because the driver
330-
// takes care of deciding which node needs to be re-evaluated. However,
331-
// disabling it also for scalar mode results in very long fitting times
332-
// for specific models (test 14 in stressRooFit), which still needs to be
333-
// understood. TODO.
334-
if (!info.isScalar()) {
329+
// We don't need dirty flag propagation because the evaluator takes care
330+
// of deciding what needs to be re-evaluated. We can disable the regular
331+
// dirty state propagation. However, fundamental variables like
332+
// RooRealVars and RooCategories are usually shared with other
333+
// computation graphs outside the evaluator, so they can't be mutated.
334+
// See also the code of the RooMinimizer, which ensures that dirty state
335+
// propagation is temporarily disabled during minimization to really
336+
// eliminate any overhead from the dirty flag propagation.
337+
if (!info.absArg->isFundamental()) {
335338
setOperMode(info.absArg, RooAbsArg::ADirty);
336-
} else {
337-
setOperMode(info.absArg, info.originalOperMode);
338339
}
339340
}
340341

@@ -632,6 +633,51 @@ void Evaluator::setOperMode(RooAbsArg *arg, RooAbsArg::OperMode opMode)
632633
}
633634
}
634635

636+
// Change the operation modes of all RooAbsArgs in the computation graph.
637+
// The changes are reset when the returned RAII object goes out of scope.
638+
//
639+
// We also walk transitively through value clients of the nodes to cover any
640+
// node that RooAbsReal::doEval (the fallback scalar implementation) might
641+
// inadvertently propagate the ADirty mode to via its recursive restore: that
642+
// helper sets servers temporarily to AClean and then calls
643+
// setOperMode(oldOperMode) to restore, which recurses to value clients when
644+
// oldOperMode is ADirty. If we did not protect those clients here, any node
645+
// outside the computation graph that shares a fundamental (e.g. a parameter
646+
// like a RooRealVar) would be left permanently in ADirty after the first
647+
// minimization, dramatically slowing down later scalar evaluations (for
648+
// example on pdfs held by the legacy test statistics' internal cache).
649+
std::stack<std::unique_ptr<ChangeOperModeRAII>> Evaluator::setOperModes(RooAbsArg::OperMode opMode)
650+
{
651+
std::stack<std::unique_ptr<ChangeOperModeRAII>> out{};
652+
std::unordered_set<RooAbsArg *> visited;
653+
654+
std::vector<RooAbsArg *> queue;
655+
queue.reserve(_nodes.size());
656+
for (auto &info : _nodes) {
657+
queue.push_back(info.absArg);
658+
}
659+
660+
while (!queue.empty()) {
661+
RooAbsArg *node = queue.back();
662+
queue.pop_back();
663+
if (!visited.insert(node).second)
664+
continue;
665+
666+
if (opMode != node->operMode()) {
667+
out.emplace(std::make_unique<ChangeOperModeRAII>(node, opMode));
668+
}
669+
670+
// Only follow value-client links: that is exactly the propagation path
671+
// used by RooAbsArg::setOperMode with mode==ADirty.
672+
if (opMode == RooAbsArg::ADirty) {
673+
for (auto *client : node->valueClients()) {
674+
queue.push_back(client);
675+
}
676+
}
677+
}
678+
return out;
679+
}
680+
635681
void Evaluator::print(std::ostream &os)
636682
{
637683
std::cout << "--- RooFit BatchMode evaluation ---\n";

roofit/roofitcore/src/RooMinimizer.cxx

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,27 +40,30 @@ automatic PDF optimization.
4040
#include "RooMinimizer.h"
4141

4242
#include "RooAbsMinimizerFcn.h"
43-
#include "RooArgSet.h"
44-
#include "RooArgList.h"
4543
#include "RooAbsReal.h"
44+
#include "RooArgList.h"
45+
#include "RooArgSet.h"
46+
#include "RooCategory.h"
4647
#include "RooDataSet.h"
47-
#include "RooRealVar.h"
48-
#include "RooSentinel.h"
48+
#include "RooEvaluatorWrapper.h"
49+
#include "RooFit/TestStatistics/RooAbsL.h"
50+
#include "RooFit/TestStatistics/RooRealL.h"
51+
#include "RooFitResult.h"
52+
#include "RooHelpers.h"
53+
#include "RooMinimizerFcn.h"
4954
#include "RooMsgService.h"
50-
#include "RooCategory.h"
5155
#include "RooMultiPdf.h"
5256
#include "RooPlot.h"
53-
#include "RooHelpers.h"
54-
#include "RooMinimizerFcn.h"
55-
#include "RooFitResult.h"
56-
#include "RooFit/TestStatistics/RooAbsL.h"
57-
#include "RooFit/TestStatistics/RooRealL.h"
57+
#include "RooRealVar.h"
58+
#include "RooSentinel.h"
5859
#ifdef ROOFIT_MULTIPROCESS
5960
#include "TestStatistics/MinuitFcnGrad.h"
6061
#include "RooFit/MultiProcess/Config.h"
6162
#include "RooFit/MultiProcess/ProcessTimer.h"
6263
#endif
6364

65+
#include "RooFitImplHelpers.h"
66+
6467
#include <Fit/BasicFCN.h>
6568
#include <Math/Minimizer.h>
6669
#include <TClass.h>
@@ -120,6 +123,22 @@ void reorderCombinations(std::vector<std::vector<int>> &combos, const std::vecto
120123
}
121124
}
122125

126+
// The RooEvaluatorWrapper uses its own logic to decide what needs to be
127+
// re-evaluated. We can therefore disable the regular dirty state propagation
128+
// temporarily during minimization. However, some RooAbsArgs shared with other
129+
// regular RooFit computation graphs outside the minimized likelihood, so we
130+
// have to make sure that the operation mode is reset after the minimization.
131+
//
132+
// This should be called before running any routine via the _minimizer data
133+
// member. The RAII object should only be destructed after the routine is done.
134+
std::stack<std::unique_ptr<ChangeOperModeRAII>> setOperModesDirty(RooAbsReal &function)
135+
{
136+
if (auto *wrapper = dynamic_cast<RooFit::Experimental::RooEvaluatorWrapper *>(&function)) {
137+
return wrapper->setOperModes(RooAbsArg::ADirty);
138+
}
139+
return {};
140+
}
141+
123142
} // namespace
124143

125144
////////////////////////////////////////////////////////////////////////////////
@@ -135,7 +154,7 @@ void reorderCombinations(std::vector<std::vector<int>> &combos, const std::vecto
135154
/// value of the input function.
136155

137156
/// Constructor that accepts all configuration in struct with RooAbsReal likelihood
138-
RooMinimizer::RooMinimizer(RooAbsReal &function, Config const &cfg) : _cfg(cfg)
157+
RooMinimizer::RooMinimizer(RooAbsReal &function, Config const &cfg) : _function{function}, _cfg(cfg)
139158
{
140159
initMinimizerFirstPart();
141160
auto nll_real = dynamic_cast<RooFit::TestStatistics::RooRealL *>(&function);
@@ -692,6 +711,7 @@ RooPlot *RooMinimizer::contour(RooRealVar &var1, RooRealVar &var2, double n1, do
692711
n[4] = n5;
693712
n[5] = n6;
694713

714+
auto operModeRAII = setOperModesDirty(_function);
695715
for (int ic = 0; ic < 6; ic++) {
696716
if (n[ic] > 0) {
697717

@@ -906,6 +926,8 @@ bool RooMinimizer::fitFCN()
906926
// fit a user provided FCN function
907927
// create fit parameter settings
908928

929+
auto operModeRAII = setOperModesDirty(_function);
930+
909931
// Check number of parameters
910932
unsigned int npar = getNPar();
911933
if (npar == 0) {
@@ -1045,6 +1067,8 @@ bool RooMinimizer::calculateHessErrors()
10451067
// compute the Hesse errors according to configuration
10461068
// set in the parameters and append value in fit result
10471069

1070+
auto operModeRAII = setOperModesDirty(_function);
1071+
10481072
// update minimizer (recreate if not done or if name has changed
10491073
if (!updateMinimizerOptions()) {
10501074
coutE(Minimization) << "RooMinimizer::calculateHessErrors() Error re-initializing the minimizer" << std::endl;
@@ -1079,6 +1103,8 @@ bool RooMinimizer::calculateMinosErrors()
10791103
// (in DoMinimization) aftewr minimizing if the
10801104
// FitConfig::MinosErrors() flag is set
10811105

1106+
auto operModeRAII = setOperModesDirty(_function);
1107+
10821108
// update minimizer (but cannot re-create in this case). Must use an existing one
10831109
if (!updateMinimizerOptions(false)) {
10841110
coutE(Minimization) << "RooMinimizer::calculateHessErrors() Error re-initializing the minimizer" << std::endl;

0 commit comments

Comments
 (0)