Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,29 +18,40 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Tempo = "c33777b2-e695-4fae-9135-aeae8855dd81"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Ephemerides = "6a9c3322-c8fe-4c26-8ad6-14a6f8acd2a0"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"

[extensions]
ChainRulesCoreExt = ["ChainRulesCore"]
EphemeridesExt = ["Ephemerides"]
MooncakeExt = ["Mooncake", "ChainRulesCore"]

[compat]
ChainRulesCore = "1"
Ephemerides = "1"
ForwardDiff = "0.10, 1"
FunctionWrappers = "1"
FunctionWrappersWrappers = "0.1"
IERSConventions = "1"
JSMDInterfaces = "1.6"
JSMDUtils = "1"
Mooncake = "0.4"
PrecompileTools = "1"
Zygote = "0.6, 0.7"
ReferenceFrameRotations = "3"
SMDGraphs = "0.2"
StaticArrays = "1"
Tempo = "1.2"
julia = "1.9"
julia = "1.10, 1.11"

[extras]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "SafeTestsets"]
test = ["Test", "SafeTestsets", "ChainRulesCore", "FiniteDiff", "Mooncake", "Zygote"]
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@ extensible axes/point graph models for mission analysis and space mission design

- Convert between different time scales and representations (via [Tempo.jl](https://github.com/JuliaSpaceMissionDesign/Tempo.jl));
- Read binary ephemeris files (via [Ephemerides.jl](https://github.com/JuliaSpaceMissionDesign/Ephemerides.jl) or [CalcephEphemeris.jl](https://github.com/JuliaSpaceMissionDesign/CalcephEphemeris.jl))
- Create custom reference frame systems with both standard and user-defined points and axes.
- Transform states and their higher-order derivatives between different frames (up to jerk)
- Create custom reference frame systems with both standard and user-defined points, axes and directions.
- Transform states and their higher-order derivatives between different frames (up to jerk).
- Compile transformations into zero-overhead, AD-transparent callables for hot loops via `compile_rotation`, `compile_translation` and `compile_direction`.

All of this seamlessly integrated with [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl).
Automatic differentiation is supported through [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl), with reverse-mode AD support via [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl) and [Mooncake.jl](https://github.com/compintell/Mooncake.jl) package extensions.

## Installation

Expand Down
23 changes: 23 additions & 0 deletions docs/src/API/frames_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,26 @@ direction6
direction9
direction12
```

## [Compiled Fast-Path](@id compiled_api)

The compiled fast-path provides zero-overhead, AD-transparent callables that bypass the
`FunctionWrapper` type-erasure barrier used internally by `FrameSystem`. Use these when you
need full inlining, custom AD-backend support (e.g., Mooncake, Zygote), or maximum
performance in hot loops such as ODE right-hand sides.

### Types

```@docs
CompiledRotation
CompiledTranslation
CompiledDirection
```

### Constructors

```@docs
compile_rotation
compile_translation
compile_direction
```
3 changes: 2 additions & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@ extensible axes/point graph models for mission analysis and space mission design
- Read binary ephemeris files (via [Ephemerides.jl](https://github.com/JuliaSpaceMissionDesign/Ephemerides.jl) or [CalcephEphemeris.jl](https://github.com/JuliaSpaceMissionDesign/CalcephEphemeris.jl) extensions).
- Create custom reference frame systems with both standard and user-defined points, axes and directions.
- Transform states and their higher-order derivatives between different frames (up to jerk).
- Compile transformations into zero-overhead, AD-transparent callables for hot loops via [`compile_rotation`](@ref), [`compile_translation`](@ref) and [`compile_direction`](@ref).

All of this integrated with [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl).
Automatic differentiation is supported through [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl), with reverse-mode AD support via [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl) and [Mooncake.jl](https://github.com/compintell/Mooncake.jl) package extensions.

## Installation

Expand Down
153 changes: 153 additions & 0 deletions ext/ChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
module ChainRulesCoreExt

using FrameTransformations
using ChainRulesCore
using StaticArrays: SVector, SMatrix

using FrameTransformations: Rotation
using ReferenceFrameRotations: DCM
using JSMDUtils.Autodiff: derivative

# ==========================================================================================
# Helper: extract numeric values from cotangents
# ==========================================================================================

# Cotangents may arrive as plain arrays (Zygote) or as ChainRulesCore.Tangent structs
# (Mooncake via @from_chainrules). We need to handle both transparently.

# For vectors: convert cotangent Δ to a concrete SVector for dot product computation.
_to_svec(x::SVector{N}) where {N} = x
_to_svec(x::AbstractVector) = SVector(x...)
function _to_svec(x::ChainRulesCore.Tangent{<:SVector{N}}) where {N}
return SVector(x.data...)
end
function _to_svec(x::ChainRulesCore.Tangent{<:Any, <:NamedTuple{(:data,)}})
d = x.data
if d isa ChainRulesCore.Tangent
return SVector(d.backing...)
elseif d isa Tuple
return SVector(d...)
else
return SVector(d...)
end
end
_to_svec(x) = SVector(collect(x)...)

# Scalar dot product between derivative J and cotangent Δ
function _vec_pullback_dt(J, Δ)
Δv = _to_svec(unthunk(Δ))
return sum(J .* Δv)
end

# ==========================================================================================
# Helper: Rotation pullback — Frobenius inner product between J and Δ
# ==========================================================================================

# derivative() on rotation3/6/9/12 returns a Vector{DCM{T}} (S elements for order S).
# The cotangent Δ may be a Rotation, a Tangent{Rotation}, or similar.
# We compute ∂L/∂t = Σᵢ ⟨Jᵢ, Δᵢ⟩_F (Frobenius inner product per DCM)

# Extract the flat NTuple{9} data from a DCM or its tangent representation
_dcm_flat(x::DCM) = x.data # NTuple{9,T}
_dcm_flat(x::SMatrix{3,3}) = Tuple(x)
_dcm_flat(x::AbstractMatrix) = Tuple(x)
# Mooncake tangent for DCM: Tangent{DCM, @NamedTuple{data::NTuple{9,T}}}
function _dcm_flat(x::ChainRulesCore.Tangent)
d = x.data
if d isa Tuple
return d
elseif d isa ChainRulesCore.Tangent
return Tuple(d.backing)
else
return Tuple(d)
end
end

# Frobenius inner product between two flat tuples
_frob_dot(a::NTuple{N}, b::NTuple{N}) where {N} = sum(a .* b)
_frob_dot(a::NTuple{N}, b) where {N} = sum(a .* Tuple(b))
_frob_dot(a, b) = sum(Tuple(a) .* Tuple(b))

# Extract the tuple of DCMs from the rotation cotangent
_get_rot_dcm_tuple(R::Rotation) = R.m
function _get_rot_dcm_tuple(Δ::ChainRulesCore.Tangent)
m = Δ.m
if m isa Tuple
return m
elseif m isa ChainRulesCore.Tangent
return Tuple(m.backing)
else
return Tuple(m)
end
end
_get_rot_dcm_tuple(Δ) = Δ.m

function _rotation_pullback_dt(J::AbstractVector, Δ)
Δu = unthunk(Δ)
Δm = _get_rot_dcm_tuple(Δu)
dt = 0.0
for i in eachindex(J)
dt += _frob_dot(_dcm_flat(J[i]), _dcm_flat(unthunk(Δm[i])))
end
return dt
end

# ==========================================================================================
# rrules for vector functions: vector3, vector6, vector9, vector12
# ==========================================================================================

for (vfun, N) in (
(:vector3, 3), (:vector6, 6), (:vector9, 9), (:vector12, 12)
)
@eval begin
function ChainRulesCore.rrule(::typeof($vfun), fr::FrameSystem, from, to, ax, t::Number)
val = $vfun(fr, from, to, ax, t)
function $(Symbol(vfun, :_pullback))(Δ)
J = derivative(τ -> $vfun(fr, from, to, ax, τ), t)
dt = _vec_pullback_dt(J, Δ)
return NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), dt
end
return val, $(Symbol(vfun, :_pullback))
end
end
end

# ==========================================================================================
# rrules for rotation functions: rotation3, rotation6, rotation9, rotation12
# ==========================================================================================

for rfun in (:rotation3, :rotation6, :rotation9, :rotation12)
@eval begin
function ChainRulesCore.rrule(::typeof($rfun), fr::FrameSystem, from, to, t::Number)
val = $rfun(fr, from, to, t)
function $(Symbol(rfun, :_pullback))(Δ)
J = derivative(τ -> $rfun(fr, from, to, τ), t)
dt = _rotation_pullback_dt(J, Δ)
return NoTangent(), NoTangent(), NoTangent(), NoTangent(), dt
end
return val, $(Symbol(rfun, :_pullback))
end
end
end

# ==========================================================================================
# rrules for direction functions: direction3, direction6, direction9, direction12
# ==========================================================================================

for (dfun, N) in (
(:direction3, 3), (:direction6, 6), (:direction9, 9), (:direction12, 12)
)
@eval begin
function ChainRulesCore.rrule(::typeof($dfun), fr::FrameSystem, name::Symbol, ax, t::Number)
val = $dfun(fr, name, ax, t)
function $(Symbol(dfun, :_pullback))(Δ)
J = derivative(τ -> $dfun(fr, name, ax, τ), t)
dt = _vec_pullback_dt(J, Δ)
return NoTangent(), NoTangent(), NoTangent(), NoTangent(), dt
end
return val, $(Symbol(dfun, :_pullback))
end
end
end

end # module
50 changes: 50 additions & 0 deletions ext/MooncakeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
module MooncakeExt

using FrameTransformations: FrameSystem,
vector3, vector6, vector9, vector12,
rotation3, rotation6, rotation9, rotation12,
direction3, direction6, direction9, direction12

using FunctionWrappers: FunctionWrapper
using FunctionWrappersWrappers: FunctionWrappersWrapper

import Mooncake
using Mooncake: @from_chainrules, DefaultCtx

# ==========================================================================================
# Declare FrameSystem and its internal opaque types as non-differentiable
# ==========================================================================================

# FunctionWrapper and FunctionWrappersWrapper are compiled C function pointers — opaque
# to reverse-mode AD. Declaring them as NoTangent prevents Mooncake from trying to
# construct tangent types for them (which would fail).
Mooncake.tangent_type(::Type{<:FunctionWrapper}) = Mooncake.NoTangent
Mooncake.tangent_type(::Type{<:FunctionWrappersWrapper}) = Mooncake.NoTangent

# FrameSystem is a structural container (graph of axes/points with function wrappers).
# It is never differentiated through — only `t` is differentiable.
Mooncake.tangent_type(::Type{<:FrameSystem}) = Mooncake.NoTangent

# ==========================================================================================
# Register ChainRulesCore rrules with Mooncake
# ==========================================================================================

# Vector functions: vector3, vector6, vector9, vector12
@from_chainrules DefaultCtx Tuple{typeof(vector3), FrameSystem, Any, Any, Any, T} where {T<:Number}
@from_chainrules DefaultCtx Tuple{typeof(vector6), FrameSystem, Any, Any, Any, T} where {T<:Number}
@from_chainrules DefaultCtx Tuple{typeof(vector9), FrameSystem, Any, Any, Any, T} where {T<:Number}
@from_chainrules DefaultCtx Tuple{typeof(vector12), FrameSystem, Any, Any, Any, T} where {T<:Number}

# Rotation functions: rotation3, rotation6, rotation9, rotation12
@from_chainrules DefaultCtx Tuple{typeof(rotation3), FrameSystem, Any, Any, T} where {T<:Number}
@from_chainrules DefaultCtx Tuple{typeof(rotation6), FrameSystem, Any, Any, T} where {T<:Number}
@from_chainrules DefaultCtx Tuple{typeof(rotation9), FrameSystem, Any, Any, T} where {T<:Number}
@from_chainrules DefaultCtx Tuple{typeof(rotation12), FrameSystem, Any, Any, T} where {T<:Number}

# Direction functions: direction3, direction6, direction9, direction12
@from_chainrules DefaultCtx Tuple{typeof(direction3), FrameSystem, Symbol, Any, T} where {T<:Number}
@from_chainrules DefaultCtx Tuple{typeof(direction6), FrameSystem, Symbol, Any, T} where {T<:Number}
@from_chainrules DefaultCtx Tuple{typeof(direction9), FrameSystem, Symbol, Any, T} where {T<:Number}
@from_chainrules DefaultCtx Tuple{typeof(direction12), FrameSystem, Symbol, Any, T} where {T<:Number}

end # module
7 changes: 4 additions & 3 deletions src/Core/axes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,16 @@ function add_axes!(
end


# Create point
# Create node and insert into the graph
node = FrameAxesNode{O,T}(name, id, parentid, funs)

# Insert new point in the graph
add_axes!(frames, node)

# Connect the new axes to the parent axes in the graph
!isnothing(parentid) && add_edge!(axes_graph(frames), parentid, id)
# Connect the new axes to the parent axes in the graph (skip for root)
parentid != id && add_edge!(axes_graph(frames), parentid, id)

empty!(frames._axes_nodes)
return nothing
end

Expand Down
Loading