You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The idea of this PR is that we can maybe collect all the changes that are necessary to make the expectation functions traceable. Then we can jit a part of the __call__ method of the expectation function which would improve memory since jax also optimizes that under jit and it would allow us to checkpoint the computation of the expectation since it is very memory intensive to compute the RDM. I know there is also the checkpoint_ncon but this makes it more flexible.
the problematic part was this which is included in all expectation functions:
This construct performs a Python-level loop and uses JAX arrays in a dynamic control context, which prevents JAX from tracing or staging out the function properly.
Despite multiple workarounds, none of them integrated cleanly with JAX’s tracing model or yielded good memory behavior.
My proposal would be to remove all of these and let the user handle the dtype in the model they define. Since the result array is also only not big ~ num gates there should be no problem with the memory
A example from my code
@staticmethod@partial(jax.jit, static_argnames=("only_unique", "is_real"))def_accumulate_over_unitcell(
peps_tensors,
unitcell,
working_up_gates,
working_down_gates,
*,
only_unique: bool,
is_real: bool,
):
# initialize accumulator as a tuple to keep it purely functionalzero=jnp.array(0.0) ifis_realelsejnp.array(0.0+0.0j)
result=tuple(zerofor_inrange(len(working_up_gates)))
forx, iter_rowsinunitcell.iter_all_rows(only_unique=only_unique):
fory, viewiniter_rows:
# Get all 4 tensors in the 2x2 viewtensors_i=view.get_indices((slice(0, 2, None), slice(0, 2, None)))
tensors= [peps_tensors[i] forjintensors_iforiinj]
tensor_objs= [tfortlinview[:2, :2] fortintl]
# remat/checkpoint for memory; gates treated static via static_argnumsstep_result_down=jax.checkpoint(
calc_three_sites_triangle_without_bottom_left_multiple_gates
)(tensors, tensor_objs, working_down_gates)
step_result_up=jax.checkpoint(
calc_three_sites_triangle_without_top_right_multiple_gates
)(tensors, tensor_objs, working_up_gates)
# functional accumulation; ensure real dtype if requestedincr=tuple(
((sd.realifis_realelsesd) + (su.realifis_realelsesu))
forsd, suinzip(step_result_down, step_result_up, strict=True)
)
result=tuple(r+aforr, ainzip(result, incr))
returnresultdef__call__(
self,
peps_tensors: Sequence[jnp.ndarray],
unitcell: PEPS_Unit_Cell,
spiral_vectors: Optional[Union[jnp.ndarray, Sequence[jnp.ndarray]]] =None,
*,
normalize_by_size: bool=True,
only_unique: bool=True,
return_single_gate_results: bool=False,
) ->Union[jnp.ndarray, List[jnp.ndarray]]:
ifself.is_spiral_peps:
ifspiral_vectorsisNone:
raiseValueError(
"When using spiral iPEPS, spiral_vectors must be provided."
)
# if not isinstance(spiral_vectors, collections.abc.Sequence):# spiral_vectors = (spiral_vectors,) * 3#[top-left, top-right, bottom-right]working_down_gates=tuple(
apply_unitary(
h,
tuple(jnp.array(ri) forriin ((0,0), (0,1), (1,1))),
spiral_vectors,
self._spiral_D,
self._spiral_sigma,
self.real_d,
3,
(0, 1, 2),
varipeps_config.spiral_wavevector_type,
)
forhinself.up_gates
)
#[top-left, bottom-left, bottom-right]working_up_gates=tuple(
apply_unitary(
h,
tuple(jnp.array(ri) forriin ((0,0), (1,0), (1,1))),
spiral_vectors,
self._spiral_D,
self._spiral_sigma,
self.real_d,
3,
(0, 1, 2),
varipeps_config.spiral_wavevector_type,
)
forhinself.down_gates
)
else:
working_up_gates=self.up_gatesworking_down_gates=self.down_gates# Use jitted static method to perform the accumulation over the unitcellresult=self._accumulate_over_unitcell(
peps_tensors,
unitcell,
working_up_gates,
working_down_gates,
only_unique=only_unique,
is_real=(self._result_type==jnp.float64),
)
ifnormalize_by_size:
size=unitcell.get_len_unique_tensors() ifonly_uniqueelse (unitcell.get_size()[0] *unitcell.get_size()[1])
size=size*self.normalization_factorresult= [r/sizeforrinresult]
iflen(result) ==1:
returnresult[0]
else:
returnresult
AI SUMMARY
This pull request simplifies the three_sites.py module by removing the real_result logic from the _three_site_triangle_workhorse function and its associated callers. This streamlines the computation of expectation values for three-site triangles, ensuring consistent output regardless of whether the gates are Hermitian.
Refactoring and code simplification:
Removed the real_result argument from the _three_site_triangle_workhorse function signature and its usage in all calling functions, eliminating conditional logic based on gate Hermiticity. [1][2]
Deleted the calculation of the real_result variable (which checked for Hermitian gates) from all relevant calc_three_sites_triangle_without_*_multiple_gates functions. [1][2][3][4]
Updated all calls to _three_site_triangle_workhorse to remove the real_result parameter, further simplifying the function interfaces. [1][2][3][4]
MPMPMPMPMPMPMP
changed the title
Make the expectation functions tracable
Make the expectation functions traceable
Oct 20, 2025
Yeah, that's right that this is a problem why we cannot trace over the expectation functions but why not just using jax.lax.cond? I can take a look into it but this should be manageable that one just use that to make this traceable
I also tried that, but true_func and false_func need to have the same dtype. Which is obviously not the case.
The true_computation must take in a single argument of type and will be invoked with true_operand which must be of the same type. The false_computation must take in a single argument of type and will be invoked with false_operand which must be of the same type. The type of the returned value of true_computation and false_computation must be the same.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
TLDR;
The idea of this PR is that we can maybe collect all the changes that are necessary to make the expectation functions traceable. Then we can jit a part of the
__call__method of the expectation function which would improve memory since jax also optimizes that under jit and it would allow us to checkpoint the computation of the expectation since it is very memory intensive to compute the RDM. I know there is also thecheckpoint_nconbut this makes it more flexible.the problematic part was this which is included in all expectation functions:
This construct performs a Python-level loop and uses JAX arrays in a dynamic control context, which prevents JAX from tracing or staging out the function properly.
Despite multiple workarounds, none of them integrated cleanly with JAX’s tracing model or yielded good memory behavior.
My proposal would be to remove all of these and let the user handle the dtype in the model they define. Since the result array is also only not big ~ num gates there should be no problem with the memory
A example from my code
AI SUMMARY
This pull request simplifies the
three_sites.pymodule by removing thereal_resultlogic from the_three_site_triangle_workhorsefunction and its associated callers. This streamlines the computation of expectation values for three-site triangles, ensuring consistent output regardless of whether the gates are Hermitian.Refactoring and code simplification:
real_resultargument from the_three_site_triangle_workhorsefunction signature and its usage in all calling functions, eliminating conditional logic based on gate Hermiticity. [1] [2]real_resultvariable (which checked for Hermitian gates) from all relevantcalc_three_sites_triangle_without_*_multiple_gatesfunctions. [1] [2] [3] [4]_three_site_triangle_workhorseto remove thereal_resultparameter, further simplifying the function interfaces. [1] [2] [3] [4]