Skip to content

feat: Added MLX GPU Acceleration to Tensor, MuitiVector, and Topology crate#433

Merged
marvin-hansen merged 55 commits intodeepcausality-rs:mainfrom
marvin-hansen:main
Dec 31, 2025
Merged

feat: Added MLX GPU Acceleration to Tensor, MuitiVector, and Topology crate#433
marvin-hansen merged 55 commits intodeepcausality-rs:mainfrom
marvin-hansen:main

Conversation

@marvin-hansen
Copy link
Copy Markdown
Member

@marvin-hansen marvin-hansen commented Dec 31, 2025

User description

Describe your changes

This PR adds MLX hardware acceleration to the causal tensor, causal multivector, and causal topology crate.
Furthermore, the Multi-Backend implementation has been prepared for a future NVIDIA CUDA backend.

Issue ticket number and link

Closes #432
Closes #428

Code checklist before requesting a review

  • I have signed the DCO?
  • All tests are passing when running make test?
  • No errors or security vulnerabilities are reported by make check?

For details on make, please see BUILD.md

Note: The CI runs all of the above and fixing things before they hit CI speeds
up the review and merge process. Thank you.


PR Type

Enhancement, Tests


Description

  • MLX GPU Acceleration: Implements TensorBackend and LinearAlgebraBackend traits for MLX GPU acceleration with comprehensive tensor operations (create, reshape, permute, slice, stack, einsum) and linear algebra operations (matmul, QR, SVD, inverse, Cholesky)

  • Tensor Architecture Refactoring: Refactors tensor implementation from CausalTensor<T> to InternalCpuTensor<T> with simplified trait bounds and improved API consistency across CPU and backend implementations

  • Multivector Enhancements: Adds Higher-Kinded Type (HKT) operations for CausalMultiField enabling functional programming patterns (Functor, Applicative, Monad, Comonad)

  • Physics Implementations: Introduces Particle Data Group (PDG) database, Lund string fragmentation kernel for QCD hadronization, and Regge geometry metric signature computation

  • Comprehensive Test Coverage: Adds extensive test suites for multivector algebra operations, Clifford algebra products, grade projections, differential operators, metric operations, and arithmetic operations

  • Example Applications: Includes SCUBA diving decompression planner demonstrating physics simulation and modular multi-physics pipeline using Causal Monad pattern

  • Backend Tensor Operations: Implements arithmetic trait operations (Add, Sub, Mul, Div) and einsum AST execution for backend tensors


Diagram Walkthrough

flowchart LR
  A["CausalTensor<br/>Legacy API"] -->|"Refactor"| B["InternalCpuTensor<br/>Simplified Bounds"]
  B -->|"Implement"| C["TensorBackend<br/>Trait"]
  C -->|"CPU"| D["CpuBackend<br/>Tensor"]
  C -->|"GPU"| E["MLXBackend<br/>Tensor"]
  E -->|"Supports"| F["Linear Algebra<br/>Operations"]
  G["CausalMultiField"] -->|"Add HKT"| H["Functional<br/>Programming"]
  I["Physics<br/>Kernels"] -->|"New"| J["PDG Database<br/>Lund Fragmentation<br/>Regge Geometry"]
  K["Test Suites"] -->|"Cover"| L["Algebra, Products<br/>Grades, Differential<br/>Metrics"]
Loading

File Walkthrough

Relevant files
Enhancement
10 files
ein_sum_impl.rs
Refactor EinSum operations to InternalCpuTensor with improved
validation

deep_causality_tensor/src/types/cpu_tensor/ops/tensor_ein_sum/ein_sum_impl.rs

  • Refactored implementation from CausalTensor to InternalCpuTensor with
    simplified trait bounds
  • Replaced recursive execute_ein_sum calls with direct
    EinSumOp::TensorSource pattern matching in operand extraction
  • Expanded contract() method with comprehensive validation and
    multi-dimensional index iteration logic
  • Updated method visibility from pub(in crate::types::causal_tensor) to
    pub(crate) and pub(in crate::types::cpu_tensor)
+266/-244
mlx_backend_tensor.rs
MLX GPU Backend Implementation for Tensor Operations         

deep_causality_tensor/src/types/backend/mlx/mlx_backend_tensor.rs

  • Implements TensorBackend trait for MLX GPU acceleration with 588 lines
    of new code
  • Provides core tensor operations: create, reshape, permute, slice,
    stack, ravel
  • Implements arithmetic operations: add, sub, mul, div, sum, max, mean
  • Implements advanced operations: einsum with AST evaluation,
    broadcasting, arg_sort
  • Includes fallback to CPU for operations not natively supported by MLX
    (slice, arg_sort, shifted_view)
+588/-0 
mod.rs
Refactor CPU Tensor API to Use InternalCpuTensor Type       

deep_causality_tensor/src/types/cpu_tensor/api/mod.rs

  • Updates trait implementations to use InternalCpuTensor instead of
    CausalTensor
  • Adds new trait methods: qr(), svd(), and stack() for linear algebra
    operations
  • Simplifies trait bounds by using TensorData instead of multiple
    individual bounds
  • Updates documentation examples to reflect the new type naming
+82/-51 
ops.rs
Backend Tensor Arithmetic and Operations Implementation   

deep_causality_tensor/src/types/backend_tensor/ops.rs

  • Implements arithmetic trait operations for BackendTensor: Add, Sub,
    Mul, Div
  • Provides both owned and reference variants for all arithmetic
    operations
  • Implements assignment arithmetic: AddAssign, SubAssign, MulAssign,
    DivAssign
  • Adds scalar arithmetic operations via macros for multiple numeric
    types
  • Implements einsum AST execution with recursive mapping
+444/-0 
main.rs
Modular Multi-Physics Pipeline via Causal Monad Pattern   

examples/physics_examples/multi_physics_pipeline/main.rs

  • Refactored pipeline from monolithic structure to modular stages using
    the Causal Monad pattern
  • Replaced single hadronization() call with
    lund_string_fragmentation_kernel() for realistic QCD string
    fragmentation
  • Decomposed pipeline into 5 independent stages: field-to-partons, Lund
    fragmentation, thermalization, and quantum detection
  • Enhanced documentation with physics concepts and design patterns for
    modular composition
+306/-82
mlx_backend_linear_algebra.rs
MLX GPU Linear Algebra Backend Implementation                       

deep_causality_tensor/src/types/backend/mlx/mlx_backend_linear_algebra.rs

  • Implements LinearAlgebraBackend trait for MLX GPU backend with
    operations: matmul, qr, svd, inverse, cholesky_decomposition,
    solve_least_squares_cholsky, tensor_product
  • Adds explicit 4x4 matrix inversion using Jacobi block decomposition
    for GPU compatibility (fallback to CPU for non-4x4 cases)
  • Provides helper functions for 2x2 block operations and matrix
    reshaping on MLX arrays
+257/-0 
mod.rs
Higher-Kinded Type Operations for CausalMultiField             

deep_causality_multivector/src/extensions/hkt_multifield/mod.rs

  • Implements Higher-Kinded Type (HKT) operations for CausalMultiField:
    fmap, pure, apply, bind, extract, extend
  • Provides witness type CausalMultiFieldWitness to work around HKT trait
    incompatibilities with TensorData constraints
  • Includes type aliases for CPU and MLX backends
  • Enables functional programming patterns (Functor, Applicative, Monad,
    Comonad) on fields
+338/-0 
pdg.rs
Particle Data Group (PDG) Database Implementation               

deep_causality_physics/src/nuclear/pdg.rs

  • Implements Particle Data Group (PDG) database with ParticleData struct
    containing PDG ID, mass, charge, spin, and name
  • Provides comprehensive particle catalog: pseudoscalar mesons, vector
    mesons, light/strange/charmed baryons, Delta and Omega baryons
  • Includes lookup functions: pdg_lookup(), pdg_mass(), and quark
    constituent mass constants
  • Adds unit tests for proton, pion, and lookup failure cases
+221/-0 
mod.rs
Regge Geometry Metric Signature Computation                           

deep_causality_topology/src/types/regge_geometry/mod.rs

  • Refactored metric_at() to compute metric signature from Cayley-Menger
    Gram matrix using edge lengths
  • Implements compute_signature() function that builds Gram matrix and
    computes eigenvalues via Jacobi algorithm
  • Adds euclidean_metric_at() fast-path for known Euclidean geometries
  • Includes helper compute_eigenvalues() using Jacobi rotation method for
    small symmetric matrices
  • Adds unit tests for equilateral triangle and regular tetrahedron
    signatures
+243/-16
fragmentation.rs
Lund String Fragmentation Kernel Implementation                   

deep_causality_physics/src/nuclear/lund/fragmentation.rs

  • Implements lund_string_fragmentation_kernel() for QCD string
    fragmentation into hadrons
  • Iterative fragmentation loop: samples z from Lund function, generates
    transverse momentum, selects quark flavor, forms mesons
  • Enforces conservation laws: 4-momentum, electric charge, baryon
    number, strangeness
  • Includes comprehensive unit tests: basic fragmentation, momentum
    conservation, multiplicity scaling with energy
+266/-0 
Documentation
1 files
main.rs
Add SCUBA decompression planner example with physics simulation

examples/medicine_examples/diving_decompression/main.rs

  • New comprehensive SCUBA diving decompression planner implementing
    Bühlmann ZH-L16C algorithm
  • Demonstrates CausalTensor for tissue compartment tracking and
    CausalEffectPropagationProcess for monadic dive phase chaining
  • Includes physics calculations for nitrogen loading, CNS oxygen
    toxicity, and decompression ceiling computation
  • Provides detailed simulation output with dive tables, bubble expansion
    risk analysis, and safety recommendations
+608/-0 
Tests
11 files
ein_sum_impl_tests.rs
Update EinSum tests to use InternalCpuTensor API                 

deep_causality_tensor/src/types/cpu_tensor/ops/tensor_ein_sum/ein_sum_impl_tests.rs

  • Updated all test calls from CausalTensor:: methods to
    InternalCpuTensor:: methods
  • Added .into_inner() conversions on test utility functions to extract
    internal tensor representation
  • Updated imports to include EinSumAST and EinSumOp from cpu_tensor
    module
  • Adjusted expected tensor shapes in contraction tests to match new
    implementation behavior
+82/-78 
mod.rs
Add CausalMultiField algebra operation tests                         

