diff --git a/brian2/equations/equations.py b/brian2/equations/equations.py index 9b987e587..555e771a1 100644 --- a/brian2/equations/equations.py +++ b/brian2/equations/equations.py @@ -737,6 +737,276 @@ def _substitute(self, replacements): def substitute(self, **kwds): return Equations(list(self._substitute(kwds).values())) + def prefix(self, prefix_str): + """ + Return a copy of the equations with all variable names prefixed. + + Parameters + ---------- + prefix_str : str + String to prepend to variable names + + Returns + ------- + Equations + New Equations object with prefixed variables + + Raises + ------ + ValueError + If prefix_str is not a valid Python identifier + + Notes + ----- + This method prefixes all user-defined variables (differential equations, + parameters, subexpressions) while protecting built-in variables (t, dt, xi, + i, N, etc.) and external namespace references. + + The prefixing is recursive, so `eqs.prefix('a_').prefix('b_')` would + result in variables prefixed with `b_a_` (nested). + + Examples + -------- + >>> eqs = Equations(''' + ... dv/dt = -v/tau : volt + ... I = g*E : amp + ... ''') + >>> eqs_exc = eqs.prefix('exc_') + >>> 'exc_v' in eqs_exc + True + """ + import keyword + + # Validate prefix + if not isinstance(prefix_str, str): + raise ValueError("Prefix must be a string") + if prefix_str == "": + # Return a copy, not self + return Equations(list(self._equations.values())) + if not prefix_str.isidentifier(): + raise ValueError( + f"Invalid prefix '{prefix_str}': must be a valid Python identifier" + ) + if keyword.iskeyword(prefix_str): + raise ValueError(f"Prefix cannot be a Python keyword: '{prefix_str}'") + + # Built-in variables to protect (never rename these) + protected = { + "t", + "dt", + "xi", + "i", + "N", + "not_refractory", + "refractory", + "refractory_until", + "time", + "clock", + } + + # Collect all identifiers from expressions (to find external references) + external_refs = set() + for eq in self._equations.values(): + if eq.expr is not None: + for identifier in eq.expr.identifiers: + if identifier not in self._equations and identifier not in protected: + external_refs.add(identifier) + + # Create new equations dict + new_equations = {} + + # First, add all renamed equations + for varname, eq in self._equations.items(): + # Skip built-in variables + if varname in protected: + new_equations[varname] = eq + continue + + # Create new name with prefix + new_name = prefix_str + varname + + # Check for conflict + if new_name in self._equations and new_name != varname: + logger.warning( + f"'{new_name}' already exists in equations, " + f"will be overwritten by prefixing '{varname}'" + ) + + # Update expression if it has one + new_expr = None + if eq.expr is not None: + expr_code = eq.expr.code + # Replace all variable references with prefixed versions + # Only replace variables that are actually defined in equations + for old_var in self._equations.keys(): + if old_var not in protected: + # Use regex to replace whole words only + expr_code = re.sub( + r"\b" + re.escape(old_var) + r"\b", + prefix_str + old_var, + expr_code, + ) + try: + new_expr = Expression(expr_code) + except ValueError as ex: + raise ValueError( + f"Failed to prefix expression for '{varname}': {ex}" + ) from ex + + # Create new SingleEquation with prefixed name + new_equations[new_name] = SingleEquation( + eq.type, + new_name, + eq.dim, + var_type=eq.var_type, + expr=new_expr, + flags=eq.flags, + ) + + # Add external references as unchanged parameters (if they were in original) + for ref in external_refs: + if ref in self._equations: + # This was defined as an equation/parameter, already handled above + continue + # External reference - create as parameter to preserve it + # Find which equation referenced it to get dimensions + ref_dim = None + for eq in self._equations.values(): + if eq.expr is not None and ref in eq.expr.identifiers: + # Try to get dimension from context + # For simplicity, we'll skip this optimization + pass + + return Equations(list(new_equations.values())) + + def postfix(self, postfix_str): + """ + Return a copy of the equations with all variable names postfixed. + + Parameters + ---------- + postfix_str : str + String to append to variable names + + Returns + ------- + Equations + New Equations object with postfixed variables + + Raises + ------ + ValueError + If postfix_str is not a valid Python identifier + + Notes + ----- + This method postfixes all user-defined variables (differential equations, + parameters, subexpressions) while protecting built-in variables (t, dt, xi, + i, N, etc.) and external namespace references. + + The postfixing is recursive, so `eqs.postfix('_a').postfix('_b')` would + result in variables postfixed with `_a_b` (nested). + + Examples + -------- + >>> eqs = Equations(''' + ... dv/dt = -v/tau : volt + ... I = g*E : amp + ... ''') + >>> eqs_pop = eqs.postfix('_pop') + >>> 'v_pop' in eqs_pop + True + """ + import keyword + + # Validate postfix + if not isinstance(postfix_str, str): + raise ValueError("Postfix must be a string") + if postfix_str == "": + # Return a copy, not self + return Equations(list(self._equations.values())) + if not postfix_str.isidentifier(): + raise ValueError( + f"Invalid postfix '{postfix_str}': must be a valid Python identifier" + ) + if keyword.iskeyword(postfix_str): + raise ValueError(f"Postfix cannot be a Python keyword: '{postfix_str}'") + + # Built-in variables to protect (never rename these) + protected = { + "t", + "dt", + "xi", + "i", + "N", + "not_refractory", + "refractory", + "refractory_until", + "time", + "clock", + } + + # Collect all identifiers from expressions (to find external references) + external_refs = set() + for eq in self._equations.values(): + if eq.expr is not None: + for identifier in eq.expr.identifiers: + if identifier not in self._equations and identifier not in protected: + external_refs.add(identifier) + + # Create new equations dict + new_equations = {} + + # First, add all renamed equations + for varname, eq in self._equations.items(): + # Skip built-in variables + if varname in protected: + new_equations[varname] = eq + continue + + # Create new name with postfix + new_name = varname + postfix_str + + # Check for conflict + if new_name in self._equations and new_name != varname: + logger.warning( + f"'{new_name}' already exists in equations, " + f"will be overwritten by postfixing '{varname}'" + ) + + # Update expression if it has one + new_expr = None + if eq.expr is not None: + expr_code = eq.expr.code + # Replace all variable references with postfixed versions + # Only replace variables that are actually defined in equations + for old_var in self._equations.keys(): + if old_var not in protected: + # Use regex to replace whole words only + expr_code = re.sub( + r"\b" + re.escape(old_var) + r"\b", + old_var + postfix_str, + expr_code, + ) + try: + new_expr = Expression(expr_code) + except ValueError as ex: + raise ValueError( + f"Failed to postfix expression for '{varname}': {ex}" + ) from ex + + # Create new SingleEquation with postfixed name + new_equations[new_name] = SingleEquation( + eq.type, + new_name, + eq.dim, + var_type=eq.var_type, + expr=new_expr, + flags=eq.flags, + ) + + return Equations(list(new_equations.values())) + def __iter__(self): return iter(self._equations) diff --git a/brian2/tests/test_equations.py b/brian2/tests/test_equations.py index ed985c465..7098ee33a 100644 --- a/brian2/tests/test_equations.py +++ b/brian2/tests/test_equations.py @@ -9,6 +9,7 @@ import pytest from brian2 import Equations, Expression, Hz, Unit, farad, metre, ms, mV, second, volt +from brian2 import NeuronGroup, run from brian2.core.namespace import DEFAULT_UNITS from brian2.equations.equations import ( BOOLEAN, @@ -567,6 +568,210 @@ def test_extract_subexpressions(): assert constant["s2"].type == SUBEXPRESSION +@pytest.mark.codegen_independent +def test_prefix_basic(): + # Test basic prefix functionality + eqs = Equations(""" + dv/dt = -v/tau : volt + I = g*E : amp + """) + new_eqs = eqs.prefix('pop_') + assert 'pop_v' in new_eqs + assert 'pop_I' in new_eqs + # g and E are identifiers in the expression I = g*E + # but they're not defined as separate equations, so they're not in eqs + assert 'tau' not in new_eqs # External, not included in equations dict + + +@pytest.mark.codegen_independent +def test_postfix_basic(): + # Test basic postfix functionality + eqs = Equations(""" + dv/dt = -v/tau : volt + """) + new_eqs = eqs.postfix('_pop') + assert 'v_pop' in new_eqs + assert 'tau' not in new_eqs # External, not included in equations dict + + +@pytest.mark.codegen_independent +def test_prefix_recursive(): + # Test recursive prefix behavior + eqs = Equations(""" + dv/dt = -v/tau : volt + """) + eqs2 = eqs.prefix('a_') + assert 'a_v' in eqs2 + eqs3 = eqs2.prefix('b_') + assert 'b_a_v' in eqs3 # Nested! + + +@pytest.mark.codegen_independent +def test_postfix_recursive(): + # Test recursive postfix behavior + eqs = Equations(""" + dv/dt = -v/tau : volt + """) + eqs2 = eqs.postfix('_a') + assert 'v_a' in eqs2 + eqs3 = eqs2.postfix('_b') + assert 'v_a_b' in eqs3 # Nested! + + +@pytest.mark.codegen_independent +def test_prefix_builtin_protection(): + # Test that built-in variables are protected + # Built-in variables like 't' and 'dt' are never in the equations dict + # They're implicit, not explicit equation definitions + eqs = Equations(""" + dv/dt = -v/tau : volt + """) + new_eqs = eqs.prefix('pop_') + # t and dt are never in the equations dict to begin with + assert 't' not in new_eqs + assert 'dt' not in new_eqs + assert 'pop_t' not in new_eqs + assert 'pop_dt' not in new_eqs + + +@pytest.mark.codegen_independent +def test_prefix_validation(): + # Test validation of prefix string + eqs = Equations(""" + dv/dt = -v/tau : volt + """) + + # Must be a string + with pytest.raises(ValueError, match="must be a string"): + eqs.prefix(123) + + # Cannot start with digit + with pytest.raises(ValueError, match="must be a valid Python identifier"): + eqs.prefix('123_') + + # Invalid characters + with pytest.raises(ValueError, match="must be a valid Python identifier"): + eqs.prefix('my-var') + + # Cannot be a keyword + with pytest.raises(ValueError, match="cannot be a Python keyword"): + eqs.prefix('for') + + +@pytest.mark.codegen_independent +def test_postfix_validation(): + # Test validation of postfix string + eqs = Equations(""" + dv/dt = -v/tau : volt + """) + + # Must be a string + with pytest.raises(ValueError, match="must be a string"): + eqs.postfix(123) + + # Invalid characters + with pytest.raises(ValueError, match="must be a valid Python identifier"): + eqs.postfix('-invalid') + + +@pytest.mark.codegen_independent +def test_prefix_case_preservation(): + # Test that case is preserved + eqs = Equations(""" + dv/dt = -v/tau : volt + dV/dt = -V/tau : volt + """) + new_eqs = eqs.prefix('Pop_') + assert 'Pop_v' in new_eqs + assert 'Pop_V' in new_eqs + assert 'Pop_v' != 'Pop_V' # Different! + + +@pytest.mark.codegen_independent +def test_prefix_external_references(): + # Test that external namespace references are NOT added to equations + tau = 10*ms + eqs = Equations(""" + dv/dt = -v/tau : volt + """) + new_eqs = eqs.prefix('pop_') + assert 'pop_v' in new_eqs + # tau is an external reference, not in equations dict + assert 'tau' not in eqs # Not in original + assert 'tau' not in new_eqs # Not in new either + + +@pytest.mark.codegen_independent +def test_prefix_empty(): + # Test that empty prefix returns a copy + eqs = Equations(""" + dv/dt = -v/tau : volt + """) + new_eqs = eqs.prefix('') + assert set(new_eqs.names) == set(eqs.names) + assert new_eqs is not eqs # Different object + + +@pytest.mark.codegen_independent +def test_postfix_empty(): + # Test that empty postfix returns a copy + eqs = Equations(""" + dv/dt = -v/tau : volt + """) + new_eqs = eqs.postfix('') + assert set(new_eqs.names) == set(eqs.names) + assert new_eqs is not eqs # Different object + + +@pytest.mark.codegen_independent +def test_prefix_subexpressions(): + # Test that subexpressions are prefixed + eqs = Equations(""" + dv/dt = (-v + I)/tau : volt + I = g*E : amp + """) + new_eqs = eqs.prefix('pop_') + assert 'pop_v' in new_eqs + assert 'pop_I' in new_eqs + # g and E are in the expression but not separate equations + + +@pytest.mark.codegen_independent +def test_prefix_parameters(): + # Test that parameters are prefixed + eqs = Equations(""" + dv/dt = -v/tau : volt + tau : second + """) + new_eqs = eqs.prefix('pop_') + assert 'pop_v' in new_eqs + assert 'pop_tau' in new_eqs + + +@pytest.mark.standalone_compatible +def test_prefix_with_neurongroup(): + # Test that prefixed equations work with NeuronGroup + eqs = Equations(""" + dv/dt = -v/(10*ms) : volt + """) + eqs_exc = eqs.prefix('exc_') + G = NeuronGroup(10, eqs_exc) + assert 'exc_v' in G.variables + run(1*ms) + + +@pytest.mark.standalone_compatible +def test_postfix_with_neurongroup(): + # Test that postfixed equations work with NeuronGroup + eqs = Equations(""" + dv/dt = -v/(10*ms) : volt + """) + eqs_exc = eqs.postfix('_exc') + G = NeuronGroup(10, eqs_exc) + assert 'v_exc' in G.variables + run(1*ms) + + @pytest.mark.codegen_independent def test_repeated_construction(): eqs1 = Equations("dx/dt = x : 1")