diff --git a/source/MaterialXCore/Document.cpp b/source/MaterialXCore/Document.cpp index ec8bb54e84..97d3e34b3e 100644 --- a/source/MaterialXCore/Document.cpp +++ b/source/MaterialXCore/Document.cpp @@ -69,6 +69,13 @@ class Document::Cache return (it != _implementationMap.end()) ? it->second : vector(); } + ImplementationPtr getImplementationForNodeGraph(const string& nodeGraphName) + { + auto lock = refreshWithLock(); + auto it = _nodeGraphImplMap.find(nodeGraphName); + return (it != _nodeGraphImplMap.end()) ? it->second : ImplementationPtr(); + } + private: std::shared_lock refreshWithLock() { @@ -103,6 +110,7 @@ class Document::Cache _portElementMap.clear(); _nodeDefMap.clear(); _implementationMap.clear(); + _nodeGraphImplMap.clear(); // Traverse the document to build a new cache. for (ElementPtr elem : doc->traverseTree()) @@ -121,6 +129,14 @@ class Document::Cache _portElementMap[portElem->getQualifiedName(portKey)].push_back(portElem); } } + if (!nodeGraphName.empty()) + { + ImplementationPtr impl = elem->asA(); + if (impl) + { + _nodeGraphImplMap[impl->getQualifiedName(nodeGraphName)] = impl; + } + } if (!nodeString.empty()) { NodeDefPtr nodeDef = elem->asA(); @@ -152,6 +168,7 @@ class Document::Cache std::unordered_map> _portElementMap; std::unordered_map> _nodeDefMap; std::unordered_map> _implementationMap; + std::unordered_map _nodeGraphImplMap; }; // @@ -378,9 +395,7 @@ vector Document::getMaterialOutputs() const vector Document::getMatchingNodeDefs(const string& nodeName) const { // Recurse to data library if present. - vector matchingNodeDefs = hasDataLibrary() ? - getDataLibrary()->getMatchingNodeDefs(nodeName) : - vector(); + vector matchingNodeDefs = hasDataLibrary() ? getDataLibrary()->getMatchingNodeDefs(nodeName) : vector(); // Append all nodedefs matching the given node name. vector localNodeDefs = _cache->getMatchingNodeDefs(nodeName); @@ -392,9 +407,7 @@ vector Document::getMatchingNodeDefs(const string& nodeName) const vector Document::getMatchingImplementations(const string& nodeDef) const { // Recurse to data library if present. - vector matchingImplementations = hasDataLibrary() ? - getDataLibrary()->getMatchingImplementations(nodeDef) : - vector(); + vector matchingImplementations = hasDataLibrary() ? getDataLibrary()->getMatchingImplementations(nodeDef) : vector(); // Append all implementations matching the given nodedef string. vector localImpls = _cache->getMatchingImplementations(nodeDef); @@ -403,6 +416,19 @@ vector Document::getMatchingImplementations(const string& n return matchingImplementations; } +ImplementationPtr Document::getImplementationForNodeGraph(const string& nodeGraphName) const +{ + if (hasDataLibrary()) + { + ImplementationPtr impl = getDataLibrary()->getImplementationForNodeGraph(nodeGraphName); + if (impl) + { + return impl; + } + } + return _cache->getImplementationForNodeGraph(nodeGraphName); +} + bool Document::validate(string* message) const { bool res = true; diff --git a/source/MaterialXCore/Document.h b/source/MaterialXCore/Document.h index ab01bbbc11..84b0f5c4c4 100644 --- a/source/MaterialXCore/Document.h +++ b/source/MaterialXCore/Document.h @@ -551,6 +551,10 @@ class MX_CORE_API Document : public GraphElement /// Implementation element or NodeGraph element. vector getMatchingImplementations(const string& nodeDef) const; + /// Return the Implementation, if any, whose nodegraph attribute matches + /// the given fully-qualified nodegraph name. + ImplementationPtr getImplementationForNodeGraph(const string& nodeGraphName) const; + /// @} /// @name UnitDef Elements /// @{ diff --git a/source/MaterialXCore/Node.cpp b/source/MaterialXCore/Node.cpp index f5923f337c..2ac50f951d 100644 --- a/source/MaterialXCore/Node.cpp +++ b/source/MaterialXCore/Node.cpp @@ -192,7 +192,7 @@ vector Node::getDownstreamPorts() const } } std::sort(downstreamPorts.begin(), downstreamPorts.end(), [](const ConstElementPtr& a, const ConstElementPtr& b) - { + { return a->getName() > b->getName(); }); return downstreamPorts; @@ -745,12 +745,10 @@ NodeDefPtr NodeGraph::getNodeDef() const // If not directly defined look for an implementation which has a nodedef association if (!nodedef) { - for (auto impl : getDocument()->getImplementations()) + ImplementationPtr impl = getDocument()->getImplementationForNodeGraph(getQualifiedName(getName())); + if (impl) { - if (impl->getNodeGraph() == getQualifiedName(getName())) - { - nodedef = impl->getNodeDef(); - } + nodedef = impl->getNodeDef(); } } return nodedef; @@ -775,7 +773,7 @@ vector NodeGraph::getDownstreamPorts() const } } std::sort(downstreamPorts.begin(), downstreamPorts.end(), [](const ConstElementPtr& a, const ConstElementPtr& b) - { + { return a->getName() > b->getName(); }); return downstreamPorts; diff --git a/source/MaterialXGenShader/Syntax.cpp b/source/MaterialXGenShader/Syntax.cpp index f33a24fd30..9700586dd4 100644 --- a/source/MaterialXGenShader/Syntax.cpp +++ b/source/MaterialXGenShader/Syntax.cpp @@ -153,7 +153,7 @@ void Syntax::makeValidName(string& name) const { std::replace_if(name.begin(), name.end(), isInvalidChar, '_'); name = replaceSubstrings(name, _invalidTokens); - if (std::find(_reservedWords.begin(), _reservedWords.end(), name) != _reservedWords.end()) + if (_reservedWords.find(name) != _reservedWords.end()) { // We append "1" here because thats the prior behavior from makeIdentifier() below when // the reservedWords were added to the identifiers list. diff --git a/source/MaterialXTest/MaterialXCore/Performance.cpp b/source/MaterialXTest/MaterialXCore/Performance.cpp new file mode 100644 index 0000000000..8886e40943 --- /dev/null +++ b/source/MaterialXTest/MaterialXCore/Performance.cpp @@ -0,0 +1,73 @@ +// +// Copyright Contributors to the MaterialX Project +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#ifdef CATCH_CONFIG_ENABLE_BENCHMARKING + + #include + #include + #include + +namespace mx = MaterialX; + +// Build a document that scales the two axes feeding NodeGraph::getNodeDef +// from PortElement::validate: +// * numImplStubs - document-root elements; sizes the +// Document::Cache nodegraph -> implementation map. +// * numNodegraphRefs - inputs that resolve through a nodegraph; each one +// performs a cache lookup. +static mx::DocumentPtr buildScalingDocument(size_t numImplStubs, size_t numNodegraphRefs) +{ + mx::DocumentPtr doc = mx::createDocument(); + + // A nodegraph with one internal node and one output, so that + // PortElement::getConnectedNode() resolves transitively through the + // output (see Input::getConnectedNode in Interface.cpp). + mx::NodeGraphPtr ng = doc->addNodeGraph("ng"); + mx::NodePtr inner = ng->addNode("constant", "inner", "color3"); + mx::OutputPtr out = ng->addOutput("out", "color3"); + out->setNodeName(inner->getName()); + + // A host node parenting many inputs that reference the nodegraph. + mx::NodePtr host = doc->addNode("surface", "host", "surfaceshader"); + for (size_t i = 0; i < numNodegraphRefs; ++i) + { + mx::InputPtr in = host->addInput("in" + std::to_string(i), "color3"); + in->setNodeGraphString("ng"); + in->setOutputString("out"); + } + + for (size_t i = 0; i < numImplStubs; ++i) + { + mx::ImplementationPtr impl = doc->addImplementation("i" + std::to_string(i)); + impl->setNodeDefString("n" + std::to_string(i)); + } + + return doc; +} + +TEST_CASE("NodeGraph getNodeDef performance", "[node][performance]") +{ + BENCHMARK("validate, stubs=200 refs=100") + { + mx::DocumentPtr doc = buildScalingDocument(200, 100); + return doc->validate(); + }; + + BENCHMARK("validate, stubs=1000 refs=100") + { + mx::DocumentPtr doc = buildScalingDocument(1000, 100); + return doc->validate(); + }; + + BENCHMARK("validate, stubs=10000 refs=100") + { + mx::DocumentPtr doc = buildScalingDocument(10000, 100); + return doc->validate(); + }; +} + +#endif // CATCH_CONFIG_ENABLE_BENCHMARKING