Skip to content

Commit 3a2a365

Browse files
committed
feat(deep_causality_tensor): Applied multiple lints, fixes, and code improvements.
Signed-off-by: Marvin Hansen <marvin.hansen@gmail.com>
1 parent 583f779 commit 3a2a365

4 files changed

Lines changed: 34 additions & 16 deletions

File tree

deep_causality_multivector/tests/types/multifield/algebra/mod.rs

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -361,11 +361,13 @@ fn test_inverse_product_is_identity() {
361361
let metric = Metric::from_signature(2, 0, 0);
362362
let num_blades = 4;
363363

364-
// Create field with scalar = 3
364+
// Create field with a proper versor (unit vector e1)
365+
// In Euclidean space, e1 is its own inverse: e1 * e1 = 1
365366
let mut mvs = Vec::with_capacity(8);
366367
for _ in 0..8 {
367368
let mut data = vec![0.0f32; num_blades];
368-
data[0] = 3.0;
369+
// Use unit vector e1 which is a versor (invertible)
370+
data[1] = 1.0; // e1 component
369371
mvs.push(CausalMultiVector::unchecked(data, metric));
370372
}
371373

@@ -381,9 +383,19 @@ fn test_inverse_product_is_identity() {
381383
let scalar = mv.data()[0];
382384
assert!(
383385
(scalar - 1.0).abs() < 1e-3,
384-
"A * A^-1 should be 1, got {}",
386+
"A * A^-1 scalar part should be 1, got {}",
385387
scalar
386388
);
389+
390+
// Check that other components are near zero
391+
for (i, &val) in mv.data().iter().enumerate().skip(1) {
392+
assert!(
393+
val.abs() < 1e-3,
394+
"Non-scalar part of identity should be zero at index {}, got {}",
395+
i,
396+
val
397+
);
398+
}
387399
}
388400
}
389401

deep_causality_tensor/src/types/backend/mlx/mlx_backend_linear_algebra.rs

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -214,12 +214,11 @@ fn explicit_inverse_4x4(input: &mlx_rs::Array) -> Result<mlx_rs::Array, String>
214214
// 5. InvS = S^-1
215215
let (is0, is1, is2, is3) = inv_2x2(&s0, &s1, &s2, &s3);
216216

217-
// F22 = InvS
218-
let zero_is = sub_(&is0, &is0);
219-
// Use & refs for sub_blocks args (owned arrays)
220-
let (n_is0, n_is1, n_is2, n_is3) = sub_blocks(
221-
&zero_is, &zero_is, &zero_is, &zero_is, &is0, &is1, &is2, &is3,
222-
);
217+
// Negate InvS for use in other blocks
218+
let n_is0 = mlx_rs::ops::negative(&is0).unwrap();
219+
let n_is1 = mlx_rs::ops::negative(&is1).unwrap();
220+
let n_is2 = mlx_rs::ops::negative(&is2).unwrap();
221+
let n_is3 = mlx_rs::ops::negative(&is3).unwrap();
223222
// F21 = -InvS * T1.
224223
let (f21_0, f21_1, f21_2, f21_3) =
225224
mul_2x2(&n_is0, &n_is1, &n_is2, &n_is3, &t1_0, &t1_1, &t1_2, &t1_3);

deep_causality_tensor/src/types/backend/mlx/mlx_backend_tensor.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,12 @@ impl TensorBackend for MlxBackend {
125125
}
126126

127127
fn strides<T>(tensor: &Self::Tensor<T>) -> Vec<usize> {
128-
tensor.as_array().strides().to_vec()
128+
tensor
129+
.as_array()
130+
.strides()
131+
.iter()
132+
.map(|&s| s as usize)
133+
.collect()
129134
}
130135

131136
fn get<T: Clone>(tensor: &Self::Tensor<T>, index: &[usize]) -> Option<T> {

deep_causality_tensor/src/types/cpu_tensor/ops/tensor_ein_sum/ein_sum_impl.rs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -282,12 +282,14 @@ where
282282
}
283283
}
284284

285-
// Get values and accumulate
286-
if let (Some(lhs_val), Some(rhs_val)) =
287-
(lhs.get(&lhs_index), rhs.get(&rhs_index))
288-
{
289-
sum = sum + *lhs_val * *rhs_val;
290-
}
285+
// Get values and accumulate - use expect to fail fast on invalid indices
286+
let lhs_val = lhs
287+
.get(&lhs_index)
288+
.expect("Internal error: lhs index out of bounds in contraction");
289+
let rhs_val = rhs
290+
.get(&rhs_index)
291+
.expect("Internal error: rhs index out of bounds in contraction");
292+
sum = sum + *lhs_val * *rhs_val;
291293
}
292294

293295
// Store result

0 commit comments

Comments
 (0)