Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 2 additions & 63 deletions ext/TensorKitMooncakeExt/factorizations.jl
Comment thread
kshyatt marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -1,63 +1,2 @@
for f in (:svd_compact, :svd_full)
f_pullback = Symbol(f, :_pullback)
@eval begin
@is_primitive DefaultCtx ReverseMode Tuple{typeof($f), AbstractTensorMap, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual{<:AbstractTensorMap}, alg_dalg::CoDual)
A, dA = arrayify(A_dA)
alg = primal(alg_dalg)

USVᴴ = $f(A, primal(alg_dalg))
USVᴴ_dUSVᴴ = Mooncake.zero_fcodual(USVᴴ)
dUSVᴴ = last.(arrayify.(USVᴴ, tangent(USVᴴ_dUSVᴴ)))

function $f_pullback(::NoRData)
MatrixAlgebraKit.svd_pullback!(dA, A, USVᴴ, dUSVᴴ)
MatrixAlgebraKit.zero!.(dUSVᴴ)
return ntuple(Returns(NoRData()), 3)
end

return USVᴴ_dUSVᴴ, $f_pullback
end
end

# mutating version is not guaranteed to actually mutate
# so we can simply use the non-mutating version instead and avoid having to worry about
# storing copies and restoring state
f! = Symbol(f, :!)
f!_pullback = Symbol(f!, :_pullback)
@eval begin
@is_primitive DefaultCtx ReverseMode Tuple{typeof($f!), AbstractTensorMap, Any, MatrixAlgebraKit.AbstractAlgorithm}
Mooncake.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual{<:AbstractTensorMap}, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual) =
Mooncake.rrule!!(Mooncake.zero_fcodual($f), A_dA, alg_dalg)
end
end

@is_primitive DefaultCtx ReverseMode Tuple{typeof(svd_trunc), AbstractTensorMap, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(
::CoDual{typeof(svd_trunc)},
A_dA::CoDual{<:AbstractTensorMap},
alg_dalg::CoDual{<:MatrixAlgebraKit.TruncatedAlgorithm}
)
A, dA = arrayify(A_dA)
alg = primal(alg_dalg)

USVᴴ = svd_compact(A, alg.alg)
USVᴴtrunc, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc)
ϵ = MatrixAlgebraKit.truncation_error(diagview(USVᴴ[2]), ind)

USVᴴtrunc_dUSVᴴtrunc = Mooncake.zero_fcodual((USVᴴtrunc..., ϵ))
dUSVᴴtrunc = last.(arrayify.(USVᴴtrunc, Base.front(tangent(USVᴴtrunc_dUSVᴴtrunc))))

function svd_trunc_pullback((_, _, _, dϵ)::Tuple{NoRData, NoRData, NoRData, Real})
abs(dϵ) ≤ MatrixAlgebraKit.defaulttol(dϵ) ||
@warn "Gradient for `svd_trunc` ignores non-zero tangents for truncation error"
MatrixAlgebraKit.svd_pullback!(dA, A, USVᴴ, dUSVᴴtrunc, ind)
return ntuple(Returns(NoRData()), 3)
end

return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_pullback
end

@is_primitive DefaultCtx ReverseMode Tuple{typeof(svd_trunc!), AbstractTensorMap, Any, MatrixAlgebraKit.AbstractAlgorithm}
Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual{<:AbstractTensorMap}, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual) =
Mooncake.rrule!!(Mooncake.zero_fcodual(svd_trunc), A_dA, alg_dalg)
# needed for the ising bimodule case
@zero_derivative DefaultCtx Tuple{typeof(MatrixAlgebraKit.initialize_output), Any, AbstractTensorMap, MatrixAlgebraKit.AbstractAlgorithm}
17 changes: 17 additions & 0 deletions ext/TensorKitMooncakeExt/tensoroperations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -250,3 +250,20 @@ function trace_permute_pullback_ΔA!(
)
return NoRData()
end

@is_primitive(
DefaultCtx,
Tuple{
typeof(TensorKit.scalar),
AbstractTensorMap,
}
)
function Mooncake.rrule!!(::CoDual{typeof(TensorKit.scalar)}, t_dt::CoDual{<:AbstractTensorMap})
t, dt = arrayify(t_dt)
val = scalar(t)
function scalar_pullback(Δval)
first(blocks(dt))[2][1] = Δval
return NoRData(), NoRData()
end
return Mooncake.zero_fcodual(val), scalar_pullback
end
6 changes: 6 additions & 0 deletions ext/TensorKitMooncakeExt/utility.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ Mooncake.tangent_type(::Type{<:HomSpace}) = Mooncake.NoTangent
@zero_derivative DefaultCtx Tuple{typeof(TensorKit.sectorstructure), Any}
@zero_derivative DefaultCtx Tuple{typeof(TensorKit.degeneracystructure), Any}

@zero_derivative DefaultCtx Tuple{typeof(TensorOperations.tensorstructure), AbstractTensorMap}
@zero_derivative DefaultCtx Tuple{typeof(TensorOperations.tensorstructure), AbstractTensorMap, Int, Bool}

@zero_derivative DefaultCtx Tuple{typeof(TensorOperations.tensorcontract_structure), AbstractTensorMap, Index2Tuple, Bool, AbstractTensorMap, Index2Tuple, Bool, Index2Tuple}

