diff --git a/src/psyclone/psyir/symbols/symbol_table.py b/src/psyclone/psyir/symbols/symbol_table.py index 82dbc6b8da..61bb01a385 100644 --- a/src/psyclone/psyir/symbols/symbol_table.py +++ b/src/psyclone/psyir/symbols/symbol_table.py @@ -48,6 +48,7 @@ import inspect import copy import logging +import re from typing import Any, List, Optional, Union, TYPE_CHECKING from psyclone.configuration import Config @@ -57,6 +58,7 @@ ImportInterface, RoutineSymbol, Symbol, SymbolError, UnresolvedInterface) from psyclone.psyir.symbols.intrinsic_symbol import IntrinsicSymbol from psyclone.psyir.symbols.typed_symbol import TypedSymbol +from psyclone.psyir.symbols.datatypes import UnsupportedFortranType if TYPE_CHECKING: from psyclone.psyir.nodes.scoping_node import ScopingNode @@ -704,6 +706,12 @@ def check_for_clashes(self, other_table, symbols_to_skip=()): isinstance(other_sym, IntrinsicSymbol)): continue + # If both symbols have CommonBlockInterface, they represent the + # same shared COMMON-block data. They cannot (and do not need to) + # be renamed, so treat this as a benign clash. + if this_sym.is_commonblock and other_sym.is_commonblock: + continue + if other_sym.is_import and this_sym.is_import: # Both symbols are imported. That's fine as long as they have # the same import interface (are imported from the same @@ -945,6 +953,7 @@ def _add_symbols_from_table(self, other_table, symbols_to_skip=()): already been updated to refer to a Container in this table. ''' + for old_sym in other_table.symbols: if old_sym in symbols_to_skip or isinstance(old_sym, @@ -952,6 +961,26 @@ def _add_symbols_from_table(self, other_table, symbols_to_skip=()): # We've dealt with Container symbols in _add_container_symbols. continue + # Avoid duplicate COMMON-block marker symbols when multiple + # routines sharing the same COMMON blocks are inlined into a + # single caller. Each routine is parsed independently so its + # _PSYCLONE_INTERNAL_COMMONBLOCK_N markers carry different + # numbers and may therefore never trigger the name-clash path; + # we must scan *all* existing markers in self for an identical + # declaration before attempting to add. + if (self._normalize(old_sym.name).startswith( + "_psyclone_internal_commonblock") + and isinstance(old_sym.datatype, UnsupportedFortranType)): + if any( + sym.datatype.declaration == old_sym.datatype.declaration + for sym in self.symbols + if (self._normalize(sym.name).startswith( + "_psyclone_internal_commonblock") + and isinstance(sym.datatype, + UnsupportedFortranType)) + ): + continue + try: self.add(old_sym) @@ -1000,11 +1029,45 @@ def _handle_symbol_clash(self, old_sym, other_table): self_sym = self.lookup(old_sym.name) if old_sym.is_unresolved and self_sym.is_unresolved: + # Update after fixing issue #3392 # The clashing symbols are both unresolved so we ASSUME that # check_for_clashes has previously determined that they must # refer to the same thing and we don't have to do anything. return + if old_sym.is_commonblock and self_sym.is_commonblock: + return + + if (isinstance(old_sym.datatype, UnsupportedFortranType) and + isinstance(self_sym.datatype, UnsupportedFortranType)): + if old_sym.datatype.declaration == self_sym.datatype.declaration: + # Identical COMMON-block markers – already present in self. + return + # Markers have different declarations. Skip the incoming one only + # if its COMMON-block name(s) overlap with those already in self: + # that means the block is already declared and adding a second + # marker for it would produce a "Symbol X is already in a COMMON + # block" compile error. If the block names are different this is + # a genuinely new COMMON block and we fall through to the + # rename-and-add path below. + if self._normalize(old_sym.name).startswith( + "_psyclone_internal_commonblock"): + _blk_re = re.compile(r"/\s*(\w*)\s*/", re.IGNORECASE) + old_blocks = set(_blk_re.findall( + old_sym.datatype.declaration)) + # Check ALL existing commonblock markers in self, not just + # the same-named one, because the numbering may differ when + # the caller already has extra COMMON blocks of its own. + for sym in self.symbols: + if (self._normalize(sym.name).startswith( + "_psyclone_internal_commonblock") + and isinstance(sym.datatype, + UnsupportedFortranType)): + self_blocks = set(_blk_re.findall( + sym.datatype.declaration)) + if old_blocks & self_blocks: + return + # A Symbol with the same name already exists so we attempt to rename # first the one that we are adding and failing that, the existing # symbol in this table. diff --git a/src/psyclone/psyir/transformations/inline_trans.py b/src/psyclone/psyir/transformations/inline_trans.py index 5cd4f3e544..b52e933ea9 100644 --- a/src/psyclone/psyir/transformations/inline_trans.py +++ b/src/psyclone/psyir/transformations/inline_trans.py @@ -147,6 +147,7 @@ def apply(self, use_first_callee_and_no_arg_check: bool = False, permit_codeblocks: bool = False, permit_unsupported_type_args: bool = False, + parameter_cloning: bool = True, **kwargs ): ''' @@ -163,6 +164,13 @@ def apply(self, if the target routine contains a CodeBlock. :param permit_unsupported_type_args: If `True` then the target routine is permitted to have arguments of UnsupportedType. + :param parameter_cloning: if `True` (the default), constant + (PARAMETER) symbols from the routine being inlined are always + copied into the call-site symbol table, potentially being renamed + to avoid clashes. If `False`, a constant from the routine is + skipped when an identical constant (same name, same type, and same + value) already exists at the call site, so no duplicate is + created. :raises InternalError: if the merge of the symbol tables fails. In theory this should never happen because validate() should @@ -219,6 +227,15 @@ def apply(self, # just delete the if statement. self._optional_arg_eliminate_ifblock_if_const_condition(routine) + # If parameter_cloning is disabled, identify duplicate constant + # (PARAMETER) symbols and redirect their references *before* the + # routine body is extracted, so that the extracted statements already + # carry references to the call-site symbols. + extra_skip: List[DataSymbol] = [] + if not parameter_cloning: + extra_skip = self._redirect_duplicate_parameters( + table, routine, routine_table) + # Construct lists of the nodes that will be inserted and all of the # References that they contain. new_stmts = [] @@ -231,7 +248,8 @@ def apply(self, # call site. This preserves any references to them. try: table.merge(routine_table, - symbols_to_skip=routine_table.argument_list[:]) + symbols_to_skip=routine_table.argument_list[:] + + extra_skip) except SymbolError as err: raise InternalError( f"Error copying routine symbols to call site. This should " @@ -329,6 +347,116 @@ def apply(self, idx += 1 parent.addchild(child, idx) + def _redirect_duplicate_parameters( + self, + table, + routine: Routine, + routine_table, + ) -> List[DataSymbol]: + ''' + Identifies constant (PARAMETER) symbols in ``routine_table`` that + are identical to constants already present in ``table`` (same name, + same type, and same initial value). For each such symbol, every + :py:class:`~psyclone.psyir.nodes.Reference` to it inside ``routine`` + and inside the datatypes / initial-value expressions of other symbols + in ``routine_table`` is redirected to point to the corresponding + symbol in ``table``. + + Only constants whose initial value is represented as a PSyIR node + (i.e. ``initial_value is not None``) are considered; constants of + ``UnsupportedFortranType`` with an embedded value string are left + unchanged. + + A constant is only considered a duplicate when every routine-local + symbol referenced inside its initial-value expression is itself a + confirmed duplicate. This prevents false positives for expressions + like ``negflag = .NOT. flag`` when ``flag`` has different values in + the caller and the callee (the names would match but the semantics + would differ). + + :param table: the call-site symbol table. + :type table: :py:class:`psyclone.psyir.symbols.SymbolTable` + :param routine: the (copy of the) routine being inlined. + :type routine: :py:class:`psyclone.psyir.nodes.Routine` + :param routine_table: the symbol table of the routine copy. + :type routine_table: :py:class:`psyclone.psyir.symbols.SymbolTable` + + :returns: the list of routine symbols that are duplicates of + call-site constants and should be excluded from the subsequent + table merge. + :rtype: List[:py:class:`psyclone.psyir.symbols.DataSymbol`] + + ''' + # The names of all local data symbols in the routine table (used to + # identify references that point to routine-local constants). + routine_local_names = { + s.name.lower() for s in routine_table.datasymbols + if not s.is_argument + } + + # First pass: collect all constants from the routine whose name, + # datatype, and initial-value tree match a constant in the call-site + # table. The structural comparison uses __eq__, which compares + # Reference nodes by symbol name. This is correct for leaf constants + # (Literals) and is refined for dependent constants in the second + # pass below. + candidates: dict = {} + for rsym in routine_table.datasymbols: + if not rsym.is_constant or rsym.initial_value is None: + # Skip constants whose value is not represented as a PSyIR + # node (e.g. UnsupportedFortranType with embedded value). + continue + tsym = table.lookup(rsym.name, otherwise=None) + if not isinstance(tsym, DataSymbol): + continue + if not tsym.is_constant or tsym.initial_value is None: + continue + if rsym.datatype != tsym.datatype: + continue + if rsym.initial_value != tsym.initial_value: + continue + candidates[rsym.name.lower()] = rsym + + # Second pass: iteratively remove candidates whose initial-value + # expression references a routine-local symbol that is NOT itself + # a confirmed duplicate. Without this step, an expression like + # ``negflag = .NOT. flag`` would compare as equal by name even when + # ``flag`` has different values in the two routines. + changed = True + while changed: + changed = False + to_remove = [ + name for name, rsym in candidates.items() + if any( + dep.name.lower() in routine_local_names + and dep.name.lower() not in candidates + for dep in rsym.initial_value.get_all_accessed_symbols() + ) + ] + for name in to_remove: + del candidates[name] + if to_remove: + changed = True + + duplicates: List[DataSymbol] = list(candidates.values()) + + # Redirect all references from duplicate routine symbols to their + # call-site counterparts. + for rsym in duplicates: + tsym = table.lookup(rsym.name) + # Update all References in the routine body. + routine.replace_symbols_using(tsym) + # Update any references to rsym embedded in the datatypes or + # initial-value expressions of other symbols in routine_table. + for sym in routine_table.symbols: + if sym is rsym: + continue + sym.replace_symbols_using(tsym) + if hasattr(sym, 'datatype') and sym.datatype is not None: + sym.datatype.replace_symbols_using(tsym) + + return duplicates + def _optional_arg_resolve_present_intrinsics(self, routine_node: Routine, arg_match_list: List = []): diff --git a/src/psyclone/tests/psyir/backend/fortran_common_block_test.py b/src/psyclone/tests/psyir/backend/fortran_common_block_test.py index 9a1b8fd084..919f9093ee 100644 --- a/src/psyclone/tests/psyir/backend/fortran_common_block_test.py +++ b/src/psyclone/tests/psyir/backend/fortran_common_block_test.py @@ -61,6 +61,8 @@ def test_fw_common_blocks(fortran_reader, fortran_writer, tmpdir): routine = psyir.walk(Routine)[0] assert routine.symbol_table.lookup("a").is_commonblock # Sanity check + assert routine.symbol_table.lookup("d").is_commonblock # Sanity check + assert routine.symbol_table.lookup("e").is_commonblock # Sanity check code = fortran_writer(routine) assert code == ( diff --git a/src/psyclone/tests/psyir/symbols/symbol_table_test.py b/src/psyclone/tests/psyir/symbols/symbol_table_test.py index 0d73fb26f9..b5ac020d68 100644 --- a/src/psyclone/tests/psyir/symbols/symbol_table_test.py +++ b/src/psyclone/tests/psyir/symbols/symbol_table_test.py @@ -1325,6 +1325,90 @@ def test_handle_symbol_clash_imported_symbols(): "of the same name imported from 'Ridcully'" in str(err.value)) +def test_handle_symbol_clash_commonblock_same_declaration(): + '''Test that _handle_symbol_clash() ignores duplicate COMMON-block + markers with identical declarations.''' + table1 = symbols.SymbolTable() + table2 = symbols.SymbolTable() + decl = "common /keep_me/ a" + marker_name = "_PSYCLONE_INTERNAL_COMMONBLOCK_1" + table1.add(symbols.DataSymbol( + marker_name, symbols.UnsupportedFortranType(decl))) + table2.add(symbols.DataSymbol( + marker_name, symbols.UnsupportedFortranType(decl))) + + old_sym = table2.lookup(marker_name) + table1._handle_symbol_clash(old_sym, table2) + + assert len(table1.symbols) == 1 + assert old_sym.name == marker_name + + +def test_handle_symbol_clash_commonblock_overlap_with_other_marker(): + '''Test that _handle_symbol_clash() scans all existing COMMON-block + markers and skips incoming marker when block names overlap.''' + table1 = symbols.SymbolTable() + table2 = symbols.SymbolTable() + marker_name = "_PSYCLONE_INTERNAL_COMMONBLOCK_1" + # Clash with same name but different block. + table1.add(symbols.DataSymbol( + marker_name, symbols.UnsupportedFortranType("common /other/ b"))) + # Include a non-marker symbol so the marker filter condition also takes + # the false branch while scanning existing symbols. + table1.add(symbols.DataSymbol("plain", symbols.INTEGER_TYPE)) + # Existing marker with different internal number but same block as incoming + # marker. This exercises the scan of all existing markers in table1. + table1.add(symbols.DataSymbol( + "_PSYCLONE_INTERNAL_COMMONBLOCK_2", + symbols.UnsupportedFortranType("common /overlap/ c"))) + table2.add(symbols.DataSymbol( + marker_name, symbols.UnsupportedFortranType("common /overlap/ a"))) + + old_sym = table2.lookup(marker_name) + table1._handle_symbol_clash(old_sym, table2) + + assert len(table1.symbols) == 3 + assert old_sym.name == marker_name + + +def test_handle_symbol_clash_commonblock_distinct_blocks_renamed(): + '''Test that _handle_symbol_clash() renames and adds an incoming + COMMON-block marker when block names do not overlap.''' + table1 = symbols.SymbolTable() + table2 = symbols.SymbolTable() + marker_name = "_PSYCLONE_INTERNAL_COMMONBLOCK_1" + table1.add(symbols.DataSymbol( + marker_name, symbols.UnsupportedFortranType("common /first/ a"))) + table2.add(symbols.DataSymbol( + marker_name, symbols.UnsupportedFortranType("common /second/ b"))) + + old_sym = table2.lookup(marker_name) + table1._handle_symbol_clash(old_sym, table2) + + assert old_sym.name != marker_name + assert any(sym.datatype.declaration == "common /second/ b" + for sym in table1.symbols) + + +def test_handle_symbol_clash_unsupported_fortran_non_commonblock_name(): + '''Test that a clash between UnsupportedFortranType symbols with names + unrelated to common-block markers takes the standard rename-and-add path. + ''' + table1 = symbols.SymbolTable() + table2 = symbols.SymbolTable() + table1.add(symbols.DataSymbol( + "clash", symbols.UnsupportedFortranType("type(t1) :: clash"))) + table2.add(symbols.DataSymbol( + "clash", symbols.UnsupportedFortranType("type(t2) :: clash"))) + + old_sym = table2.lookup("clash") + table1._handle_symbol_clash(old_sym, table2) + + assert old_sym.name != "clash" + assert any(sym.datatype.declaration == "type(t2) :: clash" + for sym in table1.symbols) + + def test_swap_symbol_properties(): ''' Test the symboltable swap_properties method ''' # pylint: disable=too-many-statements diff --git a/src/psyclone/tests/psyir/transformations/inline_trans_test.py b/src/psyclone/tests/psyir/transformations/inline_trans_test.py index cda89d5229..72495dd9d7 100644 --- a/src/psyclone/tests/psyir/transformations/inline_trans_test.py +++ b/src/psyclone/tests/psyir/transformations/inline_trans_test.py @@ -2229,6 +2229,37 @@ def test_validate_array_reshape(fortran_reader): "argument, 'x', has rank 1" in str(err.value)) +def test_validate_unknown_type_array_arg(fortran_reader): + '''Test that _validate_inline_of_call_and_routine_argument_pairs rejects + an attempt to inline a call when the actual argument has an unknown type + but the corresponding formal argument is an array.''' + code = """\ +module test_mod +contains +subroutine main + use some_mod, only: mystery + call sub(mystery) +end subroutine +subroutine sub(x) + real, dimension(10), intent(inout) :: x + x(:) = 0.0 +end subroutine +end module +""" + psyir = fortran_reader.psyir_from_source(code) + call = psyir.walk(Call)[0] + sub = psyir.walk(Routine)[1] + inline_trans = InlineTrans() + with pytest.raises(TransformationError) as err: + inline_trans._validate_inline_of_call_and_routine_argument_pairs( + call, call.arguments[0], sub, sub.symbol_table.lookup("x")) + assert ( + "Routine 'sub' cannot be inlined because the type of the actual " + "argument 'mystery' corresponding to an array formal argument " + "('x') is unknown." in str(err.value) + ) + + def test_validate_array_arg_expression(fortran_reader): ''' Check that validate rejects a call if an argument corresponding to @@ -2843,3 +2874,588 @@ def test_apply_array_access_check_unresolved_override_option( inline_trans.apply( call, use_first_callee_and_no_arg_check=True) # TODO check results + + +def test_apply_common_block_no_duplicate(fortran_reader, fortran_writer): + '''Test that inlining two routines that share a COMMON block does not + produce duplicate COMMON declarations (which would cause a Fortran compile + error "Symbol X is already in a COMMON block").''' + + src = """\ +module test_mod + implicit none +contains + subroutine caller() + call sub1() + call sub2() + end subroutine caller + subroutine sub1() + real :: volume, lmmpi + COMMON /blk/ volume, lmmpi + volume = 1.0 + end subroutine sub1 + subroutine sub2() + real :: volume, lmmpi + COMMON /blk/ volume, lmmpi + lmmpi = 2.0 + end subroutine sub2 +end module test_mod +""" + psyir = fortran_reader.psyir_from_source(src) + caller = psyir.walk(Routine)[0] + + trans = InlineTrans() + calls = caller.walk(Call) + trans.apply(calls[0]) + calls = caller.walk(Call) + trans.apply(calls[0]) + + result = fortran_writer(caller) + # Exactly one COMMON declaration must appear. + assert result.count("COMMON /blk/") == 1 + # Both variables must still be present. + assert "volume" in result + assert "lmmpi" in result + + +def test_apply_common_block_no_duplicate_three_routines( + fortran_reader, fortran_writer): + '''Test that inlining three routines that all share the same COMMON block + does not produce duplicate COMMON declarations. This mirrors the real-world + case of inlining zetabc_tile, u2dbc_tile and v2dbc_tile (each of which + includes the same set of COMMON-block headers) into step2D_FB_tile.''' + + src = """\ +module test_mod + implicit none +contains + subroutine caller() + call sub1() + call sub2() + call sub3() + end subroutine caller + subroutine sub1() + real :: zeta, ubar, vbar + COMMON /ocean_zeta/ zeta + COMMON /ocean_ubar/ ubar + COMMON /ocean_vbar/ vbar + zeta = 1.0 + end subroutine sub1 + subroutine sub2() + real :: zeta, ubar, vbar + COMMON /ocean_zeta/ zeta + COMMON /ocean_ubar/ ubar + COMMON /ocean_vbar/ vbar + ubar = 2.0 + end subroutine sub2 + subroutine sub3() + real :: zeta, ubar, vbar + COMMON /ocean_zeta/ zeta + COMMON /ocean_ubar/ ubar + COMMON /ocean_vbar/ vbar + vbar = 3.0 + end subroutine sub3 +end module test_mod +""" + psyir = fortran_reader.psyir_from_source(src) + caller = psyir.walk(Routine)[0] + + trans = InlineTrans() + calls = caller.walk(Call) + trans.apply(calls[0]) + calls = caller.walk(Call) + trans.apply(calls[0]) + calls = caller.walk(Call) + trans.apply(calls[0]) + + result = fortran_writer(caller) + # Each COMMON block must appear exactly once. + assert result.count("COMMON /ocean_zeta/") == 1 + assert result.count("COMMON /ocean_ubar/") == 1 + assert result.count("COMMON /ocean_vbar/") == 1 + # All three variables must still be present. + assert "zeta" in result + assert "ubar" in result + assert "vbar" in result + + +def test_apply_common_block_caller_has_extra_block( + fortran_reader, fortran_writer): + '''Test that inlining a routine whose only COMMON block is already present + in the caller does not produce a duplicate COMMON declaration, even when + the caller also has an *additional* COMMON block that the inlined routine + does not declare. This is a regression test derived from the real-world + test.f file: the presence of the extra /comm_setup_mpi1/ block in the + caller was enough to confuse the earlier deduplication logic and caused + "Symbol 'zeta' at (1) is already in a COMMON block".''' + + src = """\ +module test_mod + implicit none +contains + subroutine caller() + integer :: lmmpi + COMMON /comm_setup_mpi1/ lmmpi + real :: zeta + COMMON /ocean_zeta/ zeta + call subfoo() + end subroutine caller + subroutine subfoo() + real :: zeta + COMMON /ocean_zeta/ zeta + zeta = zeta + 1.0 + end subroutine subfoo +end module test_mod +""" + psyir = fortran_reader.psyir_from_source(src) + caller = psyir.walk(Routine)[0] + + trans = InlineTrans() + calls = caller.walk(Call) + trans.apply(calls[0]) + + result = fortran_writer(caller) + # /ocean_zeta/ must appear exactly once – not duplicated. + assert result.count("COMMON /ocean_zeta/") == 1 + # The extra block from the caller must be preserved. + assert result.count("COMMON /comm_setup_mpi1/") == 1 + assert "zeta" in result + assert "lmmpi" in result + + +# parameter_cloning option + + +def test_apply_parameter_cloning_default(fortran_reader, fortran_writer): + '''Test that the default behaviour (parameter_cloning=True) clones a + constant from the inlined routine into the call-site table, even when + an identical constant already exists there, potentially renaming it.''' + code = """\ +module test_mod +contains + subroutine bar(b) + real, parameter :: constval = 123.0 + integer :: b + call foo(b) + end subroutine bar + subroutine foo(a) + real, parameter :: constval = 123.0 + integer :: a + a = int(constval) + end subroutine foo +end module test_mod +""" + psyir = fortran_reader.psyir_from_source(code) + bar = psyir.walk(Routine)[0] + foo = psyir.walk(Routine)[1] + call = bar.walk(Call)[0] + + InlineTrans().apply(call, routine=foo) + + result = fortran_writer(bar) + # With cloning enabled the inlined constant must appear at least once; + # it may be renamed to avoid the clash. + assert "constval" in result + + +def test_apply_parameter_cloning_false_identical(fortran_reader, + fortran_writer): + '''Test that parameter_cloning=False suppresses the duplicate when the + call-site already has an identical constant (same name, type, value). + This is the main use-case from the user request.''' + code = """\ +module test_mod +contains + subroutine bar(b) + real, parameter :: constval = 123.0 + integer :: b + call foo(b) + end subroutine bar + subroutine foo(a) + real, parameter :: constval = 123.0 + integer :: a + a = int(constval) + end subroutine foo +end module test_mod +""" + psyir = fortran_reader.psyir_from_source(code) + bar = psyir.walk(Routine)[0] + foo = psyir.walk(Routine)[1] + call = bar.walk(Call)[0] + + InlineTrans().apply(call, routine=foo, parameter_cloning=False) + + result = fortran_writer(bar) + # constval should be declared exactly once (no duplicate parameter). + assert result.count("parameter :: constval") == 1 + # The inlined assignment should still use constval correctly. + assert "constval" in result + + +def test_apply_parameter_cloning_false_different_value(fortran_reader, + fortran_writer): + '''Test that parameter_cloning=False does NOT suppress a parameter when + the values differ between the call site and the inlined routine.''' + code = """\ +module test_mod +contains + subroutine bar(b) + real, parameter :: constval = 42.0 + integer :: b + call foo(b) + end subroutine bar + subroutine foo(a) + real, parameter :: constval = 123.0 + integer :: a + a = int(constval) + end subroutine foo +end module test_mod +""" + psyir = fortran_reader.psyir_from_source(code) + bar = psyir.walk(Routine)[0] + foo = psyir.walk(Routine)[1] + call = bar.walk(Call)[0] + + InlineTrans().apply(call, routine=foo, parameter_cloning=False) + + result = fortran_writer(bar) + # Both constant declarations must survive since they have different values. + assert result.count("constval") >= 2 + + +def test_apply_parameter_cloning_false_no_match_in_caller(fortran_reader, + fortran_writer): + '''Test that parameter_cloning=False still adds a constant that does not + exist at the call site.''' + code = """\ +module test_mod +contains + subroutine bar(b) + integer :: b + call foo(b) + end subroutine bar + subroutine foo(a) + real, parameter :: constval = 123.0 + integer :: a + a = int(constval) + end subroutine foo +end module test_mod +""" + psyir = fortran_reader.psyir_from_source(code) + bar = psyir.walk(Routine)[0] + foo = psyir.walk(Routine)[1] + call = bar.walk(Call)[0] + + InlineTrans().apply(call, routine=foo, parameter_cloning=False) + + result = fortran_writer(bar) + # constval from foo must be added to bar because bar didn't have it. + assert "constval" in result + + +def test_apply_parameter_cloning_false_used_in_array_dim(fortran_reader, + fortran_writer): + '''Test that parameter_cloning=False correctly handles a constant that + is used as an array-dimension bound inside the inlined routine.''' + code = """\ +module test_mod +contains + subroutine bar(b) + integer, parameter :: n = 5 + integer :: b + call foo(b) + end subroutine bar + subroutine foo(a) + integer, parameter :: n = 5 + real, dimension(n) :: tmp + integer :: a + tmp(1) = real(a) + a = int(tmp(1)) + end subroutine foo +end module test_mod +""" + psyir = fortran_reader.psyir_from_source(code) + bar = psyir.walk(Routine)[0] + foo = psyir.walk(Routine)[1] + call = bar.walk(Call)[0] + + InlineTrans().apply(call, routine=foo, parameter_cloning=False) + + result = fortran_writer(bar) + # n should appear only once as a parameter declaration. + assert result.count(", parameter ::", result.lower().find("n =")) <= 1 \ + or result.count("n = 5") == 1 + # The inlined array tmp should still be present and use n. + assert "tmp" in result + assert "n" in result + + +def test_apply_parameter_cloning_false_multiple_params(fortran_reader, + fortran_writer): + '''Test parameter_cloning=False with multiple constants, some matching + and some not.''' + code = """\ +module test_mod +contains + subroutine bar(b) + integer, parameter :: shared = 10 + integer :: b + call foo(b) + end subroutine bar + subroutine foo(a) + integer, parameter :: shared = 10 + integer, parameter :: local_only = 99 + integer :: a + a = shared + local_only + end subroutine foo +end module test_mod +""" + psyir = fortran_reader.psyir_from_source(code) + bar = psyir.walk(Routine)[0] + foo = psyir.walk(Routine)[1] + call = bar.walk(Call)[0] + + InlineTrans().apply(call, routine=foo, parameter_cloning=False) + + result = fortran_writer(bar) + # shared must be declared exactly once (no duplicate parameter). + assert result.count("parameter :: shared") == 1 + # local_only is unique to foo, so it must be added to bar. + assert "local_only" in result + + +def test_apply_parameter_cloning_false_complex_rhs_identical(fortran_reader, + fortran_writer): + '''Test parameter_cloning=False with constants whose value is a complex + PSyIR expression (BinaryOperation) that is identical in the caller and the + routine. The duplicate should be suppressed.''' + code = """\ +module test_mod +contains + subroutine bar(b) + integer, parameter :: base_val = 10 + integer, parameter :: constval = 123 + base_val + integer :: b + call foo(b) + end subroutine bar + subroutine foo(a) + integer, parameter :: base_val = 10 + integer, parameter :: constval = 123 + base_val + integer :: a + a = constval + end subroutine foo +end module test_mod +""" + psyir = fortran_reader.psyir_from_source(code) + bar = psyir.walk(Routine)[0] + foo = psyir.walk(Routine)[1] + call = bar.walk(Call)[0] + + InlineTrans().apply(call, routine=foo, parameter_cloning=False) + + result = fortran_writer(bar) + # Neither base_val nor constval should be duplicated. + assert result.count("parameter :: constval") == 1 + assert result.count("parameter :: base_val") == 1 + # The inlined body should still reference constval. + assert "constval" in result + + +def test_apply_parameter_cloning_false_complex_rhs_different(fortran_reader, + fortran_writer): + '''Test parameter_cloning=False with constants that have identical names + but different complex RHS expressions. Both declarations must be kept.''' + code = """\ +module test_mod +contains + subroutine bar(b) + integer, parameter :: base_val = 10 + integer, parameter :: constval = 100 + base_val + integer :: b + call foo(b) + end subroutine bar + subroutine foo(a) + integer, parameter :: base_val = 10 + integer, parameter :: constval = 123 + base_val + integer :: a + a = constval + end subroutine foo +end module test_mod +""" + psyir = fortran_reader.psyir_from_source(code) + bar = psyir.walk(Routine)[0] + foo = psyir.walk(Routine)[1] + call = bar.walk(Call)[0] + + InlineTrans().apply(call, routine=foo, parameter_cloning=False) + + result = fortran_writer(bar) + # constval has different values in bar and foo, so both must appear. + assert result.count("parameter :: constval") >= 2 or ( + "constval" in result and "constval_1" in result) + # base_val is identical and should be deduplicated. + assert result.count("parameter :: base_val") == 1 + + +def test_apply_parameter_cloning_false_unary_op_identical(fortran_reader, + fortran_writer): + '''Test parameter_cloning=False with a unary operation (.NOT.) that is + identical in both the caller and the callee. Both the base parameter and + the derived .NOT. parameter should be deduplicated.''' + code = """\ +module test_mod +contains + subroutine bar(b) + logical, parameter :: flag = .true. + logical, parameter :: negflag = .not. flag + integer :: b + call foo(b) + end subroutine bar + subroutine foo(a) + logical, parameter :: flag = .true. + logical, parameter :: negflag = .not. flag + integer :: a + if (negflag) a = 42 + end subroutine foo +end module test_mod +""" + psyir = fortran_reader.psyir_from_source(code) + bar = psyir.walk(Routine)[0] + foo = psyir.walk(Routine)[1] + call = bar.walk(Call)[0] + + InlineTrans().apply(call, routine=foo, parameter_cloning=False) + + result = fortran_writer(bar) + # Both flag and negflag are identical so neither should be duplicated. + assert result.count("parameter :: flag") == 1 + assert result.count("parameter :: negflag") == 1 + # The inlined if-body must still reference negflag correctly. + assert "negflag" in result + + +def test_apply_parameter_cloning_false_unary_op_different_base( + fortran_reader, fortran_writer): + '''Test parameter_cloning=False where .NOT. parameters share a name but + their base parameter differs. The derived constant must NOT be deduplicated + because the structural match is only nominal (the base has different + values), and using the caller's copy would produce wrong semantics.''' + code = """\ +module test_mod +contains + subroutine bar(b) + logical, parameter :: flag = .true. + logical, parameter :: negflag = .not. flag + integer :: b + call foo(b) + end subroutine bar + subroutine foo(a) + logical, parameter :: flag = .false. + logical, parameter :: negflag = .not. flag + integer :: a + if (negflag) a = 42 + end subroutine foo +end module test_mod +""" + psyir = fortran_reader.psyir_from_source(code) + bar = psyir.walk(Routine)[0] + foo = psyir.walk(Routine)[1] + call = bar.walk(Call)[0] + + InlineTrans().apply(call, routine=foo, parameter_cloning=False) + + result = fortran_writer(bar) + # flag has different values so both must appear (foo's renamed). + assert result.count("parameter :: flag") >= 2 or "flag_1" in result + # negflag depends on flag which differs, so foo's negflag must also + # appear (renamed), and the inlined if must use foo's (renamed) negflag. + assert "negflag_1" in result + assert "if (negflag_1)" in result + + +def test_apply_parameter_cloning_false_caller_has_non_constant( + fortran_reader, fortran_writer): + '''Test that parameter_cloning=False does NOT suppress a routine constant + when the call-site has a symbol with the same name that is not a constant + (i.e. tsym.is_constant is False). This exercises the + ``if not tsym.is_constant or tsym.initial_value is None`` branch in + _redirect_duplicate_parameters.''' + code = """\ +module test_mod +contains + subroutine bar(b) + integer :: constval + integer :: b + constval = 7 + call foo(b) + end subroutine bar + subroutine foo(a) + integer, parameter :: constval = 10 + integer :: a + a = constval + end subroutine foo +end module test_mod +""" + psyir = fortran_reader.psyir_from_source(code) + bar = psyir.walk(Routine)[0] + foo = psyir.walk(Routine)[1] + call = bar.walk(Call)[0] + + InlineTrans().apply(call, routine=foo, parameter_cloning=False) + + result = fortran_writer(bar) + # bar's constval is a variable; foo's is a parameter. They are not + # duplicates, so foo's parameter constant must appear (possibly renamed). + assert """\ +subroutine bar(b) + integer, parameter :: constval_1 = 10 + integer :: b + integer :: constval + + constval = 7 + b = constval_1 + +end subroutine bar""" in result + + +def test_apply_parameter_cloning_false_different_datatype( + fortran_reader, fortran_writer): + '''Test that parameter_cloning=False does NOT suppress a routine constant + when the call-site has a constant with the same name but a different + datatype. This exercises the ``if rsym.datatype != tsym.datatype`` + branch in _redirect_duplicate_parameters.''' + code = """\ +module test_mod +contains + subroutine bar(b) + integer, parameter :: constval = 10 + integer :: b + call foo(b) + end subroutine bar + subroutine foo(a) + real, parameter :: constval = 10.0 + integer :: a + a = int(constval) + end subroutine foo +end module test_mod +""" + psyir = fortran_reader.psyir_from_source(code) + bar = psyir.walk(Routine)[0] + foo = psyir.walk(Routine)[1] + call = bar.walk(Call)[0] + + InlineTrans().apply(call, routine=foo, parameter_cloning=False) + + result = fortran_writer(bar) + # bar has integer constval=10, foo has real constval=10.0. Different + # types so the routine's parameter must be added (renamed) rather than + # deduplicated. + assert """\ +subroutine bar(b) + integer, parameter :: constval = 10 + real, parameter :: constval_1 = 10.0 + integer :: b + + b = INT(constval_1) + +end subroutine bar""" in result