deep_causality_multivector/tests/types/multifield/algebra/mod.rs

  • New comprehensive test suite for CausalMultiField algebra operations
    (scale, normalize, inverse, reversion, commutators)
  • Tests cover scalar multiplication, unit normalization, multiplicative
    inversion, and grade reversion
  • Includes Lie bracket and geometric commutator tests with antisymmetry
    and relationship verification
  • Validates squared magnitude computation and metric consistency across
    field operations
+572/-0 
cpu_tests.rs
Add CpuGammaLoader Clifford algebra matrix tests                 

deep_causality_multivector/tests/types/multifield/gamma/cpu_tests.rs

  • New test suite for CpuGammaLoader implementing BackendGamma trait
  • Tests verify Clifford algebra identities: γ_i² = ±1 and
    anticommutation relations {γ_i, γ_j} = 0
  • Validates basis blade and dual basis blade matrix representations with
    orthogonality checks
  • Confirms identity blade is identity matrix and blade relationships
    across different signatures
+312/-0 
grmhd_tests.rs
Remove deprecated relativistic current kernel tests           

deep_causality_physics/tests/mhd/grmhd_tests.rs

  • Removed two tests for relativistic_current_kernel that are now
    incompatible with updated API
  • Added explanatory comment indicating tests moved to wrappers_tests.rs
    with Manifold and LorentzianMetric requirements
  • Retained test_energy_momentum_tensor test for electromagnetic tensor
    validation
+3/-14   
arithmetic_tests.rs
Comprehensive Arithmetic Tests for CausalMultiField           

deep_causality_multivector/tests/types/multifield/arithmetic/arithmetic_tests.rs

  • Adds comprehensive test suite for CausalMultiField arithmetic
    operations (446 lines)
  • Tests Zero trait, Add/Sub operations with identity and mismatch cases
  • Tests Neg (negation) and Mul (geometric product) operations
  • Tests Scale operation with various scalar values (zero, one, two,
    negative)
+446/-0 
products_tests.rs
Clifford Algebra Products and Commutator Tests                     

deep_causality_multivector/tests/types/multifield/ops/products_tests.rs

  • Adds 393 lines of tests for Clifford algebra products: inner_product,
    outer_product, cross
  • Tests commutator operations: commutator_lie and commutator_geometric
  • Tests hodge_dual operation and its properties (double dual, grade
    changes)
  • Includes metric mismatch and shape mismatch validation tests
+393/-0 
grades_tests.rs
Grade Projection and Part Extraction Tests                             

deep_causality_multivector/tests/types/multifield/ops/grades_tests.rs

  • Adds 404 lines of tests for grade projection operations on
    CausalMultiField
  • Tests grade_project for extracting specific grades (0-3)
  • Tests convenience methods: scalar_part, vector_part, bivector_part,
    trivector_part, pseudoscalar_part
  • Validates grade extraction preserves field properties and handles edge
    cases
+404/-0 
mod.rs
Add Gamma Tests Module Structure                                                 

deep_causality_multivector/tests/types/multifield/gamma/mod.rs

  • Creates new module file for gamma-related tests
  • Declares submodules: cpu_tests, mlx_tests, mod_tests, providers_tests
+8/-0     
metric_tests.rs
Comprehensive Metric Type Test Suite                                         

deep_causality_metric/tests/types/metric_tests.rs

  • Comprehensive test suite for Metric type covering dimension,
    signature, and sign operations
  • Tests for metric conversions: to_generic(), from_signature(),
    from_signs(), to_signs()
  • Tests for metric operations: flip_time_space(), tensor_product(),
    is_compatible()
  • Tests for Display, Hash, and Eq trait implementations
+410/-0 
conversions_tests.rs
CausalMultiField Conversion Operations Test Suite               

deep_causality_multivector/tests/types/multifield/ops/conversions_tests.rs

  • Tests for CausalMultiField conversion operations: zeros(), ones(),
    from_coefficients(), to_coefficients()
  • Roundtrip tests verifying preservation of scalar and vector values
    through conversion cycles
  • Tests for compute_matrix_dim() helper function with various input
    dimensions
  • Validates shape preservation, metric consistency, and spatial
    discretization parameters
+340/-0 
differential_tests.rs
Differential Operators Test Suite for CausalMultiField     

deep_causality_multivector/tests/types/multifield/ops/differential_tests.rs

  • Tests for differential operators on CausalMultiField:
    partial_derivative(), gradient(), curl(), divergence()
  • Tests for Axis enum (X, Y, Z) indexing and equality
  • Validates operator behavior on constant, linear, and zero fields
  • Tests tensor shape preservation and grade projections (curl as
    grade-2, divergence as grade-0)
+332/-0 
Additional files
101 files
.bazelignore +0/-1     
.bazelversion +1/-1     
audit.toml +1/-4     
AGENTS.md +52/-16 
Cargo.toml +1/-0     
check.sh +22/-1   
fix.sh +1/-1     
format.sh +6/-1     
sbom.sh +1/-0     
test.sh +2/-3     
surd_algo.rs +6/-7     
surd_algo_cdl.rs +7/-7     
BUILD.bazel +0/-1     
BUILD.bazel +1/-1     
BUILD.bazel +41/-11 
BUILD.bazel +36/-10 
BUILD.bazel +41/-0   
Cargo.toml +33/-0   
LICENSE +23/-0   
README.md +153/-0 
deep_causality_metric_sbom.spdx.json +52/-0   
deep_causality_metric_sbom.spdx.json.sha +1/-0     
aliases.rs +39/-0   
east_coast.rs +154/-0 
lorentzian_metric.rs +54/-0   
mod.rs +17/-0   
west_coast.rs +141/-0 
metric_error.rs +66/-0   
mod.rs +8/-0     
lib.rs +86/-0   
convert.rs +61/-0   
mod.rs +8/-0     
display.rs +20/-0   
hash.rs +33/-0   
mod.rs +320/-0 
mod.rs +8/-0     
east_coast_tests.rs +165/-0 
lorentzian_metric_tests.rs +123/-0 
mod.rs +8/-0     
west_coast_tests.rs +168/-0 
metric_error_tests.rs +62/-0   
mod.rs +2/-2     
mod.rs +8/-0     
mod.rs +6/-0     
BUILD.bazel +1/-0     
Cargo.toml +22/-3   
README.md +91/-83 
README_BENCHMARKS.md +86/-0   
multifield_bench.rs +101/-0 
multivector_bench.rs +104/-10
README.md +3/-0     
dixon_multivector.rs +3/-6     
pga3d_multivector.rs +16/-15 
alias_complex.rs +6/-0     
alias_real.rs +0/-4     
mod.rs +2/-0     
multifield_aliases.rs +35/-0   
mod.rs [link]   
mod.rs +2/-1     
lib.rs +60/-4   
matrix_rep.rs +51/-0   
mod.rs +1/-0     
multi_vector.rs +86/-0   
mod.rs +0/-173 
mod.rs +1/-1     
mod.rs +201/-0 
mod.rs +222/-0 
cpu.rs +140/-0 
mlx.rs +140/-0 
mod.rs +176/-0 
provider.rs +38/-0   
mod.rs +122/-0 
batched_matmul.rs +78/-0   
conversions.rs +202/-0 
differential.rs +274/-0 
grades.rs +89/-0   
mod.rs +9/-0     
products.rs +124/-0 
mod.rs +68/-3   
mod.rs +268/-3 
multivector.rs +0/-100 
multivector_l2_norm.rs +0/-49   
scalar_eval.rs +0/-54   
mod.rs +50/-0   
mod.rs +1/-1     
mod.rs +1/-0     
ops_matrix_rep.rs +107/-0 
ops_misc_impl.rs +78/-2   
ops_product_impl.rs +79/-2   
BUILD.bazel +4/-2     
alias_complex_tests.rs +6/-6     
hkt_multifield_tests.rs +287/-0 
mod.rs +5/-0     
hkt_extensions_tests.rs [link]   
hkt_tests.rs [link]   
mod.rs [link]   
mod.rs +2/-1     
metric_tests.rs +0/-103 
mod.rs +1/-1     
mod.rs +5/-0     
Additional files not shown

Signed-off-by: Marvin Hansen <marvin.hansen@gmail.com>
Signed-off-by: Marvin Hansen <marvin.hansen@gmail.com>
Signed-off-by: Marvin Hansen <marvin.hansen@gmail.com>
Signed-off-by: Marvin Hansen <marvin.hansen@gmail.com>
…epo.

Signed-off-by: Marvin Hansen <marvin.hansen@gmail.com>
…on. Updated relativistic_current_kernel to use manifold. Updated downstream physics example.

Signed-off-by: Marvin Hansen <marvin.hansen@gmail.com>
… spec.

Signed-off-by: Marvin Hansen <marvin.hansen@gmail.com>
…ssue deepcausality-rs#428

Signed-off-by: Marvin Hansen <marvin.hansen@gmail.com>
Signed-off-by: Marvin Hansen <marvin.hansen@gmail.com>
…loses deepcausality-rs#432

Signed-off-by: Marvin Hansen <marvin.hansen@gmail.com>
Signed-off-by: Marvin Hansen <marvin.hansen@gmail.com>
…aration for backend impl.

Signed-off-by: Marvin Hansen <marvin.hansen@gmail.com>
Signed-off-by: Marvin Hansen <marvin.hansen@gmail.com>
Signed-off-by: Marvin Hansen <marvin.hansen@gmail.com>
…ckend specific implementations of tensor ops.

Signed-off-by: Marvin Hansen <marvin.hansen@gmail.com>
… for CPU.

Signed-off-by: Marvin Hansen <marvin.hansen@gmail.com>
…or CPU.

Signed-off-by: Marvin Hansen <marvin.hansen@gmail.com>
Signed-off-by: Marvin Hansen <marvin.hansen@gmail.com>
… impl.

Signed-off-by: Marvin Hansen <marvin.hansen@gmail.com>
Signed-off-by: Marvin Hansen <marvin.hansen@gmail.com>
Signed-off-by: Marvin Hansen <marvin.hansen@gmail.com>
…X execution.

Signed-off-by: Marvin Hansen <marvin.hansen@gmail.com>
….md.

Signed-off-by: Marvin Hansen <marvin.hansen@gmail.com>
Signed-off-by: Marvin Hansen <marvin.hansen@gmail.com>
…README.md.

Signed-off-by: Marvin Hansen <marvin.hansen@gmail.com>
Signed-off-by: Marvin Hansen <marvin.hansen@gmail.com>
…ability of the crate.

Signed-off-by: Marvin Hansen <marvin.hansen@gmail.com>
…ltiField type with MLX acceleration.

Signed-off-by: Marvin Hansen <marvin.hansen@gmail.com>
Signed-off-by: Marvin Hansen <marvin.hansen@gmail.com>
…tiField type.

