-
Notifications
You must be signed in to change notification settings - Fork 33
Expand file tree
/
Copy pathinline_trans.py
More file actions
1400 lines (1254 loc) · 61.6 KB
/
inline_trans.py
File metadata and controls
1400 lines (1254 loc) · 61.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# -----------------------------------------------------------------------------
# BSD 3-Clause License
#
# Copyright (c) 2022-2026, Science and Technology Facilities Council.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
# -----------------------------------------------------------------------------
# Authors: A. R. Porter, R. W. Ford, A. Chalk and S. Siso, STFC Daresbury Lab
'''
This module contains the InlineTrans transformation.
'''
from typing import Dict, List, Optional
from psyclone.core import SymbolicMaths
from psyclone.errors import LazyString, InternalError
from psyclone.psyGen import Kern, Transformation
from psyclone.psyir.nodes import (
ArrayReference, ArrayOfStructuresReference, Assignment, BinaryOperation,
Call, CodeBlock, DataNode, IfBlock, IntrinsicCall, Literal, Loop, Node,
Range, Routine, Reference, Return, Schedule, ScopingNode, Statement,
StructureMember, StructureReference)
from psyclone.psyir.nodes.array_mixin import ArrayMixin
from psyclone.psyir.symbols import (
ArrayType,
BOOLEAN_TYPE,
DataSymbol,
INTEGER_TYPE,
StructureType,
SymbolError,
UnresolvedType,
UnsupportedType,
UnsupportedFortranType,
)
from psyclone.psyir.transformations.callee_transformation_mixin import (
CalleeTransformationMixin)
from psyclone.psyir.transformations.reference2arrayrange_trans import (
Reference2ArrayRangeTrans)
from psyclone.psyir.transformations.transformation_error import (
TransformationError)
from psyclone.psyir.nodes.call import CallMatchingArgumentsNotFound
from psyclone.utils import transformation_documentation_wrapper
_ONE = Literal("1", INTEGER_TYPE)
@transformation_documentation_wrapper
class InlineTrans(Transformation, CalleeTransformationMixin):
'''
This transformation takes a Call (which may have a return value)
and replaces it with the body of the target routine. It is used as
follows:
>>> from psyclone.psyir.backend.fortran import FortranWriter
>>> from psyclone.psyir.frontend.fortran import FortranReader
>>> from psyclone.psyir.nodes import Call, Routine
>>> from psyclone.psyir.transformations import InlineTrans
>>> code = """
... module test_mod
... contains
... subroutine run_it()
... integer :: i
... real :: a(10)
... do i=1,10
... a(i) = 1.0
... call sub(a(i))
... end do
... end subroutine run_it
... subroutine sub(x)
... real, intent(inout) :: x
... x = 2.0*x
... end subroutine sub
... end module test_mod"""
>>> psyir = FortranReader().psyir_from_source(code)
>>> call = psyir.walk(Call)[0]
>>> inline_trans = InlineTrans()
>>> inline_trans.apply(call)
>>> # Uncomment the following line to see a text view of the schedule
>>> # print(psyir.walk(Routine)[0].view())
>>> print(FortranWriter()(psyir.walk(Routine)[0]))
subroutine run_it()
integer :: i
real, dimension(10) :: a
<BLANKLINE>
do i = 1, 10, 1
a(i) = 1.0
a(i) = 2.0 * a(i)
enddo
<BLANKLINE>
end subroutine run_it
<BLANKLINE>
The target of the call must already be present in the same Container
(module) as the call site. This may be achieved using
KernelModuleInlineTrans.
.. warning::
Routines/calls with any of the following characteristics are not
supported and will result in a TransformationError:
* the routine contains an early Return statement;
* the routine contains a variable with UnknownInterface;
* the routine contains a variable with StaticInterface;
* the routine contains an UnsupportedType variable with
ArgumentInterface;
* the shape of any array arguments as declared inside the routine does
not match the shape of the arrays being passed as arguments;
* the routine accesses an un-resolved symbol;
* the routine accesses a symbol declared in the Container to which it
belongs.
Some of these restrictions will be lifted by #924.
'''
def apply(self,
node: Call,
routine: Optional[Routine] = None,
use_first_callee_and_no_arg_check: bool = False,
permit_codeblocks: bool = False,
permit_unsupported_type_args: bool = False,
parameter_cloning: bool = True,
**kwargs
):
'''
Takes the body of the routine that is the target of the supplied
call and replaces the call with it.
:param node: the Call node to inline.
:param routine: Optional Routine to be inlined. (By default, PSyclone
will search for a target routine with a matching signature).
:param use_first_callee_and_no_arg_check: if True, simply use the
first potential callee routine. No argument type-checking is
performed.
:param permit_codeblocks: If `False` (the default), raise an Exception
if the target routine contains a CodeBlock.
:param permit_unsupported_type_args: If `True` then the target routine
is permitted to have arguments of UnsupportedType.
:param parameter_cloning: if `True` (the default), constant
(PARAMETER) symbols from the routine being inlined are always
copied into the call-site symbol table, potentially being renamed
to avoid clashes. If `False`, a constant from the routine is
skipped when an identical constant (same name, same type, and same
value) already exists at the call site, so no duplicate is
created.
:raises InternalError: if the merge of the symbol tables fails.
In theory this should never happen because validate() should
catch such a situation.
'''
self.validate(
node, routine=routine,
use_first_callee_and_no_arg_check=(
use_first_callee_and_no_arg_check),
permit_codeblocks=permit_codeblocks,
permit_unsupported_type_args=permit_unsupported_type_args)
# The table associated with the scoping region holding the Call.
table = node.ancestor(Routine).symbol_table
if not routine:
# No target Routine has been provided so we search for one with
# a matching signature.
(orig_routine, arg_match_list) = node.get_callee(
use_first_callee_and_no_arg_check=(
use_first_callee_and_no_arg_check))
else:
# Target Routine supplied to this transformation directly.
orig_routine = routine
arg_match_list = node.get_argument_map(orig_routine)
if not orig_routine.children or isinstance(orig_routine.children[0],
Return):
# Called routine is empty so just remove the call.
node.detach()
return
# Ensure we don't modify the original Routine by working with a
# copy of it.
routine = orig_routine.copy()
routine_table = routine.symbol_table
# Next, we remove all optional arguments which are not used.
# Step 1)
# - Build lookup dictionary for all optional arguments:
# - For all `PRESENT(...)`:
# - Lookup variable in dictionary
# - Replace with `True` or `False`, depending on whether
# it's provided or not.
self._optional_arg_resolve_present_intrinsics(
routine, arg_match_list
)
# Step 2)
# - For all If-Statements, handle constant conditions:
# - `True`: Replace If-Block with If-Body
# - `False`: Replace If-Block with Else-Body. If it doesn't exist
# just delete the if statement.
self._optional_arg_eliminate_ifblock_if_const_condition(routine)
# If parameter_cloning is disabled, identify duplicate constant
# (PARAMETER) symbols and redirect their references *before* the
# routine body is extracted, so that the extracted statements already
# carry references to the call-site symbols.
extra_skip: List[DataSymbol] = []
if not parameter_cloning:
extra_skip = self._redirect_duplicate_parameters(
table, routine, routine_table)
# Construct lists of the nodes that will be inserted and all of the
# References that they contain.
new_stmts = []
refs = []
for child in routine.pop_all_children():
new_stmts.append(child)
refs.extend(new_stmts[-1].walk(Reference))
# Shallow copy the symbols from the routine into the table at the
# call site. This preserves any references to them.
try:
table.merge(routine_table,
symbols_to_skip=routine_table.argument_list[:]
+ extra_skip)
except SymbolError as err:
raise InternalError(
f"Error copying routine symbols to call site. This should "
f"have been caught by the validate() method. Original error "
f"was {err}") from err
# Replace any references to formal arguments with copies of the
# actual arguments.
formal_args = routine_table.argument_list
for ref in refs[:]:
self._replace_formal_args_in_expr(
ref, node, formal_args, routine_node=routine,
use_first_callee_and_no_arg_check=(
use_first_callee_and_no_arg_check)
)
# Ensure any references to Symbols within the shape-specification of
# other Symbols are updated. Note, we don't have to worry about
# initialisation expressions here as they imply that a variable is
# static. We don't support inlining routines with static variables.
for sym in table.automatic_datasymbols:
if not isinstance(sym.datatype, ArrayType):
continue
new_shape = []
for dim in sym.datatype.shape:
if isinstance(dim, ArrayType.Extent):
new_shape.append(dim)
else:
lower = self._replace_formal_args_in_expr(
dim.lower, node, formal_args,
routine_node=routine,
use_first_callee_and_no_arg_check=(
use_first_callee_and_no_arg_check),
)
upper = self._replace_formal_args_in_expr(
dim.upper, node, formal_args,
routine_node=routine,
use_first_callee_and_no_arg_check=(
use_first_callee_and_no_arg_check),
)
new_shape.append(ArrayType.ArrayBounds(lower, upper))
sym.datatype = ArrayType(sym.datatype.elemental_type, new_shape)
for sym in table.datatypesymbols:
if not isinstance(sym.datatype, StructureType):
continue
for name, ctype in sym.datatype.components.items():
if isinstance(ctype.datatype, ArrayType):
new_shape = []
for dim in ctype.datatype.shape:
lower = self._replace_formal_args_in_expr(
dim.lower, node, formal_args,
routine_node=routine,
use_first_callee_and_no_arg_check=(
use_first_callee_and_no_arg_check),
)
upper = self._replace_formal_args_in_expr(
dim.upper, node, formal_args,
routine_node=routine,
use_first_callee_and_no_arg_check=(
use_first_callee_and_no_arg_check),
)
new_shape.append(ArrayType.ArrayBounds(lower, upper))
sym.datatype.components[name] = (
StructureType.ComponentType(
name=name,
datatype=ArrayType(ctype.datatype.elemental_type,
new_shape),
visibility=ctype.visibility,
initial_value=ctype.initial_value))
# Copy the nodes from the Routine into the call site.
# TODO #924 - while doing this we should ensure that any References
# to common/shared Symbols in the inlined code are updated to point
# to the ones at the call site.
if routine.return_symbol:
# This is a function
assignment = node.ancestor(Statement, excluding=Call)
parent = assignment.parent
idx = assignment.position-1
for child in new_stmts:
idx += 1
parent.addchild(child, idx)
# Avoid a potential name clash with the original function
table.rename_symbol(
routine.return_symbol, table.next_available_name(
f"inlined_{routine.return_symbol.name}"))
node.replace_with(Reference(routine.return_symbol))
else:
# This is a call
parent = node.parent
idx = node.position
node.replace_with(new_stmts[0])
for child in new_stmts[1:]:
idx += 1
parent.addchild(child, idx)
def _redirect_duplicate_parameters(
self,
table,
routine: Routine,
routine_table,
) -> List[DataSymbol]:
'''
Identifies constant (PARAMETER) symbols in ``routine_table`` that
are identical to constants already present in ``table`` (same name,
same type, and same initial value). For each such symbol, every
:py:class:`~psyclone.psyir.nodes.Reference` to it inside ``routine``
and inside the datatypes / initial-value expressions of other symbols
in ``routine_table`` is redirected to point to the corresponding
symbol in ``table``.
Only constants whose initial value is represented as a PSyIR node
(i.e. ``initial_value is not None``) are considered; constants of
``UnsupportedFortranType`` with an embedded value string are left
unchanged.
A constant is only considered a duplicate when every routine-local
symbol referenced inside its initial-value expression is itself a
confirmed duplicate. This prevents false positives for expressions
like ``negflag = .NOT. flag`` when ``flag`` has different values in
the caller and the callee (the names would match but the semantics
would differ).
:param table: the call-site symbol table.
:type table: :py:class:`psyclone.psyir.symbols.SymbolTable`
:param routine: the (copy of the) routine being inlined.
:type routine: :py:class:`psyclone.psyir.nodes.Routine`
:param routine_table: the symbol table of the routine copy.
:type routine_table: :py:class:`psyclone.psyir.symbols.SymbolTable`
:returns: the list of routine symbols that are duplicates of
call-site constants and should be excluded from the subsequent
table merge.
:rtype: List[:py:class:`psyclone.psyir.symbols.DataSymbol`]
'''
# The names of all local data symbols in the routine table (used to
# identify references that point to routine-local constants).
routine_local_names = {
s.name.lower() for s in routine_table.datasymbols
if not s.is_argument
}
# First pass: collect all constants from the routine whose name,
# datatype, and initial-value tree match a constant in the call-site
# table. The structural comparison uses __eq__, which compares
# Reference nodes by symbol name. This is correct for leaf constants
# (Literals) and is refined for dependent constants in the second
# pass below.
candidates: dict = {}
for rsym in routine_table.datasymbols:
if not rsym.is_constant or rsym.initial_value is None:
# Skip constants whose value is not represented as a PSyIR
# node (e.g. UnsupportedFortranType with embedded value).
continue
tsym = table.lookup(rsym.name, otherwise=None)
if not isinstance(tsym, DataSymbol):
continue
if not tsym.is_constant or tsym.initial_value is None:
continue
if rsym.datatype != tsym.datatype:
continue
if rsym.initial_value != tsym.initial_value:
continue
candidates[rsym.name.lower()] = rsym
# Second pass: iteratively remove candidates whose initial-value
# expression references a routine-local symbol that is NOT itself
# a confirmed duplicate. Without this step, an expression like
# ``negflag = .NOT. flag`` would compare as equal by name even when
# ``flag`` has different values in the two routines.
changed = True
while changed:
changed = False
to_remove = [
name for name, rsym in candidates.items()
if any(
dep.name.lower() in routine_local_names
and dep.name.lower() not in candidates
for dep in rsym.initial_value.get_all_accessed_symbols()
)
]
for name in to_remove:
del candidates[name]
if to_remove:
changed = True
duplicates: List[DataSymbol] = list(candidates.values())
# Redirect all references from duplicate routine symbols to their
# call-site counterparts.
for rsym in duplicates:
tsym = table.lookup(rsym.name)
# Update all References in the routine body.
routine.replace_symbols_using(tsym)
# Update any references to rsym embedded in the datatypes or
# initial-value expressions of other symbols in routine_table.
for sym in routine_table.symbols:
if sym is rsym:
continue
sym.replace_symbols_using(tsym)
if hasattr(sym, 'datatype') and sym.datatype is not None:
sym.datatype.replace_symbols_using(tsym)
return duplicates
def _optional_arg_resolve_present_intrinsics(self,
routine_node: Routine,
arg_match_list: List = []):
"""Replace PRESENT(some_argument) intrinsics in routine with constant
booleans depending on whether `some_argument` has been provided
(`True`) or not (`False`).
:param routine_node: Routine to be inlined.
:param arg_match_list: List of indices of arguments that match.
"""
# We first build a lookup table of all optional arguments
# to see whether it's present or not.
optional_sym_present_dict: Dict[str, bool] = dict()
for optional_arg_idx, datasymbol in enumerate(
routine_node.symbol_table.datasymbols
):
if not isinstance(datasymbol.datatype, UnsupportedFortranType):
continue
if ", OPTIONAL" not in str(datasymbol.datatype):
continue
sym_name = datasymbol.name.lower()
optional_sym_present_dict[sym_name] = (optional_arg_idx in
arg_match_list)
# Check if we have any optional arguments at all and if not, return
if len(optional_sym_present_dict) == 0:
return
# Find all "PRESENT()" calls
for intrinsic_call in routine_node.walk(IntrinsicCall):
intrinsic_call: IntrinsicCall
if intrinsic_call.intrinsic is IntrinsicCall.Intrinsic.PRESENT:
present_arg: Reference = intrinsic_call.arguments[0]
present_arg_name = present_arg.name.lower()
is_present = optional_sym_present_dict.get(present_arg_name,
None)
if is_present:
# The argument is present.
intrinsic_call.replace_with(Literal("true", BOOLEAN_TYPE))
else:
intrinsic_call.replace_with(Literal("false", BOOLEAN_TYPE))
def _optional_arg_eliminate_ifblock_if_const_condition(
self, routine_node: Routine
):
"""Eliminate if-block where condition is a boolean Literal.
:param routine_node: the Routine in which to eliminate if blocks.
"""
def if_else_replace(main_schedule: Schedule,
if_block: IfBlock,
if_body_schedule: Schedule):
"""Little helper routine to eliminate one branch of an IfBlock.
:param main_schedule: Schedule where if-branch is used
:param if_block: If-else block itself
:param if_body_schedule: The body of the if or else block
"""
# Obtain index in main schedule
idx = main_schedule.children.index(if_block)
# Detach it
if_block.detach()
# Insert children of if-body schedule
for child in if_body_schedule.pop_all_children():
main_schedule.addchild(child, idx)
idx += 1
sym_maths = SymbolicMaths.get()
for if_block in routine_node.walk(IfBlock):
if_block: IfBlock
condition = if_block.condition
# Ensure any expressions in the condition are simplified.
sym_maths.expand(condition)
# Make sure we only handle a Boolean Literal as a condition
# TODO #2802
if not isinstance(condition, Literal):
continue
if condition.value == "true":
# Only keep if_block
if_else_replace(if_block.parent, if_block, if_block.if_body)
else:
# If there's an else block, replace if-condition with
# else-block
if not if_block.else_body:
if_block.detach()
continue
# Only keep else block
if_else_replace(if_block.parent, if_block, if_block.else_body)
def _replace_formal_args_in_expr(
self,
expression: Node,
call_node: Call,
formal_args: List[DataSymbol],
routine_node: Routine,
use_first_callee_and_no_arg_check: bool = False,
) -> Reference:
'''
Recursively combines any References to formal arguments in the supplied
PSyIR expression with the corresponding Reference (actual argument)
from the call site to make a new Reference for use in the inlined code.
If the supplied node is not a Reference to a formal argument then it is
just returned (after we have recursed to any children).
:param ref: the expression to update.
:param call_node: the call site.
:param formal_args: the formal arguments of the called routine.
:returns: the replacement reference.
'''
if not isinstance(expression, Reference):
# Recurse down in case this is e.g. an Operation or Range.
for child in expression.children[:]:
self._replace_formal_args_in_expr(
child, call_node, formal_args, routine_node,
use_first_callee_and_no_arg_check=(
use_first_callee_and_no_arg_check))
return expression
ref = expression
if ref.symbol not in formal_args:
# The supplied reference is not to a formal argument.
return ref
# Lookup index in routine argument
routine_arg_idx = formal_args.index(ref.symbol)
if use_first_callee_and_no_arg_check:
# We're not attempting to match argument types.
actual_arg = call_node.arguments[routine_arg_idx]
else:
# Lookup index of actual argument
# If this is an optional argument, but not used, this index lookup
# shouldn't fail
try:
arg_match_list = call_node.get_argument_map(routine_node)
actual_arg_idx = arg_match_list.index(routine_arg_idx)
except ValueError as err:
arg_list = routine_node.symbol_table.argument_list
arg_name = arg_list[routine_arg_idx].name
raise TransformationError(
f"Subroutine argument '{arg_name}' is not provided by "
f"'{call_node.debug_string().strip()}', but used in the "
f"subroutine. If this is correct code, this is likely due "
f"to some non-eliminated if-branches using `PRESENT(...)` "
f"as conditional (TODO #2802).") from err
# Lookup the actual argument that corresponds to this formal
# argument.
actual_arg = call_node.arguments[actual_arg_idx]
# Generate a expression that replaces the formal argument using
# the actual argument.
new_ref = self._generate_formal_arg_replacement(
actual_arg, ref, call_node, formal_args,
routine_node=routine_node,
use_first_callee_and_no_arg_check=(
use_first_callee_and_no_arg_check))
# If the local reference we are replacing has a parent then we must
# ensure the parent's child list is updated. (It may not have a parent
# if we are in the process of constructing a brand new reference.)
if ref.parent:
ref.replace_with(new_ref)
return new_ref
def _create_inlined_idx(
self,
call_node: Call,
formal_args: List[DataSymbol],
local_idx: DataNode,
decln_start: DataNode,
actual_start: DataNode,
routine_node: Routine,
use_first_callee_and_no_arg_check: bool = False,
) -> DataNode:
'''
Utility that creates the PSyIR for an inlined array-index access
expression. This is not trivial since a formal argument may be
declared with bounds that are shifted relative to those of an
actual argument.
If local_idx is the index of the access in the routine;
local_decln_start is the starting index of the dimension as
declared in the routine;
actual_start is the starting index of the slice at the callsite
(whether from the array declaration or a slice);
then the index of the inlined access will be::
inlined_idx = local_idx - local_decln_start + 1 + actual_start - 1
= local_idx - local_decln_start + actual_start
:param call_node: the Call that we are inlining.
:param formal_args: the formal arguments of the routine being called.
:param local_idx: a local array-index expression (i.e. appearing
within the routine being inlined).
:param decln_start: the lower bound of the corresponding array
dimension, as declared inside the routine being inlined.
:param actual_start: the lower bound of the corresponding array
dimension, as defined at the call site.
:param routine_node: the Routine that is being inlined.
:param use_first_callee_and_no_arg_check: use the first potential
callee that is found without checking for argument types. Defaults
to False.
:returns: PSyIR for the corresponding inlined array index.
'''
if isinstance(local_idx, Range):
lower = self._create_inlined_idx(
call_node,
formal_args,
local_idx.start,
decln_start,
actual_start,
routine_node=routine_node,
use_first_callee_and_no_arg_check=(
use_first_callee_and_no_arg_check)
)
upper = self._create_inlined_idx(
call_node,
formal_args,
local_idx.stop,
decln_start,
actual_start,
routine_node=routine_node,
use_first_callee_and_no_arg_check=(
use_first_callee_and_no_arg_check)
)
step = self._replace_formal_args_in_expr(
local_idx.step,
call_node,
formal_args,
routine_node=routine_node,
use_first_callee_and_no_arg_check=(
use_first_callee_and_no_arg_check)
)
return Range.create(lower.copy(), upper.copy(), step.copy())
uidx = self._replace_formal_args_in_expr(
local_idx,
call_node,
formal_args,
routine_node=routine_node,
use_first_callee_and_no_arg_check=use_first_callee_and_no_arg_check
)
if decln_start == actual_start:
# If the starting indices in the actual and formal arguments are
# the same then we don't need to shift the index.
return uidx
ustart = self._replace_formal_args_in_expr(
decln_start,
call_node,
formal_args,
routine_node=routine_node,
use_first_callee_and_no_arg_check=use_first_callee_and_no_arg_check
)
start_sub = BinaryOperation.create(BinaryOperation.Operator.SUB,
uidx.copy(), ustart.copy())
return BinaryOperation.create(BinaryOperation.Operator.ADD,
start_sub, actual_start.copy())
def _update_actual_indices(
self,
actual_arg: ArrayMixin,
local_ref: Reference,
call_node: Call,
formal_args: List[DataSymbol],
routine_node: Routine,
use_first_callee_and_no_arg_check: bool = False,
) -> List[Node]:
'''
Create a new list of indices for the supplied actual argument
(ArrayMixin) by replacing any Ranges with the appropriate expressions
from the local access in the called routine. If there are no Ranges
then the returned list of indices just contains copies of the inputs.
:param actual_arg: (part of) the actual argument to the routine.
:param local_ref: the corresponding Reference in the called routine.
:param call_node: the call site.
:param formal_args: the formal arguments of the called routine.
:param routine_node: the Routine being inlined.
:param use_first_callee_and_no_arg_check: use the first potential
callee that is found without checking for argument types. Defaults
to False.
:returns: new indices for the actual argument.
'''
if isinstance(local_ref, ArrayMixin):
local_indices = [idx.copy() for idx in local_ref.indices]
# Get the locally-declared shape of the formal argument in case its
# bounds are shifted relative to the caller.
if isinstance(local_ref.symbol.datatype, ArrayType):
local_decln_shape = local_ref.symbol.datatype.shape
else:
local_decln_shape = []
new_indices = [idx.copy() for idx in actual_arg.indices]
local_idx_posn = 0
for pos, idx in enumerate(new_indices[:]):
if not isinstance(idx, Range):
continue
# Starting index of slice of actual argument.
if actual_arg.is_lower_bound(pos):
# Range starts at lower bound of argument so that's what
# we store.
actual_start = actual_arg.get_lbound_expression(pos)
else:
actual_start = idx.start
local_decln_start = None
if local_decln_shape:
if isinstance(local_decln_shape[local_idx_posn],
ArrayType.ArrayBounds):
# The formal argument declaration has a shape.
local_shape = local_decln_shape[local_idx_posn]
local_decln_start = local_shape.lower
if isinstance(local_decln_start, Node):
# Ensure any references to formal arguments within
# the declared array lower bound are updated.
local_decln_start = self._replace_formal_args_in_expr(
local_decln_start,
call_node,
formal_args,
routine_node=routine_node,
use_first_callee_and_no_arg_check=(
use_first_callee_and_no_arg_check),
)
elif (local_decln_shape[local_idx_posn] ==
ArrayType.Extent.DEFERRED):
# The formal argument is declared to be allocatable and
# therefore has the same bounds as the actual argument.
local_shape = None
local_decln_start = actual_start
if not local_decln_start:
local_shape = None
local_decln_start = _ONE
if local_ref.is_full_range(local_idx_posn):
# If the local Range is for the full extent of the formal
# argument then the actual Range is defined by that of the
# actual argument and no change is required unless the formal
# argument is declared as having a Range with an extent that is
# less than that supplied. In general we're not going to know
# that so we have to be conservative.
if local_shape:
new = Range.create(local_shape.lower.copy(),
local_shape.upper.copy())
new_indices[pos] = self._create_inlined_idx(
call_node,
formal_args,
new,
local_decln_start,
actual_start,
routine_node=routine_node,
use_first_callee_and_no_arg_check=(
use_first_callee_and_no_arg_check),
)
else:
# Otherwise, the local index expression replaces the Range.
new_indices[pos] = self._create_inlined_idx(
call_node,
formal_args,
local_indices[local_idx_posn],
local_decln_start,
actual_start,
routine_node=routine_node,
use_first_callee_and_no_arg_check=(
use_first_callee_and_no_arg_check),
)
# Each Range corresponds to one dimension of the formal argument.
local_idx_posn += 1
return new_indices
def _generate_formal_arg_replacement(
self,
actual_arg: Reference,
ref: Reference,
call_node: Call,
formal_args: List[DataSymbol],
routine_node: Routine,
use_first_callee_and_no_arg_check: bool = False,
) -> Reference:
'''
Called by _replace_formal_args_in_expr() whenever a reference to
the formal argument is found. This will need to be replaced with
the actual argument (accounting for possible index offsets between
the two). For example:
.. code-block:: fortran
call my_sub(my_struct%grid(:,2,:), 10)
subroutine my_sub(grid, ngrids)
...
do igrid = 1, ngrids
do jgrid = ...
do i = 1, 10
do j = 1, 10
grid(igrid, jgrid)%data(i,j) = 0.0
The assignment in the inlined code should become
.. code-block:: fortran
my_struct%grid(igrid,2,jgrid)%data(i,j) = 0.0
This routine therefore recursively combines any References to formal
arguments in the supplied Reference (including any array-index
expressions) with the corresponding Reference from the call site to
make a new Reference for use in the inlined code.
:param actual_arg: an actual argument to the routine being inlined.
:param ref: the corresponding reference to a formal argument.
:param call_node: the call site.
:param formal_args: the formal arguments of the called routine.
:param routine_node: Routine node to be inlined.
:param use_first_callee_and_no_arg_check: Just use the first possible
callee and do not check argument types. Defaults to False.
:returns: the replacement reference.
'''
actual_arg = actual_arg.copy()
# If the local reference is a simple Reference then we can just
# replace it with a copy of the actual argument, e.g.
#
# call my_sub(my_struct%data(i,j))
#
# subroutine my_sub(var)
# ...
# var = 0.0
#
# pylint: disable=unidiomatic-typecheck
if type(ref) is Reference:
return actual_arg
# Below this point we need to know if the reference is to an Array
# and if so, what are their boundaries. To do so we can use the
# Reference2ArrayRangeTrans
if type(actual_arg) is Reference:
if isinstance(actual_arg.datatype, ArrayType):
dummy = Assignment()
dummy.addchild(actual_arg)
Reference2ArrayRangeTrans().apply(actual_arg)
actual_arg = dummy.children[0]
# If the local reference is not simple but the actual argument is
# guaranteed to not be an array, e.g.:
#
# call my_sub(my_struct)
#
# subroutine my_sub(var)
# ...
# var%data(i,j) = 0.0
#
# we just need to replicate the local expression but using the
# actual argument symbol.
if type(actual_arg) is Reference:
new_ref = ref.copy()
new_ref.symbol = actual_arg.symbol
for child in new_ref.children[:]:
self._replace_formal_args_in_expr(
child,
call_node,
formal_args,
routine_node,
use_first_callee_and_no_arg_check=(
use_first_callee_and_no_arg_check),
)
return new_ref
# If we reach this point we need to create the appropriate
# [Array][Of][Structure][s]Reference so we have to collect the indices
# and members as we walk down both the actual and local references.
local_indices = None
members = []
# Actual arg could be var, var(:)%index, var(i,j)%grid(:) or
# var(j)%data(i) etc. Any Ranges must correspond to dimensions of the
# formal argument. The validate() method has already ensured that we
# do not have any indirect accesses or non-unit strides.
if isinstance(ref, ArrayMixin):
local_indices = [idx.copy() for idx in ref.indices]
# Since a Range can occur at any level of a Structure access in the
# actual argument, we walk down it and check each Member. Any Ranges
# are updated according to how that dimension is accessed by the
# reference inside the routine.
cursor = actual_arg
while True:
if isinstance(cursor, ArrayMixin):
new_indices = self._update_actual_indices(
cursor,
ref,
call_node,
formal_args,
routine_node=routine_node,
use_first_callee_and_no_arg_check=(
use_first_callee_and_no_arg_check),
)
members.append((cursor.name, new_indices))
else:
members.append(cursor.name)
if not isinstance(cursor, (StructureMember, StructureReference)):
break
cursor = cursor.member
if not actual_arg.walk(Range) and local_indices:
# There are no Ranges in the actual argument but the local
# reference is an array access.
# Create updated index expressions for that access.
new_indices = []
for idx in local_indices:
new_indices.append(
self._replace_formal_args_in_expr(