@zero_derivative DefaultCtx Tuple{typeof(TensorKit.has_shared_permute), AbstractTensorMap, Index2Tuple}
@zero_derivative DefaultCtx Tuple{typeof(TensorKit.select), HomSpace, Index2Tuple}
@zero_derivative DefaultCtx Tuple{typeof(TensorKit.flip), HomSpace, Any}
@zero_derivative DefaultCtx Tuple{typeof(TensorKit.permute), HomSpace, Index2Tuple}
Expand Down
31 changes: 20 additions & 11 deletions test/mooncake/factorizations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ using MatrixAlgebraKit: remove_qr_gauge_dependence!, remove_lq_gauge_dependence!
using Mooncake
using Random

function call_and_zero!(f!, A, alg)
F′ = f!(A, alg)
MatrixAlgebraKit.zero!(A)
return F′
end

mode = Mooncake.ReverseMode
rng = Random.default_rng()
Expand All @@ -18,7 +23,6 @@ eltypes = (Float64, ComplexF64)
@timedtestset "Mooncake - Factorizations: $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes
atol = default_tol(T)
rtol = default_tol(T)

@timedtestset "QR" begin
A = randn(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2])

Expand All @@ -29,8 +33,7 @@ eltypes = (Float64, ComplexF64)
ΔQR = Mooncake.randn_tangent(rng, QR)
remove_qr_gauge_dependence!(ΔQR..., A, QR...)
Mooncake.TestUtils.test_rule(rng, qr_full, A; output_tangent = ΔQR, atol, rtol, mode, is_primitive = false)
# TODO:
# Mooncake.TestUtils.test_rule(rng, qr_null, A; atol, rtol, mode, is_primitive = false)
#Mooncake.TestUtils.test_rule(rng, qr_null, A; atol, rtol, mode, is_primitive = false)

A = randn(T, V[1] ⊗ V[2] ⊗ V[3] ← (V[4] ⊗ V[5])')

Expand All @@ -41,34 +44,31 @@ eltypes = (Float64, ComplexF64)
ΔQR = Mooncake.randn_tangent(rng, QR)
remove_qr_gauge_dependence!(ΔQR..., A, QR...)
Mooncake.TestUtils.test_rule(rng, qr_full, A; output_tangent = ΔQR, atol, rtol, mode, is_primitive = false)
# TODO:
# Mooncake.TestUtils.test_rule(rng, qr_null, A; atol, rtol, mode, is_primitive = false)
#Mooncake.TestUtils.test_rule(rng, qr_null, A; atol, rtol, mode, is_primitive = false)
end

@timedtestset "LQ" begin
A = randn(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2])

Mooncake.TestUtils.test_rule(rng, lq_compact, A; atol, rtol, mode, is_primitive = false)

# qr_full/qr_null requires being careful with gauges
# lq_full/lq_null requires being careful with gauges
LQ = lq_full(A)
ΔLQ = Mooncake.randn_tangent(rng, LQ)
remove_lq_gauge_dependence!(ΔLQ..., A, LQ...)
Mooncake.TestUtils.test_rule(rng, lq_full, A; output_tangent = ΔLQ, atol, rtol, mode, is_primitive = false)
# TODO:
# Mooncake.TestUtils.test_rule(rng, lq_null, A; atol, rtol, mode, is_primitive = false)
#Mooncake.TestUtils.test_rule(rng, lq_null, A; atol, rtol, mode, is_primitive = false)

A = randn(T, V[1] ⊗ V[2] ← (V[3] ⊗ V[4] ⊗ V[5])')

Mooncake.TestUtils.test_rule(rng, lq_compact, A; atol, rtol, mode, is_primitive = false)

# qr_full/qr_null requires being careful with gauges
# lq_full/lq_null requires being careful with gauges
LQ = lq_full(A)
ΔLQ = Mooncake.randn_tangent(rng, LQ)
remove_lq_gauge_dependence!(ΔLQ..., A, LQ...)
Mooncake.TestUtils.test_rule(rng, lq_full, A; output_tangent = ΔLQ, atol, rtol, mode, is_primitive = false)
# TODO:
# Mooncake.TestUtils.test_rule(rng, lq_null, A; atol, rtol, mode, is_primitive = false)
#Mooncake.TestUtils.test_rule(rng, lq_null, A; atol, rtol, mode, is_primitive = false)
end

@timedtestset "Eigenvalue decomposition" begin
Expand Down Expand Up @@ -105,6 +105,15 @@ eltypes = (Float64, ComplexF64)
ΔUSVᴴtrunc = (Mooncake.randn_tangent(rng, Base.front(USVᴴtrunc))..., zero(last(USVᴴtrunc)))
remove_svd_gauge_dependence!(ΔUSVᴴtrunc[1], ΔUSVᴴtrunc[3], Base.front(USVᴴtrunc)...)
Mooncake.TestUtils.test_rule(rng, svd_trunc, t, alg; output_tangent = ΔUSVᴴtrunc, atol, rtol, mode)

V_trunc = spacetype(t)(c => min(size(b)...) ÷ 2 for (c, b) in blocks(t))
trunc = truncspace(V_trunc)
USVᴴ = svd_compact(t)
alg = MatrixAlgebraKit.select_algorithm(svd_trunc, t, nothing; trunc)
USVᴴtrunc = svd_trunc(t, alg)
ΔUSVᴴtrunc = (Mooncake.randn_tangent(rng, Base.front(USVᴴtrunc))..., zero(last(USVᴴtrunc)))
remove_svd_gauge_dependence!(ΔUSVᴴtrunc[1], ΔUSVᴴtrunc[3], Base.front(USVᴴtrunc)...)
Mooncake.TestUtils.test_rule(rng, call_and_zero!, svd_trunc!, t, alg; output_tangent = ΔUSVᴴtrunc, atol, rtol, mode, is_primitive = false)
end
end
end
Loading