Signed-off-by: Marvin Hansen <marvin.hansen@gmail.com>
…support.

Signed-off-by: Marvin Hansen <marvin.hansen@gmail.com>
Signed-off-by: Marvin Hansen <marvin.hansen@gmail.com>
… underlyuing tensor, multivector, and topology crates.

Signed-off-by: Marvin Hansen <marvin.hansen@gmail.com>
Signed-off-by: Marvin Hansen <marvin.hansen@gmail.com>
Signed-off-by: Marvin Hansen <marvin.hansen@gmail.com>
Signed-off-by: Marvin Hansen <marvin.hansen@gmail.com>
Signed-off-by: Marvin Hansen <marvin.hansen@gmail.com>
Signed-off-by: Marvin Hansen <marvin.hansen@gmail.com>
…iField.

Signed-off-by: Marvin Hansen <marvin.hansen@gmail.com>
…usalMultiFied.

Signed-off-by: Marvin Hansen <marvin.hansen@gmail.com>
… HKT physics

Signed-off-by: Marvin Hansen <marvin.hansen@gmail.com>
@qodo-code-review
Copy link
Copy Markdown
Contributor

qodo-code-review Bot commented Dec 31, 2025

PR Compliance Guide 🔍

Below is a summary of compliance checks for this PR:

Security Compliance
Panic-based DoS

Description: A reachable .unwrap() on tensor.get(&current_full_index) can panic and crash consumers
(denial-of-service) if internal indexing invariants are violated by unexpected
shapes/strides.
ein_sum_impl.rs [509-521]

