Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/maxdiffusion/generate_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,10 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):

s0 = time.perf_counter()

# Disable profiler for the first two runs to avoid duplicate uploads
original_enable_profiler = config.enable_profiler if "enable_profiler" in config.get_keys() else False
config.get_keys()["enable_profiler"] = False

# Using global_batch_size_to_train_on so not to create more config variables
prompt = [config.prompt] * config.global_batch_size_to_train_on
negative_prompt = [config.negative_prompt] * config.global_batch_size_to_train_on
Expand Down Expand Up @@ -321,6 +325,8 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
max_logging.log("\n".join(summary))

s0 = time.perf_counter()
# Restore original profiler setting for the profiling run
config.get_keys()["enable_profiler"] = original_enable_profiler
if max_utils.profiler_enabled(config):
# Injecting user requested XLA tracing flags
xla_flags = os.environ.get("XLA_FLAGS", "")
Expand Down
36 changes: 18 additions & 18 deletions src/maxdiffusion/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,31 +154,31 @@ def stop(self):
if _jax_profiler_enabled(self.config):
jax.profiler.stop_trace()

trace_dir = self.config.tensorboard_dir
if trace_dir.startswith("gs://"):
local_dir = os.path.join("/tmp/profiler_traces", self.config.run_name)
if os.path.exists(local_dir):
max_logging.log(f"Uploading profiler traces from {local_dir} to {trace_dir}...")
client = storage.Client()
bucket_name, prefix = parse_gcs_bucket_and_prefix(trace_dir)
bucket = client.bucket(bucket_name)

for root, _, files in os.walk(local_dir):
for file in files:
local_file = os.path.join(root, file)
rel_path = os.path.relpath(local_file, local_dir)
blob_name = os.path.join(prefix, rel_path)
blob = bucket.blob(blob_name)
blob.upload_from_filename(local_file)
max_logging.log(f"Uploaded {local_file} to gs://{bucket_name}/{blob_name}")

def __enter__(self):
self.start()
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.stop()

trace_dir = self.config.tensorboard_dir
if trace_dir.startswith("gs://"):
local_dir = os.path.join("/tmp/profiler_traces", self.config.run_name)
if os.path.exists(local_dir):
max_logging.log(f"Uploading profiler traces from {local_dir} to {trace_dir}...")
client = storage.Client()
bucket_name, prefix = parse_gcs_bucket_and_prefix(trace_dir)
bucket = client.bucket(bucket_name)

for root, _, files in os.walk(local_dir):
for file in files:
local_file = os.path.join(root, file)
rel_path = os.path.relpath(local_file, local_dir)
blob_name = os.path.join(prefix, rel_path)
blob = bucket.blob(blob_name)
blob.upload_from_filename(local_file)
max_logging.log(f"Uploaded {local_file} to gs://{bucket_name}/{blob_name}")


def initialize_summary_writer(config):
return writer.SummaryWriter(config.tensorboard_dir) if jax.process_index() == 0 else None
Expand Down
1 change: 1 addition & 0 deletions src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,7 @@ def scan_body(carry, t):

if config and max_utils.profiler_enabled(config) and step == last_profiling_step:
if profiler:
latents.block_until_ready()
profiler.stop()

return latents
1 change: 1 addition & 0 deletions src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,5 +712,6 @@ def scan_body(carry, t):

if config and max_utils.profiler_enabled(config) and step == last_profiling_step:
if profiler:
latents.block_until_ready()
profiler.stop()
return latents
1 change: 1 addition & 0 deletions src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,5 +478,6 @@ def scan_body(carry, t):

if config and max_utils.profiler_enabled(config) and step == last_profiling_step:
if profiler:
latents.block_until_ready()
profiler.stop()
return latents
1 change: 1 addition & 0 deletions src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,5 +820,6 @@ def scan_body(carry, t):

if config and max_utils.profiler_enabled(config) and step == last_profiling_step:
if profiler:
latents.block_until_ready()
profiler.stop()
return latents
79 changes: 79 additions & 0 deletions src/maxdiffusion/tests/profiler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@

import unittest
from unittest.mock import patch
import os

real_exists = os.path.exists
from maxdiffusion import max_utils


Expand Down Expand Up @@ -80,6 +83,82 @@ def test_profiler_disabled(self, mock_ml_run):
max_utils.ensure_machinelearning_job_runs(config)
mock_ml_run.assert_not_called()

@patch("maxdiffusion.max_utils.storage.Client")
@patch("maxdiffusion.max_utils.os.path.exists")
@patch("maxdiffusion.max_utils.os.walk", return_value=[("/tmp/profiler_traces/test_run", [], ["file1.trace"])])
@patch("jax.profiler.start_trace")
@patch("jax.profiler.stop_trace")
@patch("jax.process_index", return_value=0)
def test_jax_profiler_manual_gcs(
self,
mock_process_index,
mock_stop_trace,
mock_start_trace,
mock_os_walk,
mock_os_exists,
mock_storage_client,
):
"""Tests manual start/stop with GCS upload."""
mock_os_exists.side_effect = lambda path: True if path == "/tmp/profiler_traces/test_run" else real_exists(path)
config = MockConfig(
enable_ml_diagnostics=False,
enable_profiler=True,
tensorboard_dir="gs://test-bucket/tensorboard",
run_name="test_run",
)

profiler = max_utils.Profiler(config)
profiler.start()
mock_start_trace.assert_called_once()

profiler.stop()
mock_stop_trace.assert_called_once()

# Verify GCS upload was attempted
mock_storage_client.assert_called_once()
mock_bucket = mock_storage_client.return_value.bucket
mock_bucket.assert_called_once_with("test-bucket")
mock_blob = mock_bucket.return_value.blob
mock_blob.assert_called_once_with("tensorboard/file1.trace")
mock_blob.return_value.upload_from_filename.assert_called_once_with("/tmp/profiler_traces/test_run/file1.trace")

@patch("maxdiffusion.max_utils.storage.Client")
@patch("maxdiffusion.max_utils.os.path.exists")
@patch("maxdiffusion.max_utils.os.walk", return_value=[("/tmp/profiler_traces/test_run", [], ["file1.trace"])])
@patch("jax.profiler.start_trace")
@patch("jax.profiler.stop_trace")
@patch("jax.process_index", return_value=0)
def test_jax_profiler_context_gcs(
self,
mock_process_index,
mock_stop_trace,
mock_start_trace,
mock_os_walk,
mock_os_exists,
mock_storage_client,
):
"""Tests context manager with GCS upload."""
mock_os_exists.side_effect = lambda path: True if path == "/tmp/profiler_traces/test_run" else real_exists(path)
config = MockConfig(
enable_ml_diagnostics=False,
enable_profiler=True,
tensorboard_dir="gs://test-bucket/tensorboard",
run_name="test_run",
)

with max_utils.Profiler(config):
mock_start_trace.assert_called_once()

mock_stop_trace.assert_called_once()

# Verify GCS upload was attempted
mock_storage_client.assert_called_once()
mock_bucket = mock_storage_client.return_value.bucket
mock_bucket.assert_called_once_with("test-bucket")
mock_blob = mock_bucket.return_value.blob
mock_blob.assert_called_once_with("tensorboard/file1.trace")
mock_blob.return_value.upload_from_filename.assert_called_once_with("/tmp/profiler_traces/test_run/file1.trace")


if __name__ == "__main__":
unittest.main()
Loading