Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 93 additions & 11 deletions desc/external/gx.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import re
import shutil
import subprocess
from collections import deque

import numpy as np
from interpax import interp1d
Expand Down Expand Up @@ -467,7 +468,9 @@ def gx(
gx_input_file : str
Path to a template GX TOML input file that specifies physics parameters
(species, domain, time stepping, etc.). The ``geo_file`` path in this
template will be replaced with the generated geometry file.
template will be replaced with the generated geometry file, and
``geo_option`` will be set to ``"eik"`` to match DESC's plain-text
geometry output.
launch_cmd : list of str, optional
Command prefix for launching GX, e.g.
``["srun", "-N", "1", "--gpus-per-task=1"]`` for SLURM GPU allocation or
Expand Down Expand Up @@ -576,7 +579,7 @@ def gx(
timeout=timeout,
)
qflux = _read_gx_output(output_nc)
except (OSError, subprocess.TimeoutExpired, RuntimeError) as e:
except (OSError, RuntimeError) as e:
warnif(
True,
UserWarning,
Expand All @@ -603,6 +606,32 @@ def _write_gx_input(template_path, output_path, geo_path):
with open(template_path) as f:
data = f.read()

# DESC writes GX geometry in the plain-text eik format, so normalize the
# generated input even if the template was copied from an NC-based workflow.
data, geo_option_subs = re.subn(
r"(^\s*geo_option\s*=\s*)(['\"])[^'\"]*\2",
r'\1"eik"',
data,
count=1,
flags=re.MULTILINE,
)
if geo_option_subs == 0:
data, geo_option_subs = re.subn(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the other branch, I switched to using a loop to update the file content. This is mainly because regex operations are very hard to read, and some other developers wanted that.

@rmchurch Can you switch to the other method when resolving the merge conflicts?

r"(^\s*geo_option\s*=\s*)(\S+)",
r'\1"eik"',
data,
count=1,
flags=re.MULTILINE,
)
if geo_option_subs == 0:
data, geo_option_subs = re.subn(
r"(^\s*geo_file\s*=.*$)",
'geo_option = "eik"\n\\1',
data,
count=1,
flags=re.MULTILINE,
)

# replace any existing geo_file reference with the new path
data = re.sub(
r"(geo_file\s*=\s*)(['\"])[^'\"]*\2",
Expand All @@ -620,6 +649,40 @@ def _write_gx_input(template_path, output_path, geo_path):
f.write(data)


def _tail_file(path, max_lines=10):
"""Return the tail of a text file, skipping blank lines."""
if not os.path.exists(path):
return ""

tail = deque(maxlen=max_lines)
with open(path, errors="replace") as f:
for line in f:
line = line.rstrip()
if line:
tail.append(line)
return "\n".join(tail)


def _format_subprocess_failure(cmd, stdout_path, stderr_path, *, returncode=None, timeout=None):
"""Build a concise GX subprocess failure message with log context."""
parts = [f"Command failed: {' '.join(cmd)}"]
if returncode is not None:
parts.append(f"exit status: {returncode}")
if timeout is not None:
parts.append(f"timed out after {timeout:.0f} seconds")

stdout_tail = _tail_file(stdout_path)
if stdout_tail:
parts.append(f"stdout tail:\n{stdout_tail}")

stderr_tail = _tail_file(stderr_path)
if stderr_tail:
parts.append(f"stderr tail:\n{stderr_tail}")

parts.append(f"logs: stdout={stdout_path}, stderr={stderr_path}")
return "\n".join(parts)


def _run_gx(dir, exec_path, input_path=None, launch_cmd=None, gx_gpu=None, timeout=300):
"""Run the GX executable.

Expand Down Expand Up @@ -666,15 +729,34 @@ def _run_gx(dir, exec_path, input_path=None, launch_cmd=None, gx_gpu=None, timeo
env["CUDA_VISIBLE_DEVICES"] = str(gx_gpu)

with open(stdout_path, "w") as fout, open(stderr_path, "w") as ferr:
subprocess.run(
cmd,
cwd=dir,
timeout=timeout,
stdout=fout,
stderr=ferr,
check=True,
env=env,
)
try:
subprocess.run(
cmd,
cwd=dir,
timeout=timeout,
stdout=fout,
stderr=ferr,
check=True,
env=env,
)
except subprocess.CalledProcessError as e:
raise RuntimeError(
_format_subprocess_failure(
cmd,
stdout_path,
stderr_path,
returncode=e.returncode,
)
) from e
except subprocess.TimeoutExpired as e:
raise RuntimeError(
_format_subprocess_failure(
cmd,
stdout_path,
stderr_path,
timeout=timeout,
)
) from e

# find the output netcdf file
out_files = [f for f in os.listdir(dir) if f.endswith(".out.nc")]
Expand Down
46 changes: 46 additions & 0 deletions tests/test_external_gx.py

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We actually don't require writing tests for the desc.external module. But these look good to me. We can decide to keep them or not in the actual PR for the yge/gx branch.

Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""Tests for the GX external objective helpers."""

from desc.external.gx import _run_gx, _write_gx_input


def test_write_gx_input_forces_eik_geometry(tmp_path):
"""GX inputs should match DESC's plain-text geometry output."""
template_path = tmp_path / "template.in"
template_path.write_text(
'[Geometry]\n'
'geo_option = "nc"\n'
'geo_file = "placeholder.nc"\n'
)

output_path = tmp_path / "gx.in"
geo_path = tmp_path / "gx_geo.out"
_write_gx_input(str(template_path), str(output_path), str(geo_path))

data = output_path.read_text()
assert 'geo_option = "eik"' in data
assert f"geo_file = '{geo_path}'" in data


def test_run_gx_reports_child_logs(tmp_path):
"""GX launcher failures should surface the captured stdout/stderr."""
exec_path = tmp_path / "fake_gx.sh"
exec_path.write_text(
"#!/bin/sh\n"
'echo "starting gx"\n'
'echo "fatal child error" >&2\n'
"exit 2\n"
)
exec_path.chmod(0o755)

try:
_run_gx(str(tmp_path), str(exec_path))
except RuntimeError as err:
message = str(err)
else:
raise AssertionError("_run_gx should have raised RuntimeError")

assert "exit status: 2" in message
assert "starting gx" in message
assert "fatal child error" in message
assert str(tmp_path / "stdout.gx") in message
assert str(tmp_path / "stderr.gx") in message
Loading