Referred Code
for _ in 0..num_batch_elements {
    for i in 0..diag_len {
        let mut current_full_index = vec![0; tensor.ndim()];
        current_full_index[axis1] = i;
        current_full_index[axis2] = i;

        // Fill in batch indices
        for (j, &batch_axis) in batch_axes.iter().enumerate() {
            current_full_index[batch_axis] = current_batch_indices[j];
        }
        result_data.push(*tensor.get(&current_full_index).unwrap());
    }
Ticket Compliance
🟡
🎫 #432
Add an MLX feature flag to enable MLX acceleration for the topology crate.
Ensure MLX acceleration is disabled by default to keep compile times in check.
🟡
🎫 #428
Add MLX acceleration for the tensor crate, gated behind a feature flag.
Codebase Duplication Compliance
Codebase context is not defined

Follow the guide to enable codebase context checks.

Custom Compliance
🟢
Generic: Comprehensive Audit Trails

Objective: To create a detailed and reliable record of critical system actions for security analysis
and compliance.

Status: Passed

Learn more about managing compliance generic rules or creating your own custom rules

Generic: Meaningful Naming and Self-Documenting Code

Objective: Ensure all identifiers clearly express their purpose and intent, making code
self-documenting

Status: Passed

Learn more about managing compliance generic rules or creating your own custom rules

Generic: Secure Error Handling

Objective: To prevent the leakage of sensitive system information through error messages while
providing sufficient detail for internal debugging.

Status: Passed

Learn more about managing compliance generic rules or creating your own custom rules

Generic: Secure Logging Practices

Objective: To ensure logs are useful for debugging and auditing without exposing sensitive
information like PII, PHI, or cardholder data.

Status: Passed

Learn more about managing compliance generic rules or creating your own custom rules

Generic: Security-First Input Validation and Data Handling

Objective: Ensure all data inputs are validated, sanitized, and handled securely to prevent
vulnerabilities

Status: Passed

Learn more about managing compliance generic rules or creating your own custom rules

🔴
Generic: Robust Error Handling and Edge Case Management

Objective: Ensure comprehensive error handling that provides meaningful context and graceful
degradation

Status:
Unwrap may panic: New tensor operations use unwrap() on potentially fallible indexing (get() /
get_flat_index()), which can panic instead of returning contextual errors for edge cases.

Learn more about managing compliance generic rules or creating your own custom rules

  • Update
Compliance status legend 🟢 - Fully Compliant
🟡 - Partial Compliant
🔴 - Not Compliant
⚪ - Requires Further Human Verification
🏷️ - Compliance label

@qodo-code-review
Copy link
Copy Markdown
Contributor

qodo-code-review Bot commented Dec 31, 2025

PR Code Suggestions ✨

Latest suggestions up to be92279

CategorySuggestion                                                                                                                                    Impact
Incremental [*]
Respect closure in broadcast

Remove the incorrect fast-path in broadcast_op that hardcodes multiplication for
scalar-like tensors. This optimization ignores the provided closure _f, causing
silent bugs for any non-multiplication broadcast operations.

deep_causality_tensor/src/types/backend/mlx/mlx_backend_tensor.rs [147-158]

-// For the common case of scalar multiplication/scaling (used in Christoffel),
-// use native MLX multiply which supports broadcasting.
+// Even if `rhs` is a scalar, `broadcast_op` must respect the provided closure `_f`.
+// Fallback to CPU for generic closures.
 let rhs_shape = Self::shape(rhs);
-if rhs_shape.iter().all(|&s| s == 1) {
-    // Scalar broadcast - use native multiply
-    let res = mlx_rs::ops::multiply(lhs.as_array(), rhs.as_array()).map_err(|e| {
-        crate::CausalTensorError::MlxOperationFailed(format!(
-            "MlxBackend::broadcast_op scalar multiply failed: {e}"
-        ))
-    })?;
-    return Ok(MlxTensor::new(res));
-}
+let _ = rhs_shape; // keep for potential future optimizations

[To ensure code accuracy, apply this suggestion manually]

Suggestion importance[1-10]: 9

__

Why: This suggestion correctly identifies a critical bug where a generic broadcast_op function incorrectly hardcodes a multiplication operation, ignoring the provided closure _f and leading to silent incorrect results for other operations.

High
Fix dtype-safe epsilon handling

To ensure type safety and robust error handling, replace the hard-coded f32
epsilon with a type-inferred value using full_like, and propagate errors using ?
instead of panicking with .unwrap().

deep_causality_tensor/src/types/backend/mlx/mlx_backend_linear_algebra.rs [159-162]

-let eps_scalar = mlx_rs::Array::from_slice(&[1e-12f32], &[1]);
-let det_safe = add_(&det, &eps_scalar);
+let eps = mlx_rs::ops::full_like(&det, 1e-12).map_err(|e| e.to_string())?;
+let det_safe = add_(&det, &eps);
 
-let ones = mlx_rs::ops::ones_like(&det_safe).unwrap();
+let ones = mlx_rs::ops::ones_like(&det_safe).map_err(|e| e.to_string())?;

[To ensure code accuracy, apply this suggestion manually]

Suggestion importance[1-10]: 8

__

Why: The suggestion correctly identifies a potential dtype mismatch by using a hard-coded f32 for eps_scalar, which could fail if the input tensor is f64. It also improves error handling by replacing .unwrap() with proper Result propagation, enhancing the robustness of the explicit_inverse_4x4 helper function.

Medium
Fix stride type conversion

In the strides function, explicitly map and cast each stride element to usize to
match the return type and prevent a compilation error, as
tensor.as_array().strides() likely returns a slice of signed integers.

deep_causality_tensor/src/types/backend/mlx/mlx_backend_tensor.rs [127-129]

 fn strides<T>(tensor: &Self::Tensor<T>) -> Vec<usize> {
-    tensor.as_array().strides().to_vec()
+    tensor
+        .as_array()
+        .strides()
+        .iter()
+        .map(|&s| s as usize)
+        .collect()
 }

[To ensure code accuracy, apply this suggestion manually]

Suggestion importance[1-10]: 7

__

Why: The suggestion correctly identifies that strides() returns a slice of signed integers, and the direct to_vec() would cause a type mismatch with the function's Vec<usize> return type, fixing a compilation error.

Medium
Use true scalar tensor shape

In the scalar assignment arithmetic implementations (AddAssign, SubAssign,
etc.), create scalar tensors with an empty shape &[] instead of &[1]. This
ensures correct broadcasting semantics for true 0-D scalars.

deep_causality_tensor/src/types/backend_tensor/ops.rs [221-251]

 impl<B: TensorBackend> AddAssign<$t> for BackendTensor<$t, B> {
     fn add_assign(&mut self, rhs: $t) {
         // Create scalar tensor on same device, use broadcast add
-        let scalar_tensor = B::create(&[rhs], &[1]);
+        let scalar_tensor = B::create(&[rhs], &[]);
         self.inner = B::add(&self.inner, &scalar_tensor);
     }
 }
 
 impl<B: TensorBackend> SubAssign<$t> for BackendTensor<$t, B> {
     fn sub_assign(&mut self, rhs: $t) {
         // Create scalar tensor on same device, use broadcast sub
-        let scalar_tensor = B::create(&[rhs], &[1]);
+        let scalar_tensor = B::create(&[rhs], &[]);
         self.inner = B::sub(&self.inner, &scalar_tensor);
     }
 }
 
 impl<B: TensorBackend> MulAssign<$t> for BackendTensor<$t, B> {
     fn mul_assign(&mut self, rhs: $t) {
         // Create scalar tensor on same device, use broadcast mul
-        let scalar_tensor = B::create(&[rhs], &[1]);
+        let scalar_tensor = B::create(&[rhs], &[]);
         self.inner = B::mul(&self.inner, &scalar_tensor);
     }
 }
 
 impl<B: TensorBackend> DivAssign<$t> for BackendTensor<$t, B> {
     fn div_assign(&mut self, rhs: $t) {
         // Create scalar tensor on same device, use broadcast div
-        let scalar_tensor = B::create(&[rhs], &[1]);
+        let scalar_tensor = B::create(&[rhs], &[]);
         self.inner = B::div(&self.inner, &scalar_tensor);
     }
 }

[To ensure code accuracy, apply this suggestion manually]

Suggestion importance[1-10]: 6

__

Why: The suggestion correctly points out that using shape [] for a scalar is more precise than [1], improving the robustness and correctness of broadcasting across different backends.

Low
Possible issue
Allow nested operand ASTs

Restore recursive evaluation for EinSumAST operands in get_binary_operands to
support nested expressions, fixing a functional regression.

deep_causality_tensor/src/types/cpu_tensor/ops/tensor_ein_sum/ein_sum_impl.rs [25-45]

 let lhs = match children[0].value() {
     EinSumOp::TensorSource { tensor } => tensor.clone(),
-    _ => {
-        return Err(CausalTensorError::EinSumError(
-            EinSumValidationError::InvalidASTStructure {
-                message: "Expected TensorSource node for binary operand".to_string(),
-            },
-        ));
-    }
+    _ => Self::execute_ein_sum(&children[0])?,
 };
 
 let rhs = match children[1].value() {
     EinSumOp::TensorSource { tensor } => tensor.clone(),
-    _ => {
-        return Err(CausalTensorError::EinSumError(
-            EinSumValidationError::InvalidASTStructure {
-                message: "Expected TensorSource node for binary operand".to_string(),
-            },
-        ));
-    }
+    _ => Self::execute_ein_sum(&children[1])?,
 };

[To ensure code accuracy, apply this suggestion manually]

Suggestion importance[1-10]: 9

__

Why: The suggestion correctly identifies a significant functional regression where nested EinSumAST subtrees are no longer evaluated, breaking complex expressions.

High
Make type handling memory-safe

Replace unsound type_name() string comparisons with TypeId checks before using
unsafe transmutes to ensure memory safety. This prevents potential undefined
behavior from incorrect type assumptions.

deep_causality_tensor/src/types/backend/mlx/mlx_backend_tensor.rs [16-64]

-fn create_impl<T>(data: &[T], shape: &[usize]) -> MlxTensor<T> {
+fn create_impl<T: 'static>(data: &[T], shape: &[usize]) -> MlxTensor<T> {
+    use core::any::TypeId;
+
     let shape_i32: Vec<i32> = shape.iter().map(|&s| s as i32).collect();
-    let name = std::any::type_name::<T>();
 
-    if name == "f32" {
-        let slice = unsafe { std::mem::transmute::<&[T], &[f32]>(data) };
+    if TypeId::of::<T>() == TypeId::of::<f32>() {
+        let slice = unsafe {
+            core::slice::from_raw_parts(data.as_ptr() as *const f32, data.len())
+        };
         let array = mlx_rs::Array::from_slice(slice, &shape_i32);
         return MlxTensor::new(array);
     }
-    if name == "i32" {
-        let slice = unsafe { std::mem::transmute::<&[T], &[i32]>(data) };
+
+    if TypeId::of::<T>() == TypeId::of::<i32>() {
+        let slice = unsafe {
+            core::slice::from_raw_parts(data.as_ptr() as *const i32, data.len())
+        };
         let array = mlx_rs::Array::from_slice(slice, &shape_i32);
         return MlxTensor::new(array);
     }
-    if name == "f64" {
-        let slice = unsafe { std::mem::transmute::<&[T], &[f64]>(data) };
-        // Convert to f32
+
+    if TypeId::of::<T>() == TypeId::of::<f64>() {
+        let slice = unsafe {
+            core::slice::from_raw_parts(data.as_ptr() as *const f64, data.len())
+        };
         let data_f32: Vec<f32> = slice.iter().map(|&x| x as f32).collect();
         let array = mlx_rs::Array::from_slice(&data_f32, &shape_i32);
         return MlxTensor::new(array);
     }
 
-    panic!("MlxBackend::create: Unsupported type {}", name);
+    panic!("MlxBackend::create: Unsupported type");
 }
 
-fn to_vec_impl<T: Clone>(tensor: &MlxTensor<T>) -> Vec<T> {
+fn to_vec_impl<T: Clone + 'static>(tensor: &MlxTensor<T>) -> Vec<T> {
+    use core::any::TypeId;
+
     let array = tensor.as_array();
-    let name = std::any::type_name::<T>();
 
-    if name == "f32" {
+    if TypeId::of::<T>() == TypeId::of::<f32>() {
         let vals = array.as_slice::<f32>();
-        let vals_t = unsafe { std::mem::transmute::<&[f32], &[T]>(vals) };
+        let vals_t = unsafe { core::slice::from_raw_parts(vals.as_ptr() as *const T, vals.len()) };
         return vals_t.to_vec();
     }
 
-    if name == "f64" {
+    if TypeId::of::<T>() == TypeId::of::<f64>() {
         let vals = array.as_slice::<f32>();
         let vals_f64: Vec<f64> = vals.iter().map(|&x| x as f64).collect();
-        return unsafe { std::mem::transmute::<Vec<f64>, Vec<T>>(vals_f64) };
+        let out = unsafe {
+            // Sound because we just proved T == f64
+            core::mem::transmute::<Vec<f64>, Vec<T>>(vals_f64)
+        };
+        return out;
     }
 
-    if name == "i32" {
+    if TypeId::of::<T>() == TypeId::of::<i32>() {
         let vals = array.as_slice::<i32>();
-        let vals_t = unsafe { std::mem::transmute::<&[i32], &[T]>(vals) };
+        let vals_t = unsafe { core::slice::from_raw_parts(vals.as_ptr() as *const T, vals.len()) };
         return vals_t.to_vec();
     }
 
-    panic!("MlxBackend::to_vec: Unsupported type {}", name);
+    panic!("MlxBackend::to_vec: Unsupported type");
 }

[To ensure code accuracy, apply this suggestion manually]

Suggestion importance[1-10]: 9

__

Why: The suggestion correctly identifies that using type_name for transmute is unsound and can cause undefined behavior, which is a critical memory safety issue. Replacing it with TypeId checks is the proper and robust way to handle this in Rust.

High
Align zeros/ones dtype with T

Fix the zeros and ones functions to create tensors with the correct data type T
instead of always defaulting to f32. This prevents a type mismatch that could
lead to undefined behavior.

deep_causality_tensor/src/types/backend/mlx/mlx_backend_tensor.rs [169-181]

-fn zeros<T: TensorData>(shape: &[usize]) -> Self::Tensor<T> {
+fn zeros<T: TensorData + 'static>(shape: &[usize]) -> Self::Tensor<T> {
+    use core::any::TypeId;
+
     let mlx_shape: Vec<i32> = shape.iter().map(|&s| s as i32).collect();
-    // Ignoring dtype for now to fix compat
-    let array = mlx_rs::ops::zeros::<f32>(&mlx_shape).expect("MlxBackend::zeros: failed");
-    MlxTensor::new(array)
+
+    if TypeId::of::<T>() == TypeId::of::<f32>() {
+        return MlxTensor::new(mlx_rs::ops::zeros::<f32>(&mlx_shape).expect("MlxBackend::zeros: failed"));
+    }
+    if TypeId::of::<T>() == TypeId::of::<i32>() {
+        return MlxTensor::new(mlx_rs::ops::zeros::<i32>(&mlx_shape).expect("MlxBackend::zeros: failed"));
+    }
+    if TypeId::of::<T>() == TypeId::of::<f64>() {
+        // Stored as f32 for MLX; keep consistent with create/to_vec f64 path.
+        return MlxTensor::new(mlx_rs::ops::zeros::<f32>(&mlx_shape).expect("MlxBackend::zeros: failed"));
+    }
+
+    panic!("MlxBackend::zeros: unsupported dtype");
 }
 
-fn ones<T: TensorData>(shape: &[usize]) -> Self::Tensor<T> {
+fn ones<T: TensorData + 'static>(shape: &[usize]) -> Self::Tensor<T> {
+    use core::any::TypeId;
+
     let mlx_shape: Vec<i32> = shape.iter().map(|&s| s as i32).collect();
-    // Ignoring dtype
-    let array = mlx_rs::ops::ones::<f32>(&mlx_shape).expect("MlxBackend::ones: failed");
-    MlxTensor::new(array)
+
+    if TypeId::of::<T>() == TypeId::of::<f32>() {
+        return MlxTensor::new(mlx_rs::ops::ones::<f32>(&mlx_shape).expect("MlxBackend::ones: failed"));
+    }
+    if TypeId::of::<T>() == TypeId::of::<i32>() {
+        return MlxTensor::new(mlx_rs::ops::ones::<i32>(&mlx_shape).expect("MlxBackend::ones: failed"));
+    }
+    if TypeId::of::<T>() == TypeId::of::<f64>() {
+        // Stored as f32 for MLX; keep consistent with create/to_vec f64 path.
+        return MlxTensor::new(mlx_rs::ops::ones::<f32>(&mlx_shape).expect("MlxBackend::ones: failed"));
+    }
+
+    panic!("MlxBackend::ones: unsupported dtype");
 }

[To ensure code accuracy, apply this suggestion manually]

Suggestion importance[1-10]: 9

__

Why: The suggestion correctly identifies a critical bug where zeros and ones create an f32 tensor but return it as MlxTensor<T>, leading to a type mismatch and likely undefined behavior later. The fix ensures type correctness for supported types and panics for unsupported ones.

High
Prevent reshape size overflows

Prevent potential integer overflow when calculating tensor sizes by computing
the product as usize and then using a checked conversion to i32 with an explicit
error message.

deep_causality_tensor/src/types/backend/mlx/mlx_backend_linear_algebra.rs [91-92]

-let size_a: i32 = shape_a.iter().product();
-let size_b: i32 = shape_b.iter().product();
+let size_a_usize: usize = shape_a.iter().product();
+let size_b_usize: usize = shape_b.iter().product();
+let size_a: i32 = size_a_usize
+    .try_into()
+    .expect("MlxBackend::tensor_product: size_a overflows i32");
+let size_b: i32 = size_b_usize
+    .try_into()
+    .expect("MlxBackend::tensor_product: size_b overflows i32");

[To ensure code accuracy, apply this suggestion manually]

Suggestion importance[1-10]: 8

__

Why: The suggestion correctly identifies a potential integer overflow when calculating tensor sizes, which could lead to a panic, and provides a robust fix using checked conversion.

Medium
Remove panics from CPU fallback

Replace unwrap() with proper error handling in the CPU fallback path for
broadcast_op. This prevents potential panics and allows errors to be propagated
to the caller.

deep_causality_tensor/src/types/backend/mlx/mlx_backend_tensor.rs [161-167]

 // Fallback to CPU for general closures
-let cpu_lhs = crate::InternalCpuTensor::new(Self::to_vec(lhs), Self::shape(lhs)).unwrap();
-let cpu_rhs = crate::InternalCpuTensor::new(Self::to_vec(rhs), Self::shape(rhs)).unwrap();
+let cpu_lhs = crate::InternalCpuTensor::new(Self::to_vec(lhs), Self::shape(lhs))
+    .map_err(|e| crate::CausalTensorError::MlxOperationFailed(format!("MlxBackend::broadcast_op cpu lhs build failed: {e}")))?;
+let cpu_rhs = crate::InternalCpuTensor::new(Self::to_vec(rhs), Self::shape(rhs))
+    .map_err(|e| crate::CausalTensorError::MlxOperationFailed(format!("MlxBackend::broadcast_op cpu rhs build failed: {e}")))?;
 
 use crate::CpuBackend;
 let result = CpuBackend::broadcast_op(&cpu_lhs, &cpu_rhs, _f)?;
 Ok(Self::create(result.as_slice(), result.shape()))

[To ensure code accuracy, apply this suggestion manually]

Suggestion importance[1-10]: 7

__

Why: The suggestion correctly points out that using unwrap() can lead to unrecoverable panics, which is poor practice for library code. Replacing it with proper error propagation using ? improves the robustness and usability of the function.

Medium
Fix spacing mapping compilation

Explicitly clone the elements of dx_a before passing them to the mapping
function f to support non-Copy types as required by the A: Clone trait bound.

deep_causality_multivector/src/extensions/hkt_multifield/mod.rs [101-102]

-let dx_new = [f(dx_a[0]), f(dx_a[1]), f(dx_a[2])];
+let dx_new = [
+    f(dx_a[0].clone()),
+    f(dx_a[1].clone()),
+    f(dx_a[2].clone()),
+];
 CausalMultiField::from_coefficients(&new_mvs, shape, dx_new)

[To ensure code accuracy, apply this suggestion manually]

Suggestion importance[1-10]: 7

__

Why: The suggestion correctly identifies a compile error for non-Copy types due to moving from a borrowed context, and the proposed fix using .clone() is correct and necessary.

Medium
Return a cloned extracted value

Clone the extracted value from mvs[0].data[0] before returning it to ensure the
function compiles for non-Copy types that only implement Clone.

deep_causality_multivector/src/extensions/hkt_multifield/mod.rs [275-276]

 let mvs = fa.to_coefficients();
-mvs[0].data[0]
+mvs[0].data[0].clone()

[To ensure code accuracy, apply this suggestion manually]

Suggestion importance[1-10]: 7

__

Why: The suggestion correctly identifies a compile error where a value is moved from a borrowed context, which fails for non-Copy types, and the fix using .clone() is appropriate.

Medium
General
Avoid panics on invalid indexing

Replace .unwrap() with ok_or_else to propagate an error instead of panicking on
invalid tensor indexing.

deep_causality_tensor/src/types/cpu_tensor/ops/tensor_ein_sum/ein_sum_impl.rs [522]

-result_data.push(*tensor.get(&current_full_index).unwrap());
+result_data.push(
+    *tensor.get(&current_full_index).ok_or_else(|| {
+        CausalTensorError::EinSumError(EinSumValidationError::InvalidAxesSpecification {
+            message: "Internal error: diagonal index out of bounds".to_string(),
+        })
+    })?,
+);

[To ensure code accuracy, apply this suggestion manually]

Suggestion importance[1-10]: 7

__

Why: The suggestion correctly proposes replacing .unwrap() with proper error handling to prevent potential panics, which aligns with the function's return type and improves robustness.

Medium
  • More

Previous suggestions

Suggestions up to commit be92279
CategorySuggestion                                                                                                                                    Impact
Possible issue
Remove brittle unsafe type dispatch

Replace the brittle type_name string comparison with TypeId for runtime type
dispatch to improve safety and prevent potential undefined behavior.

deep_causality_tensor/src/types/backend/mlx/mlx_backend_tensor.rs [16-64]

-fn create_impl<T>(data: &[T], shape: &[usize]) -> MlxTensor<T> {
+fn create_impl<T: 'static>(data: &[T], shape: &[usize]) -> MlxTensor<T> {
+    use core::any::TypeId;
+
     let shape_i32: Vec<i32> = shape.iter().map(|&s| s as i32).collect();
-    let name = std::any::type_name::<T>();
 
-    if name == "f32" {
-        let slice = unsafe { std::mem::transmute::<&[T], &[f32]>(data) };
+    if TypeId::of::<T>() == TypeId::of::<f32>() {
+        let slice = unsafe { core::slice::from_raw_parts(data.as_ptr() as *const f32, data.len()) };
         let array = mlx_rs::Array::from_slice(slice, &shape_i32);
         return MlxTensor::new(array);
     }
-    if name == "i32" {
-        let slice = unsafe { std::mem::transmute::<&[T], &[i32]>(data) };
+
+    if TypeId::of::<T>() == TypeId::of::<i32>() {
+        let slice = unsafe { core::slice::from_raw_parts(data.as_ptr() as *const i32, data.len()) };
         let array = mlx_rs::Array::from_slice(slice, &shape_i32);
         return MlxTensor::new(array);
     }
-    if name == "f64" {
-        let slice = unsafe { std::mem::transmute::<&[T], &[f64]>(data) };
-        // Convert to f32
+
+    if TypeId::of::<T>() == TypeId::of::<f64>() {
+        let slice = unsafe { core::slice::from_raw_parts(data.as_ptr() as *const f64, data.len()) };
         let data_f32: Vec<f32> = slice.iter().map(|&x| x as f32).collect();
         let array = mlx_rs::Array::from_slice(&data_f32, &shape_i32);
         return MlxTensor::new(array);
     }
 
-    panic!("MlxBackend::create: Unsupported type {}", name);
+    panic!("MlxBackend::create: Unsupported type {}", core::any::type_name::<T>());
 }
 
-fn to_vec_impl<T: Clone>(tensor: &MlxTensor<T>) -> Vec<T> {
+fn to_vec_impl<T: Clone + 'static>(tensor: &MlxTensor<T>) -> Vec<T> {
+    use core::any::TypeId;
+
     let array = tensor.as_array();
-    let name = std::any::type_name::<T>();
 
-    if name == "f32" {
+    if TypeId::of::<T>() == TypeId::of::<f32>() {
         let vals = array.as_slice::<f32>();
-        let vals_t = unsafe { std::mem::transmute::<&[f32], &[T]>(vals) };
-        return vals_t.to_vec();
+        let out = vals.to_vec();
+        return unsafe { core::mem::transmute::<Vec<f32>, Vec<T>>(out) };
     }
 
-    if name == "f64" {
+    if TypeId::of::<T>() == TypeId::of::<f64>() {
         let vals = array.as_slice::<f32>();
-        let vals_f64: Vec<f64> = vals.iter().map(|&x| x as f64).collect();
-        return unsafe { std::mem::transmute::<Vec<f64>, Vec<T>>(vals_f64) };
+        let out: Vec<f64> = vals.iter().map(|&x| x as f64).collect();
+        return unsafe { core::mem::transmute::<Vec<f64>, Vec<T>>(out) };
     }
 
-    if name == "i32" {
+    if TypeId::of::<T>() == TypeId::of::<i32>() {
         let vals = array.as_slice::<i32>();
-        let vals_t = unsafe { std::mem::transmute::<&[i32], &[T]>(vals) };
-        return vals_t.to_vec();
+        let out = vals.to_vec();
+        return unsafe { core::mem::transmute::<Vec<i32>, Vec<T>>(out) };
     }
 
-    panic!("MlxBackend::to_vec: Unsupported type {}", name);
+    panic!("MlxBackend::to_vec: Unsupported type {}", core::any::type_name::<T>());
 }
Suggestion importance[1-10]: 9

__

Why: The suggestion correctly identifies that using type_name for type dispatch is brittle and replaces it with the more robust TypeId, preventing potential undefined behavior.

High
Ensure correct dtype allocation

Fix a critical bug in zeros and ones by allocating arrays with the correct data
type T instead of always using f32, preventing undefined behavior.

deep_causality_tensor/src/types/backend/mlx/mlx_backend_tensor.rs [169-181]

-fn zeros<T: TensorData>(shape: &[usize]) -> Self::Tensor<T> {
+fn zeros<T: TensorData + 'static>(shape: &[usize]) -> Self::Tensor<T> {
+    use core::any::TypeId;
+
     let mlx_shape: Vec<i32> = shape.iter().map(|&s| s as i32).collect();
-    // Ignoring dtype for now to fix compat
-    let array = mlx_rs::ops::zeros::<f32>(&mlx_shape).expect("MlxBackend::zeros: failed");
-    MlxTensor::new(array)
+
+    if TypeId::of::<T>() == TypeId::of::<f32>() || TypeId::of::<T>() == TypeId::of::<f64>() {
+        let array = mlx_rs::ops::zeros::<f32>(&mlx_shape).expect("MlxBackend::zeros: failed");
+        return MlxTensor::new(array);
+    }
+
+    if TypeId::of::<T>() == TypeId::of::<i32>() {
+        let array = mlx_rs::ops::zeros::<i32>(&mlx_shape).expect("MlxBackend::zeros: failed");
+        return MlxTensor::new(array);
+    }
+
+    panic!("MlxBackend::zeros: Unsupported type {}", core::any::type_name::<T>());
 }
 
-fn ones<T: TensorData>(shape: &[usize]) -> Self::Tensor<T> {
+fn ones<T: TensorData + 'static>(shape: &[usize]) -> Self::Tensor<T> {
+    use core::any::TypeId;
+
     let mlx_shape: Vec<i32> = shape.iter().map(|&s| s as i32).collect();
-    // Ignoring dtype
-    let array = mlx_rs::ops::ones::<f32>(&mlx_shape).expect("MlxBackend::ones: failed");
-    MlxTensor::new(array)
+
+    if TypeId::of::<T>() == TypeId::of::<f32>() || TypeId::of::<T>() == TypeId::of::<f64>() {
+        let array = mlx_rs::ops::ones::<f32>(&mlx_shape).expect("MlxBackend::ones: failed");
+        return MlxTensor::new(array);
+    }
+
+    if TypeId::of::<T>() == TypeId::of::<i32>() {
+        let array = mlx_rs::ops::ones::<i32>(&mlx_shape).expect("MlxBackend::ones: failed");
+        return MlxTensor::new(array);
+    }
+
+    panic!("MlxBackend::ones: Unsupported type {}", core::any::type_name::<T>());
 }
Suggestion importance[1-10]: 9

__

Why: This is a critical correctness fix, as the original code creates a tensor of a fixed type (f32) regardless of the generic type T, which would lead to undefined behavior.

High
Reject duplicate contraction axes

Add validation to the contract function to ensure that lhs_axes and rhs_axes do
not contain duplicate axes, preventing incorrect calculations.

deep_causality_tensor/src/types/cpu_tensor/ops/tensor_ein_sum/ein_sum_impl.rs [200-211]

-for (&lhs_ax, &rhs_ax) in lhs_axes.iter().zip(rhs_axes.iter()) {
-    if lhs.shape[lhs_ax] != rhs.shape[rhs_ax] {
-        return Err(CausalTensorError::EinSumError(
-            EinSumValidationError::ShapeMismatch {
-                message: format!(
-                    "Contraction dimension mismatch: lhs axis {} (dim {}) vs rhs axis {} (dim {})",
-                    lhs_ax, lhs.shape[lhs_ax], rhs_ax, rhs.shape[rhs_ax]
-                ),
-            },
-        ));
+// Validate contracted dimensions match and axes are unique
+{
+    let mut seen_lhs = std::collections::HashSet::with_capacity(lhs_axes.len());
+    let mut seen_rhs = std::collections::HashSet::with_capacity(rhs_axes.len());
+
+    for (&lhs_ax, &rhs_ax) in lhs_axes.iter().zip(rhs_axes.iter()) {
+        if !seen_lhs.insert(lhs_ax) {
+            return Err(CausalTensorError::EinSumError(
+                EinSumValidationError::InvalidAxesSpecification {
+                    message: format!("Duplicate LHS contraction axis: {lhs_ax}"),
+                },
+            ));
+        }
+        if !seen_rhs.insert(rhs_ax) {
+            return Err(CausalTensorError::EinSumError(
+                EinSumValidationError::InvalidAxesSpecification {
+                    message: format!("Duplicate RHS contraction axis: {rhs_ax}"),
+                },
+            ));
+        }
+
+        if lhs.shape[lhs_ax] != rhs.shape[rhs_ax] {
+            return Err(CausalTensorError::EinSumError(
+                EinSumValidationError::ShapeMismatch {
+                    message: format!(
+                        "Contraction dimension mismatch: lhs axis {} (dim {}) vs rhs axis {} (dim {})",
+                        lhs_ax, lhs.shape[lhs_ax], rhs_ax, rhs.shape[rhs_ax]
+                    ),
+                },
+            ));
+        }
     }
 }
Suggestion importance[1-10]: 8

__

Why: This suggestion addresses a critical correctness issue by adding validation for duplicate contraction axes, which would otherwise lead to incorrect results.

Medium
Propagate CPU fallback failures

Replace unwrap() with the ? operator in broadcast_op to properly propagate
errors from InternalCpuTensor::new instead of panicking.

deep_causality_tensor/src/types/backend/mlx/mlx_backend_tensor.rs [161-166]

 // Fallback to CPU for general closures
-let cpu_lhs = crate::InternalCpuTensor::new(Self::to_vec(lhs), Self::shape(lhs)).unwrap();
-let cpu_rhs = crate::InternalCpuTensor::new(Self::to_vec(rhs), Self::shape(rhs)).unwrap();
+let cpu_lhs = crate::InternalCpuTensor::new(Self::to_vec(lhs), Self::shape(lhs))?;
+let cpu_rhs = crate::InternalCpuTensor::new(Self::to_vec(rhs), Self::shape(rhs))?;
 
 use crate::CpuBackend;
 let result = CpuBackend::broadcast_op(&cpu_lhs, &cpu_rhs, _f)?;
 Ok(Self::create(result.as_slice(), result.shape()))
Suggestion importance[1-10]: 8

__

Why: The suggestion correctly replaces unwrap() with error propagation, preventing a potential panic and allowing for graceful error handling, which is a significant improvement in robustness.

Medium
Match epsilon dtype to input

To prevent type mismatches in the generic explicit_inverse_4x4 function, create
the epsilon tensor using mlx_rs::ops::full_like to ensure it matches the data
type of the input tensor det.

deep_causality_tensor/src/types/backend/mlx/mlx_backend_linear_algebra.rs [159-160]

-let eps_scalar = mlx_rs::Array::from_slice(&[1e-12f32], &[1]);
-let det_safe = add_(&det, &eps_scalar);
+let eps = mlx_rs::ops::full_like(&det, 1e-12).map_err(|e| e.to_string())?;
+let det_safe = add_(&det, &eps).map_err(|e| e.to_string())?;
Suggestion importance[1-10]: 8

__

Why: The suggestion correctly identifies that hardcoding f32 for eps_scalar will cause a type mismatch when the generic function is used with other floating-point types like f64, leading to a panic. The proposed fix using full_like makes the implementation robust and generic.

Medium
Remove panic from diagonal indexing

In the diagonal function, replace the .unwrap() call with proper error handling
to prevent potential panics on out-of-bounds indexing.

deep_causality_tensor/src/types/cpu_tensor/ops/tensor_ein_sum/ein_sum_impl.rs [522]

-result_data.push(*tensor.get(&current_full_index).unwrap());
+let v = tensor.get(&current_full_index).ok_or_else(|| {
+    CausalTensorError::EinSumError(EinSumValidationError::InvalidAxesSpecification {
+        message: "Internal error: diagonal index out of bounds".to_string(),
+    })
+})?;
+result_data.push(*v);
Suggestion importance[1-10]: 7

__

Why: The suggestion correctly proposes replacing a potential panic with proper error handling, which improves the robustness and reliability of the diagonal function.

Medium
Avoid moving out of borrow

To prevent a compile error when the generic type A is not Copy, explicitly clone
the elements of dx_a before passing them to the function f.

deep_causality_multivector/src/extensions/hkt_multifield/mod.rs [101-102]

-let dx_new = [f(dx_a[0]), f(dx_a[1]), f(dx_a[2])];
+let dx_new = [f(dx_a[0].clone()), f(dx_a[1].clone()), f(dx_a[2].clone())];
 CausalMultiField::from_coefficients(&new_mvs, shape, dx_new)
Suggestion importance[1-10]: 7

__

Why: The suggestion correctly identifies a move out of a borrowed context that would cause a compilation error for non-Copy types, and the proposed fix using clone() is accurate.

Medium
Clone instead of moving value

To fix a compile error, clone the value from mvs[0].data[0] instead of
attempting to move it, as indexing a Vec returns a reference.

deep_causality_multivector/src/extensions/hkt_multifield/mod.rs [275-276]

 let mvs = fa.to_coefficients();
-mvs[0].data[0]
+mvs[0].data[0].clone()
Suggestion importance[1-10]: 7

__

Why: The suggestion correctly points out that indexing mvs[0].data returns a reference, and moving the value is not allowed. The proposed fix using clone() is necessary for the function to compile with non-Copy types.

Medium
Use consistent scalar shape

Remove the special handling for scalar results in the contract function to
ensure scalars are consistently represented with an empty shape [] instead of
[1].

deep_causality_tensor/src/types/cpu_tensor/ops/tensor_ein_sum/ein_sum_impl.rs [225-227]

-// Handle scalar result case
-if result_shape.is_empty() {
-    result_shape.push(1); // Scalar represented as 1-element tensor
-}
+// If all axes are contracted, the result is a scalar (rank-0 tensor).
+// Represent scalars with an empty shape `[]` (with data length 1).
+// (No special-casing needed here; `product()` over an empty iterator yields 1.)
Suggestion importance[1-10]: 6

__

Why: The suggestion correctly identifies an inconsistency in scalar representation between the contract and trace functions, which could cause downstream logic errors.

Low
✅ Suggestions up to commit 3a2a365
CategorySuggestion                                                                                                                                    Impact
Possible issue
Return errors instead of panics
Suggestion Impact:The commit replaced the `.expect("... multiply failed")` on `mlx_rs::ops::multiply(...)` with `.map_err(...)` that converts the MLX error into `CausalTensorError::MlxOperationFailed` and propagates it via `?`, preventing a panic. The commit also applied similar non-panicking error handling to another MLX op (`transpose_axes`).

code diff:

-            let res = mlx_rs::ops::multiply(lhs.as_array(), rhs.as_array())
-                .expect("MlxBackend::broadcast_op: multiply failed");
+            let res = mlx_rs::ops::multiply(lhs.as_array(), rhs.as_array()).map_err(|e| {
+                crate::CausalTensorError::MlxOperationFailed(format!(
+                    "MlxBackend::broadcast_op scalar multiply failed: {e}"
+                ))
+            })?;
             return Ok(MlxTensor::new(res));

Replace .expect() with proper error handling using .map_err() to prevent panics
and return a CausalTensorError on failure.

deep_causality_tensor/src/types/backend/mlx/mlx_backend_tensor.rs [157-159]

-let res = mlx_rs::ops::multiply(lhs.as_array(), rhs.as_array())
-    .expect("MlxBackend::broadcast_op: multiply failed");
+let res = mlx_rs::ops::multiply(lhs.as_array(), rhs.as_array()).map_err(|e| {
+    crate::CausalTensorError::MlxOperationFailed(format!(
+        "MlxBackend::broadcast_op scalar multiply failed: {e}"
+    ))
+})?;
 return Ok(MlxTensor::new(res));
Suggestion importance[1-10]: 8

__

Why: The suggestion correctly identifies that using .expect() can cause a panic, which is a bug in a library function that is designed to return a Result. The proposed change correctly handles the error and propagates it, significantly improving the robustness of the code.

Medium
Validate axis and propagate errors
Suggestion Impact:The commit updated MlxBackend::stack to cast ndim to usize, added an axis out-of-bounds check, and replaced transpose_axes(...).expect(...) with map_err + ? error propagation (matching the suggestion’s intent). It also applied a similar expect->map_err change in broadcast_op's scalar multiply path.

code diff:

@@ -262,12 +260,20 @@
         }
 
         // Move axis 0 to the target axis
-        let ndim = stacked.ndim();
+        let ndim = stacked.ndim() as usize;
+        if axis >= ndim {
+            return Err(crate::CausalTensorError::MlxOperationFailed(format!(
+                "MlxBackend::stack: axis {axis} out of bounds for stacked ndim {ndim}"
+            )));
+        }
         let mut axes: Vec<i32> = (1..ndim as i32).collect();
         axes.insert(axis, 0);
 
-        let transposed = mlx_rs::ops::transpose_axes(&stacked, &axes)
-            .expect("MlxBackend::stack: transpose failed");
+        let transposed = mlx_rs::ops::transpose_axes(&stacked, &axes).map_err(|e| {
+            crate::CausalTensorError::MlxOperationFailed(format!(
+                "MlxBackend::stack transpose failed: {e}"
+            ))
+        })?;
         Ok(MlxTensor::new(transposed))

Add a bounds check for the axis parameter and replace .expect() with error
propagation to prevent potential panics in the stack function.

deep_causality_tensor/src/types/backend/mlx/mlx_backend_tensor.rs [257-271]

 let stacked = mlx_rs::ops::stack(&arrays)
     .map_err(|e| crate::CausalTensorError::MlxOperationFailed(e.to_string()))?;
 
 if axis == 0 {
     return Ok(MlxTensor::new(stacked));
 }
 
 // Move axis 0 to the target axis
-let ndim = stacked.ndim();
+let ndim = stacked.ndim() as usize;
+if axis >= ndim {
+    return Err(crate::CausalTensorError::MlxOperationFailed(format!(
+        "MlxBackend::stack: axis {axis} out of bounds for stacked ndim {ndim}"
+    )));
+}
+
 let mut axes: Vec<i32> = (1..ndim as i32).collect();
 axes.insert(axis, 0);
 
-let transposed = mlx_rs::ops::transpose_axes(&stacked, &axes)
-    .expect("MlxBackend::stack: transpose failed");
+let transposed = mlx_rs::ops::transpose_axes(&stacked, &axes).map_err(|e| {
+    crate::CausalTensorError::MlxOperationFailed(format!(
+        "MlxBackend::stack: transpose failed: {e}"
+    ))
+})?;
 Ok(MlxTensor::new(transposed))
Suggestion importance[1-10]: 8

__

Why: The suggestion correctly identifies two potential panics: one from an out-of-bounds axis and another from using .expect(). The proposed changes add necessary bounds checking and proper error propagation, which prevents crashes and makes the function more robust.

Medium
Decouple spacing from payload value
Suggestion Impact:The commit added the deep_causality_num::One bound to T and changed pure() to pass [T::one(), T::one(), T::one()] as spacing instead of [value, value, value], including a clarifying comment about preventing division-by-zero.

code diff:

@@ -115,6 +115,7 @@
     where
         T: TensorData
             + Clone
+            + deep_causality_num::One
             + std::ops::AddAssign
             + std::ops::SubAssign
             + std::ops::Neg<Output = T>
@@ -125,7 +126,8 @@
             data: vec![value],
             metric: Metric::Euclidean(0),
         };
-        CausalMultiField::from_coefficients(&[mv], [1, 1, 1], [value, value, value])
+        // Use T::one() for spacing instead of value to prevent division-by-zero
+        CausalMultiField::from_coefficients(&[mv], [1, 1, 1], [T::one(), T::one(), T::one()])
     }

In the pure function, set the spatial spacing dx to a default of [T::one(),
T::one(), T::one()] instead of the input value to prevent potential
division-by-zero errors.

deep_causality_multivector/src/extensions/hkt_multifield/mod.rs [114-129]

 pub fn pure<T>(value: T) -> CausalMultiField<B, T>
 where
     T: TensorData
         + Clone
+        + deep_causality_num::One
         + std::ops::AddAssign
         + std::ops::SubAssign
         + std::ops::Neg<Output = T>
         + std::ops::Div<Output = T>,
     B: crate::types::multifield::gamma::GammaProvider<T>,
 {
     let mv = CausalMultiVector {
         data: vec![value],
         metric: Metric::Euclidean(0),
     };
-    CausalMultiField::from_coefficients(&[mv], [1, 1, 1], [value, value, value])
+    CausalMultiField::from_coefficients(&[mv], [1, 1, 1], [T::one(), T::one(), T::one()])
 }
Suggestion importance[1-10]: 8

__

Why: The suggestion correctly identifies a critical issue where using the input value for dx can lead to division-by-zero errors in downstream differential operators, and proposes a robust fix.

Medium
Preserve grid spacing through apply

In the apply function, preserve the spatial spacing dx from the input field f_a
instead of resetting it to [NewT::default(); 3] to maintain grid integrity.

deep_causality_multivector/src/extensions/hkt_multifield/mod.rs [137-190]

 pub fn apply<A, NewT, Func>(
     f_ab: CausalMultiField<B, Func>,
     f_a: CausalMultiField<B, A>,
 ) -> CausalMultiField<B, NewT>
 where
     Func: FnMut(A) -> NewT
         + TensorData
         + Clone
         + std::ops::AddAssign
         + std::ops::SubAssign
         + std::ops::Neg<Output = Func>
         + std::ops::Div<Output = Func>,
     A: Clone
         + TensorData
         + std::ops::AddAssign
         + std::ops::SubAssign
         + std::ops::Neg<Output = A>
         + std::ops::Div<Output = A>,
     NewT: TensorData
         + Clone
+        + From<A>
         + std::ops::AddAssign
         + std::ops::SubAssign
         + std::ops::Neg<Output = NewT>
         + std::ops::Div<Output = NewT>,
     B: crate::types::multifield::gamma::GammaProvider<Func>
         + crate::types::multifield::gamma::GammaProvider<A>
         + crate::types::multifield::gamma::GammaProvider<NewT>,
 {
     assert_eq!(f_ab.shape(), f_a.shape(), "Shape mismatch in apply");
+    assert_eq!(f_ab.dx(), f_a.dx(), "dx mismatch in apply");
 
     let funcs = f_ab.to_coefficients();
     let values = f_a.to_coefficients();
     let shape = *f_a.shape();
+    let dx_a = *f_a.dx();
 
     let new_mvs: Vec<_> = funcs
         .into_iter()
         .zip(values)
         .map(|(func_mv, val_mv)| {
             let new_data: Vec<NewT> = func_mv
                 .data
                 .into_iter()
                 .zip(val_mv.data)
                 .map(|(mut func, val)| func(val))
                 .collect();
             CausalMultiVector {
                 data: new_data,
                 metric: val_mv.metric,
             }
         })
         .collect();
 
-    let dx_new = [NewT::default(); 3];
+    let dx_new = [NewT::from(dx_a[0].clone()), NewT::from(dx_a[1].clone()), NewT::from(dx_a[2].clone())];
     CausalMultiField::from_coefficients(&new_mvs, shape, dx_new)
 }
Suggestion importance[1-10]: 8

__

Why: The suggestion correctly identifies that dx is reset to its default, which is likely zero, potentially causing division-by-zero errors. Preserving the grid spacing is crucial for correctness.

Medium
Stabilize inversion near singularities
Suggestion Impact:The commit modified inv_2x2 to add an epsilon (1e-12) to the determinant (det_safe) and then invert det_safe instead of det, preventing inf/NaN on near-singular blocks. The implementation uses a scalar Array broadcast rather than full_like, but the intent matches the suggestion.

code diff:

@@ -153,8 +153,14 @@
     // Inverse 2x2 Helper: (a,b,c,d) -> (oa, ob, oc, od)
     let inv_2x2 = |a: &Array, b: &Array, c: &Array, d: &Array| -> (Array, Array, Array, Array) {
         let det = sub_(&mul_(a, d), &mul_(b, c));
-        let ones = mlx_rs::ops::ones_like(&det).unwrap();
-        let inv_det = div_(&ones, &det);
+
+        // Prevent inf/NaN on near-singular blocks with small epsilon
+        // Create epsilon scalar and add via broadcast
+        let eps_scalar = mlx_rs::Array::from_slice(&[1e-12f32], &[1]);
+        let det_safe = add_(&det, &eps_scalar);
+
+        let ones = mlx_rs::ops::ones_like(&det_safe).unwrap();
+        let inv_det = div_(&ones, &det_safe);
 

In the inv_2x2 helper, add a small epsilon to the determinant before inversion
to prevent division by zero and improve numerical stability for near-singular
matrices.

deep_causality_tensor/src/types/backend/mlx/mlx_backend_linear_algebra.rs [154-168]

 let inv_2x2 = |a: &Array, b: &Array, c: &Array, d: &Array| -> (Array, Array, Array, Array) {
     let det = sub_(&mul_(a, d), &mul_(b, c));
-    let ones = mlx_rs::ops::ones_like(&det).unwrap();
-    let inv_det = div_(&ones, &det);
+
+    // Prevent inf/NaN on near-singular blocks
+    let eps = mlx_rs::ops::full_like(&det, 1e-12).unwrap();
+    let det_safe = add_(&det, &eps);
+
+    let ones = mlx_rs::ops::ones_like(&det_safe).unwrap();
+    let inv_det = div_(&ones, &det_safe);
 
     let zero = sub_(&det, &det);
     let neg_b = sub_(&zero, b);
     let neg_c = sub_(&zero, c);
 
     let oa = mul_(&inv_det, d);
     let ob = mul_(&inv_det, &neg_b);
     let oc = mul_(&inv_det, &neg_c);
     let od = mul_(&inv_det, a);
     (oa, ob, oc, od)
 };
Suggestion importance[1-10]: 8

__

Why: The suggestion correctly identifies a potential division-by-zero issue when inverting a singular or near-singular matrix, which would produce NaN/inf values, and proposes a standard numerical stability fix.

Medium
Incremental [*]
Replace panics with error returns

In the contract function, replace expect() calls with Result error propagation
to avoid panics on potential index out-of-bounds errors.

deep_causality_tensor/src/types/cpu_tensor/ops/tensor_ein_sum/ein_sum_impl.rs [286-292]

-let lhs_val = lhs
-    .get(&lhs_index)
-    .expect("Internal error: lhs index out of bounds in contraction");
-let rhs_val = rhs
-    .get(&rhs_index)
-    .expect("Internal error: rhs index out of bounds in contraction");
+let lhs_val = lhs.get(&lhs_index).ok_or_else(|| {
+    CausalTensorError::EinSumError(EinSumValidationError::InvalidAxesSpecification {
+        message: "Internal error: lhs index out of bounds in contraction".to_string(),
+    })
+})?;
+let rhs_val = rhs.get(&rhs_index).ok_or_else(|| {
+    CausalTensorError::EinSumError(EinSumValidationError::InvalidAxesSpecification {
+        message: "Internal error: rhs index out of bounds in contraction".to_string(),
+    })
+})?;
 sum = sum + *lhs_val * *rhs_val;
Suggestion importance[1-10]: 7

__

Why: The suggestion correctly identifies expect() calls that can cause a panic and proposes replacing them with proper error handling, which improves the robustness of this library function.

Medium
Propagate negation failures safely
Suggestion Impact:The commit replaced the four .unwrap() calls on mlx_rs::ops::negative(&is*) with .map_err(|e| e.to_string())? to propagate errors instead of panicking.

code diff:

     // Negate InvS for use in other blocks
-    let n_is0 = mlx_rs::ops::negative(&is0).unwrap();
-    let n_is1 = mlx_rs::ops::negative(&is1).unwrap();
-    let n_is2 = mlx_rs::ops::negative(&is2).unwrap();
-    let n_is3 = mlx_rs::ops::negative(&is3).unwrap();
+    let n_is0 = mlx_rs::ops::negative(&is0).map_err(|e| e.to_string())?;
+    let n_is1 = mlx_rs::ops::negative(&is1).map_err(|e| e.to_string())?;
+    let n_is2 = mlx_rs::ops::negative(&is2).map_err(|e| e.to_string())?;
+    let n_is3 = mlx_rs::ops::negative(&is3).map_err(|e| e.to_string())?;

In the explicit_inverse_4x4 function, replace .unwrap() calls with error
propagation using .map_err(|e| e.to_string())? to prevent panics and allow the
CPU fallback mechanism to work correctly.

deep_causality_tensor/src/types/backend/mlx/mlx_backend_linear_algebra.rs [218-221]

-let n_is0 = mlx_rs::ops::negative(&is0).unwrap();
-let n_is1 = mlx_rs::ops::negative(&is1).unwrap();
-let n_is2 = mlx_rs::ops::negative(&is2).unwrap();
-let n_is3 = mlx_rs::ops::negative(&is3).unwrap();
+let n_is0 = mlx_rs::ops::negative(&is0).map_err(|e| e.to_string())?;
+let n_is1 = mlx_rs::ops::negative(&is1).map_err(|e| e.to_string())?;
+let n_is2 = mlx_rs::ops::negative(&is2).map_err(|e| e.to_string())?;
+let n_is3 = mlx_rs::ops::negative(&is3).map_err(|e| e.to_string())?;
Suggestion importance[1-10]: 7

__

Why: The suggestion correctly identifies unwrap() calls that would panic on failure, preventing the intended fallback to CPU inversion; replacing them with error propagation makes the function robust.

Medium
✅ Suggestions up to commit 8848b13
CategorySuggestion                                                                                                                                    Impact
High-level
Refactor the GPU backend to avoid inefficient CPU fallbacks
Suggestion Impact:The commit updates BackendTensor scalar assignment operators (AddAssign/SubAssign/MulAssign/DivAssign) to create a scalar tensor and use backend add/sub/mul/div, avoiding the previous CPU roundtrip via to_vec/create_from_vec and thus keeping MLX computations on-device. It also makes minor MLX backend improvements (error handling and bounds checks) but does not remove existing CPU fallbacks for reductions.

code diff:

# File: deep_causality_tensor/src/types/backend_tensor/ops.rs
@@ -214,34 +214,39 @@
 }
 
 // --- Scalar Assignment Arithmetic ---
+// Use backend broadcast operations to stay on device (GPU for MlxBackend)
 macro_rules! impl_scalar_assign_arithmetic {
     ($($t:ty),*) => {
         $(
             impl<B: TensorBackend> AddAssign<$t> for BackendTensor<$t, B> {
                 fn add_assign(&mut self, rhs: $t) {
-                    let data: Vec<$t> = B::to_vec(&self.inner).into_iter().map(|x| x + rhs).collect();
-                    self.inner = B::create_from_vec(data, &B::shape(&self.inner));
+                    // Create scalar tensor on same device, use broadcast add
+                    let scalar_tensor = B::create(&[rhs], &[1]);
+                    self.inner = B::add(&self.inner, &scalar_tensor);
                 }
             }
 
             impl<B: TensorBackend> SubAssign<$t> for BackendTensor<$t, B> {
                 fn sub_assign(&mut self, rhs: $t) {
-                    let data: Vec<$t> = B::to_vec(&self.inner).into_iter().map(|x| x - rhs).collect();
-                    self.inner = B::create_from_vec(data, &B::shape(&self.inner));
+                    // Create scalar tensor on same device, use broadcast sub
+                    let scalar_tensor = B::create(&[rhs], &[1]);
+                    self.inner = B::sub(&self.inner, &scalar_tensor);
                 }
             }
 
             impl<B: TensorBackend> MulAssign<$t> for BackendTensor<$t, B> {
                 fn mul_assign(&mut self, rhs: $t) {
-                    let data: Vec<$t> = B::to_vec(&self.inner).into_iter().map(|x| x * rhs).collect();
-                    self.inner = B::create_from_vec(data, &B::shape(&self.inner));
+                    // Create scalar tensor on same device, use broadcast mul
+                    let scalar_tensor = B::create(&[rhs], &[1]);
+                    self.inner = B::mul(&self.inner, &scalar_tensor);
                 }
             }
 
             impl<B: TensorBackend> DivAssign<$t> for BackendTensor<$t, B> {
                 fn div_assign(&mut self, rhs: $t) {
-                    let data: Vec<$t> = B::to_vec(&self.inner).into_iter().map(|x| x / rhs).collect();
-                    self.inner = B::create_from_vec(data, &B::shape(&self.inner));
+                    // Create scalar tensor on same device, use broadcast div
+                    let scalar_tensor = B::create(&[rhs], &[1]);
+                    self.inner = B::div(&self.inner, &scalar_tensor);
                 }
             }
         )*

The MLX backend implementation frequently falls back to the CPU for many tensor
operations, which is inefficient. This should be refactored to perform
computations on the GPU to avoid costly data transfers.

Examples:

deep_causality_tensor/src/types/backend/mlx/mlx_backend_tensor.rs [238-250]
deep_causality_tensor/src/types/backend_tensor/ops.rs [217-251]

Solution Walkthrough:

Before:

// In MlxBackend implementation
fn slice<T: Clone>(tensor: &MlxTensor<T>, ranges: &[Range<usize>]) -> MlxTensor<T> {
    // Fallback to CPU
    // 1. Transfer data from GPU to CPU
    let cpu_data: Vec<T> = MlxBackend::to_vec(tensor);
    let cpu_shape = MlxBackend::shape(tensor);
    
    // 2. Create a CPU tensor and perform the operation
    let cpu_tensor = CausalTensor::new(cpu_data, cpu_shape).unwrap();
    let result_cpu_tensor = CpuBackend::slice(&cpu_tensor, ranges);
    
    // 3. Transfer result from CPU back to GPU
    MlxBackend::create(result_cpu_tensor.as_slice(), result_cpu_tensor.shape())
}

// In BackendTensor scalar operations
impl<B: TensorBackend> AddAssign<f64> for BackendTensor<f64, B> {
    fn add_assign(&mut self, rhs: f64) {
        let data: Vec<f64> = B::to_vec(&self.inner).into_iter().map(|x| x + rhs).collect();
        self.inner = B::create_from_vec(data, &B::shape(&self.inner));
    }
}

After:

// In MlxBackend implementation
fn slice<T: Clone>(tensor: &MlxTensor<T>, ranges: &[Range<usize>]) -> MlxTensor<T> {
    // Perform the operation directly on the GPU device
    let starts: Vec<i32> = ranges.iter().map(|r| r.start as i32).collect();
    let stops: Vec<i32> = ranges.iter().map(|r| r.end as i32).collect();
    
    // Use the underlying MLX library's native slice operation
    let sliced_array = mlx_rs::ops::slice(tensor.as_array(), &starts, &stops)
        .expect("On-device slice failed");
        
    MlxTensor::new(sliced_array)
}

// In BackendTensor scalar operations
impl<B: TensorBackend> AddAssign<f64> for BackendTensor<f64, B> {
    fn add_assign(&mut self, rhs: f64) {
        // Create a scalar tensor on the same device
        let scalar_tensor = B::create(&[rhs], &[]);
        // Perform the operation on-device
        self.inner = B::add(&self.inner, &scalar_tensor);
    }
}
Suggestion importance[1-10]: 9

__

Why: The suggestion correctly identifies a critical performance flaw where many MlxBackend operations fall back to CPU, undermining the purpose of GPU acceleration by incurring expensive data transfer overhead.

High
Possible issue
Replace unsafe transmute with s...

Signed-off-by: Marvin Hansen <marvin.hansen@gmail.com>
@codecov
Copy link
Copy Markdown

codecov Bot commented Dec 31, 2025

Codecov Report

❌ Patch coverage is 89.75028% with 275 lines in your changes missing coverage. Please review.
✅ Project coverage is 86.32%. Comparing base (e09173f) to head (be92279).
⚠️ Report is 59 commits behind head on main.

Files with missing lines Patch % Lines
...y_multivector/src/extensions/hkt_multifield/mod.rs 74.09% 50 Missing ⚠️
...ty_tensor/src/types/backend/cpu/cpu_tensor_impl.rs 82.45% 20 Missing ⚠️
deep_causality_physics/src/mhd/grmhd.rs 76.81% 16 Missing ⚠️
deep_causality_physics/src/nuclear/quantities.rs 90.80% 16 Missing ⚠️
...tensor/src/types/backend/cpu/cpu_backend_tensor.rs 87.93% 14 Missing ⚠️
deep_causality_metric/src/types/metric/mod.rs 91.33% 13 Missing ⚠️
...ltivector/src/types/multifield/ops/differential.rs 90.69% 12 Missing ⚠️
deep_causality_physics/src/nuclear/pdg.rs 75.00% 12 Missing ⚠️
...p_causality_tensor/src/types/backend_tensor/ops.rs 93.81% 12 Missing ⚠️
...ep_causality_num/src/complex/complex_number/mod.rs 0.00% 10 Missing ⚠️
... and 23 more
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #433      +/-   ##
==========================================
- Coverage   94.79%   86.32%   -8.48%     
==========================================
  Files         723      777      +54     
  Lines       28710    31698    +2988     
==========================================
+ Hits        27217    27362     +145     
- Misses       1493     4336    +2843     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Signed-off-by: Marvin Hansen <marvin.hansen@gmail.com>
Signed-off-by: Marvin Hansen <marvin.hansen@gmail.com>
Signed-off-by: Marvin Hansen <marvin.hansen@gmail.com>
Signed-off-by: Marvin Hansen <marvin.hansen@gmail.com>
…not be tested on CI.

Signed-off-by: Marvin Hansen <marvin.hansen@gmail.com>
Signed-off-by: Marvin Hansen <marvin.hansen@gmail.com>
Signed-off-by: Marvin Hansen <marvin.hansen@gmail.com>
…les for better maintainabiliy.

Signed-off-by: Marvin Hansen <marvin.hansen@gmail.com>
…improvements.

Signed-off-by: Marvin Hansen <marvin.hansen@gmail.com>
…improvements.

Signed-off-by: Marvin Hansen <marvin.hansen@gmail.com>
…improvements.

Signed-off-by: Marvin Hansen <marvin.hansen@gmail.com>
@marvin-hansen marvin-hansen merged commit 18fda5a into deepcausality-rs:main Dec 31, 2025
10 of 13 checks passed
@qodo-code-review
Copy link
Copy Markdown
Contributor

Persistent suggestions updated to latest commit be92279

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

feat(deep_causality_topology): Add MLX acceleration to topology crate Add MLX acceleration to Tensor Crate

1 participant