4343import pytest
4444from psyclone .configuration import Config
4545from psyclone .domain .gocean .transformations import GOceanLoopFuseTrans
46+ from psyclone .domain .common .transformations import KernelModuleInlineTrans
4647from psyclone .errors import GenerationError
4748from psyclone .gocean1p0 import GOKern
4849from psyclone .parse import ModuleManager
@@ -103,7 +104,7 @@ def test_loop_fuse_error():
103104 assert "Both nodes must be of the same GOLoop class." in str (err .value )
104105
105106
106- def test_omp_parallel_loop (tmpdir , fortran_writer ):
107+ def test_omp_paralleldo_loop (tmpdir , fortran_writer ):
107108 '''Test that we can generate an OMP PARALLEL DO correctly,
108109 independent of whether or not we are generating constant loop bounds '''
109110 psy , invoke = get_invoke ("single_invoke_three_kernels.f90" , API , idx = 0 ,
@@ -176,7 +177,7 @@ def test_omp_region_with_single_loop(tmpdir):
176177 within_omp_region = False
177178 call_count = 0
178179 for line in gen .split ('\n ' ):
179- if '!$omp parallel' in line :
180+ if '!$omp parallel default(shared) private(i,j) ' in line :
180181 within_omp_region = True
181182 if '!$omp end parallel' in line :
182183 within_omp_region = False
@@ -192,7 +193,7 @@ def test_omp_region_with_single_loop(tmpdir):
192193 within_omp_region = False
193194 call_count = 0
194195 for line in gen .split ('\n ' ):
195- if '!$omp parallel' in line :
196+ if '!$omp parallel default(shared) private(i,j) ' in line :
196197 within_omp_region = True
197198 if '!$omp end parallel' in line :
198199 within_omp_region = False
@@ -222,7 +223,7 @@ def test_omp_region_with_slice(tmpdir):
222223 within_omp_region = False
223224 call_count = 0
224225 for line in gen .split ('\n ' ):
225- if '!$omp parallel' in line :
226+ if '!$omp parallel default(shared) private(i,j) ' in line :
226227 within_omp_region = True
227228 if '!$omp end parallel' in line :
228229 within_omp_region = False
@@ -288,7 +289,7 @@ def test_omp_region_no_slice(tmpdir):
288289 within_omp_region = False
289290 call_count = 0
290291 for line in gen .split ('\n ' ):
291- if '!$omp parallel' in line :
292+ if '!$omp parallel default(shared) private(i,j) ' in line :
292293 within_omp_region = True
293294 if '!$omp end parallel' in line :
294295 within_omp_region = False
@@ -319,7 +320,7 @@ def test_omp_region_no_slice_const_bounds(tmpdir):
319320 within_omp_region = False
320321 call_count = 0
321322 for line in gen .split ('\n ' ):
322- if '!$omp parallel' in line :
323+ if '!$omp parallel default(shared) private(i,j) ' in line :
323324 within_omp_region = True
324325 if '!$omp end parallel' in line :
325326 within_omp_region = False
@@ -452,6 +453,14 @@ def test_omp_region_retains_kernel_order3(tmpdir):
452453
453454 # Kernels should be in order {compute_cu, compute_cv, time_smooth}
454455 assert cu_idx < cv_idx < ts_idx
456+
457+ # Check that the two directive are different statements in above the
458+ # second loop (iterates over cv_fld internal) and that the private
459+ # clause (now on the parallel directive) only has i and j.
460+ assert ("!$omp parallel default(shared) private(i,j)\n "
461+ " !$omp do schedule(static)\n "
462+ " do j = cv_fld%internal%ystart" in gen )
463+
455464 assert GOceanBuild (tmpdir ).code_compiles (psy )
456465
457466
@@ -482,7 +491,7 @@ def test_omp_region_before_loops_trans(tmpdir):
482491 omp_region_idx = - 1
483492 omp_do_idx = - 1
484493 for idx , line in enumerate (gen .split ('\n ' )):
485- if '!$omp parallel' in line :
494+ if '!$omp parallel default(shared) private(i,j) ' in line :
486495 omp_region_idx = idx
487496 if '!$omp do' in line :
488497 omp_do_idx = idx
@@ -502,6 +511,11 @@ def test_omp_region_after_loops_trans(tmpdir):
502511 dist_mem = False )
503512 schedule = invoke .schedule
504513
514+ # We test with inlining because in the past we had an error when
515+ # producing the clauses if the calls were inlined.
516+ for kern in schedule .kernels ():
517+ KernelModuleInlineTrans ().apply (kern )
518+
505519 # Put an OpenMP do directive around each loop contained
506520 # in the schedule
507521 ompl = GOceanOMPLoopTrans ()
@@ -519,7 +533,7 @@ def test_omp_region_after_loops_trans(tmpdir):
519533 omp_region_idx = - 1
520534 omp_do_idx = - 1
521535 for idx , line in enumerate (gen .split ('\n ' )):
522- if '!$omp parallel' in line :
536+ if '!$omp parallel default(shared) private(i,j) ' in line :
523537 omp_region_idx = idx
524538 if '!$omp do' in line :
525539 omp_do_idx = idx
0 commit comments