Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions debug_gates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please remove all debug files and keep all necessary figure in outputs/

import tensorcircuit as tc
import logging

class GateManager:
def __init__(self, layers, rows, cols):
self.layers = layers
self.rows = rows
self.cols = cols

def get_bond_gate(self, bond_idx, layer_idx):
if bond_idx < 0 or bond_idx >= self.cols - 1:
return None
layer = self.layers[layer_idx]
for g in layer:
cols = [x // self.rows for x in g["index"]]
if min(cols) == bond_idx and max(cols) == bond_idx + 1:
return g
return None

def generate_sycamore_circuit(rows, cols, depth):
layers = []
l0 = []
for i in range(rows * cols):
l0.append({"gatef": tc.gates.h, "index": [i], "parameters": {}})
layers.append(l0)

import random
random.seed(42)
for _ in range(depth):
l = []
for i in range(rows * cols):
l.append({"gatef": tc.gates.rz, "index": [i], "parameters": {"theta": random.random()}})
for r in range(rows):
for c_ in range(cols):
idx = c_ * rows + r
if c_ < cols - 1:
idx_next = (c_ + 1) * rows + r
l.append({"gatef": tc.gates.cz, "index": [idx, idx_next], "parameters": {}})
if r < rows - 1:
idx_next = c_ * rows + (r + 1)
l.append({"gatef": tc.gates.cz, "index": [idx, idx_next], "parameters": {}})
layers.append(l)
return layers

def main():
rows = 4
cols = 4
depth = 6
layers = generate_sycamore_circuit(rows, cols, depth)
gm = GateManager(layers, rows, cols)

# Check Layer 0, Bond 0
g = gm.get_bond_gate(0, 0)
print(f"Layer 0, Bond 0: {g}")

# Check Layer 1, Bond 0
g = gm.get_bond_gate(0, 1)
print(f"Layer 1, Bond 0: {g}")

if __name__ == "__main__":
main()
109 changes: 109 additions & 0 deletions debug_tensor_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@

import jax.numpy as jnp
import numpy as np

def main():
# Simulate update_R for i=3 (Boundary)
rows = 4
K = 2

# R initialized to 1s.
# Shape: (b_old, b_new) + (g0_in, g0_out) + (g1_in, g1_out)
shape_R = (1, 1, 1, 1, 1, 1)
T = jnp.ones(shape_R)

# update_R logic simulation
# Contract with A (dummy) -> adds phys
# T becomes (b, b, p_new, p_old, g...)
# p dims are 2.
shape_T = (1, 1) + (2,)*4 + (2,)*4 + (1, 1, 1, 1)
T = jnp.ones(shape_T)

print(f"T init shape: {T.shape}")

g_start = 10 # 2 + 8

# Loop k=1. gate_L exists.
k = 1
idx_ri = g_start + 2*k # 12
idx_ro = g_start + 2*k + 1 # 13

r = 0 # gate_L index 1 % rows
phys_old_start = 6

# Gate L logic
# T = tensordot(T, U, axes=[[phys_old+r], [3]])
# U is (2, 2, 2, 2). axes=[[6], [3]].
# T rank 14. U rank 4. Result 16.
# Removes T[6], U[3].
# T indices: 0..5, 7..13, U[0], U[1], U[2].
# idx_ri (12) becomes 11. idx_ro (13) becomes 12.
# U[0, 1, 2] are appended (indices 13, 14, 15).
# New phys is U[2] (15).
# New legs are U[0, 1] (13, 14).

# Simulating shape
new_shape = list(T.shape)
new_shape.pop(6) # phys
new_shape.extend([2, 2, 2]) # lo, ro, li
# Note: U indices lo, ro, li.
# logic: lo, ro, li.

print(f"After dot shape: {new_shape}")

# moveaxis(T, -2, phys_old_start+r)
# Move -2 (ro?) to 6.
# Wait, my code logic:
# T = tensordot(T, U, axes=[[phys_old_start+r], [3]])
# T: (..., lo, ro, li).
# moveaxis -2 (ro) to phys.
# ro is new phys?
# U(lo, ro, li, ri).
# Contract ri (3).
# Remaining: lo(0), ro(1), li(2).
# ro(1) is new phys?
# Logic in code: "ro is new phys_old".
# So yes.
# move -2 (ro) to 6.

# New legs: lo(0), li(2).
# lo is at -3. li is at -1.
# lo matches idx_ro. li matches idx_ri.
# squeeze idx_ri, idx_ro.
# move -1 to idx_ri. move -2 to idx_ro.

# Let's perform using JAX
T = jnp.ones(shape_T)
U = jnp.ones((2, 2, 2, 2))

T = jnp.tensordot(T, U, axes=[[6], [3]])
# T shape: ..., lo, ro, li
print(f"After dot: {T.shape}")

T = jnp.moveaxis(T, -2, 6)
# T shape: ..., lo, li. (ro moved to 6).
print(f"After move phys: {T.shape}")

# squeeze 12, 13
T = jnp.squeeze(T, axis=(12, 13))
print(f"After squeeze: {T.shape}")

# move -1 (li) to 12. -2 (lo) to 13.
T = jnp.moveaxis(T, [-1, -2], [12, 13])
print(f"After move legs: {T.shape}")

# Check dims at 12, 13
print(f"Dims at 12, 13: {T.shape[12]}, {T.shape[13]}")

# k=0. pass.

# Final trace
for r in range(4):
# trace 2, 6
T = jnp.trace(T, axis1=2, axis2=6)
print(f"Trace {r}: {T.shape}")

print(f"Final shape: {T.shape}")

if __name__ == "__main__":
main()
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Loading