Skip to content

Commit 4b72e0b

Browse files
committed
Add SetConstantCmd and DispatchIndirectCmd for compute
1 parent dcdb4f5 commit 4b72e0b

12 files changed

Lines changed: 411 additions & 24 deletions

File tree

src/Aardvark.Rendering.GL/Instructions/OpenGL.fs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,7 @@ module OpenGl =
456456
let BindProgram = getProcAddress "glUseProgram"
457457

458458
let DispatchCompute = getProcAddress "glDispatchCompute"
459+
let DispatchComputeIndirect = getProcAddress "glDispatchComputeIndirect"
459460
let GetInteger = getProcAddress "glGetIntegerv"
460461
let GetFloat = getProcAddress "glGetFloatv"
461462
let GetDouble = getProcAddress "glGetDoublev"

src/Aardvark.Rendering.GL/Management/AssemblerExtensions.fs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ type ICommandStream =
5555
abstract member UseProgram : m : nativeptr<int> -> unit
5656
abstract member DispatchCompute : gx : int * gy : int * gz : int -> unit
5757
abstract member DispatchCompute : groups : nativeptr<V3i> -> unit
58+
abstract member DispatchComputeIndirect : indirect: nativeint -> unit
5859

5960
abstract member MemoryBarrier : MemoryBarrierFlags -> unit
6061

@@ -499,6 +500,10 @@ module GLAssemblerExtensions =
499500
s.PushIntArg (NativePtr.toNativeInt groups + 0n)
500501
s.Call OpenGl.Pointers.DispatchCompute
501502

503+
member this.DispatchComputeIndirect(indirect : nativeint) =
504+
s.BeginCall(1)
505+
s.PushArg indirect
506+
s.Call OpenGl.Pointers.DispatchComputeIndirect
502507

503508
member this.Enable(v : int) =
504509
s.BeginCall(1)
@@ -1012,6 +1017,7 @@ module GLAssemblerExtensions =
10121017
member this.Disable(v: nativeptr<int>) = this.Disable(v)
10131018
member this.DispatchCompute(gx: int, gy: int, gz: int) = this.DispatchCompute(gx, gy, gz)
10141019
member this.DispatchCompute(groups: nativeptr<V3i>) = this.DispatchCompute(groups)
1020+
member this.DispatchComputeIndirect(indirect: nativeint) = this.DispatchComputeIndirect(indirect)
10151021
member this.DrawArrays(stats: nativeptr<V2i>, isActive: nativeptr<int>, beginMode: nativeptr<GLBeginMode>, calls: nativeptr<DrawCallInfoList>) = this.DrawArrays(stats, isActive, beginMode, calls)
10161022
member this.DrawArraysIndirect(stats: nativeptr<V2i>, isActive: nativeptr<int>, beginMode: nativeptr<GLBeginMode>, indirect: nativeptr<IndirectDrawArgs>) = this.DrawArraysIndirect(stats, isActive, beginMode, indirect)
10171023
member this.DrawElements(stats: nativeptr<V2i>, isActive: nativeptr<int>, beginMode: nativeptr<GLBeginMode>, indexType: int, calls: nativeptr<DrawCallInfoList>) = this.DrawElements(stats, isActive, beginMode, indexType, calls)
@@ -1135,6 +1141,7 @@ module GLAssemblerExtensions =
11351141
member x.Disable(v: nativeptr<int>) = inner.Disable(v); x.Append("Disable", v)
11361142
member x.DispatchCompute(gx: int, gy: int, gz: int) = inner.DispatchCompute(gx, gy, gz); x.Append("DispatchCompute", gx, gy, gz)
11371143
member x.DispatchCompute(groups: nativeptr<V3i>) = inner.DispatchCompute(groups); x.Append("DispatchCompute", groups)
1144+
member x.DispatchComputeIndirect(indirect: nativeint) = inner.DispatchComputeIndirect(indirect); x.Append("DispatchComputeIndirect")
11381145
member x.DrawArrays(stats: nativeptr<V2i>, isActive: nativeptr<int>, beginMode: nativeptr<GLBeginMode>, calls: nativeptr<DrawCallInfoList>) = inner.DrawArrays(stats, isActive, beginMode, calls); x.Append("DrawArrays", stats, isActive, beginMode, calls)
11391146
member x.DrawArraysIndirect(stats: nativeptr<V2i>, isActive: nativeptr<int>, beginMode: nativeptr<GLBeginMode>, indirect: nativeptr<IndirectDrawArgs>) = inner.DrawArraysIndirect(stats, isActive, beginMode, indirect); x.Append("DrawArraysIndirect", stats, isActive, beginMode, indirect)
11401147
member x.DrawElements(stats: nativeptr<V2i>, isActive: nativeptr<int>, beginMode: nativeptr<GLBeginMode>, indexType: int, calls: nativeptr<DrawCallInfoList>) = inner.DrawElements(stats, isActive, beginMode, indexType, calls); x.Append("DrawElements", stats, isActive, beginMode, indexType, calls)

src/Aardvark.Rendering.GL/Runtime/Compute.fs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,11 +375,22 @@ module internal ComputeTaskInternals =
375375
s.BindImageTexture(slot, TextureAccess.ReadWrite, binding.Pointer)
376376
)
377377

378+
| ComputeCommand.SetConstantCmd _ ->
379+
raise <| NotSupportedException("Constants are not supported.")
380+
378381
| ComputeCommand.DispatchCmd groups ->
379382
do! CompilerState.assemble (fun _ s ->
380383
s.DispatchCompute(groups.X, groups.Y, groups.Z)
381384
)
382385

386+
| ComputeCommand.DispatchIndirectCmd (indirectBuffer, offset) ->
387+
let indirectBuffer = unbox<GL.Buffer> indirectBuffer
388+
389+
do! CompilerState.assemble (fun _ s ->
390+
s.BindBuffer(BufferTarget.DispatchIndirectBuffer, indirectBuffer.Handle)
391+
s.DispatchComputeIndirect(nativeint offset)
392+
)
393+
383394
| ComputeCommand.ExecuteCmd other ->
384395
do! CompilerState.execute other
385396

src/Aardvark.Rendering.GL/Runtime/Runtime.fs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,9 @@ type Runtime(debug : IDebugConfig) =
193193
member x.CreateInputBinding(shader, inputs) =
194194
x.CreateInputBinding(shader, inputs)
195195

196+
member x.GetComputeConstant<'T>(_, _) : IComputeConstant<'T> =
197+
raise <| NotSupportedException("Constants are not supported.")
198+
196199
member x.CompileCompute (commands) =
197200
x.CompileCompute commands
198201

src/Aardvark.Rendering.Vulkan/Runtime/ComputeTask.fs

Lines changed: 64 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,42 @@ module internal ComputeTaskInternals =
1313
{
1414
Shader : IComputeShader
1515
DescriptorSets : INativeResourceLocation<DescriptorSetBinding>
16-
PushConstants : IConstantResourceLocation<PushConstants> voption
1716
}
1817

1918
interface IComputeInputBinding with
2019
member x.Shader = x.Shader
2120

21+
type PushConstant(shader: IComputeShader, name: Symbol, inputType: Type) =
22+
let layout =
23+
shader.PipelineLayout.PushConstants
24+
|> ValueOption.defaultWith (fun _ -> failf "Compute shader does not use any push constants.")
25+
26+
let field =
27+
layout.Buffer.ubFields
28+
|> List.tryFindV (fun f -> f.ufName = string name)
29+
|> ValueOption.defaultWith (fun _ -> failf $"Compute shader does not use push constant '{name}'.")
30+
31+
let writer =
32+
match UniformWriters.tryGetWriter 0 field.ufType inputType with
33+
| Result.Ok writer -> writer
34+
| Result.Error msg -> failf $"Cannot get writer for compute constant '{name}': {msg}"
35+
36+
let size = GLSLType.sizeof field.ufType
37+
38+
member _.Size = size
39+
40+
member _.Write(stream: VKVM.CommandStream, value: obj, buffer: nativeint) =
41+
writer.WriteUnsafeValue(value, buffer)
42+
stream.PushConstants(shader.PipelineLayout.Handle, layout.StageFlags, uint32 field.ufOffset, uint32 size, buffer) |> ignore
43+
44+
interface IComputeConstant with
45+
member _.Shader = shader
46+
member _.Name = name
47+
48+
type PushConstant<'T>(shader: IComputeShader, name: Symbol) =
49+
inherit PushConstant(shader, name, typeof<'T>)
50+
interface IComputeConstant<'T>
51+
2252
type ResourceManager with
2353

2454
member x.CreateComputeInputBinding(shader : IComputeShader, inputs : IUniformProvider) =
@@ -28,15 +58,8 @@ module internal ComputeTaskInternals =
2858
let sets = x.CreateDescriptorSets(shader.PipelineLayout, provider)
2959
x.CreateDescriptorSetBinding(VkPipelineBindPoint.Compute, shader.PipelineLayout, sets)
3060

31-
let pushConstants =
32-
shader.PipelineLayout.PushConstants |> ValueOption.map (fun pc ->
33-
x.CreatePushConstants(pc, provider)
34-
)
35-
3661
{ Shader = shader
37-
DescriptorSets = descriptorSets
38-
PushConstants = pushConstants }
39-
62+
DescriptorSets = descriptorSets }
4063

4164
[<RequireQualifiedAccess>]
4265
type private HostCommand =
@@ -59,11 +82,16 @@ module internal ComputeTaskInternals =
5982

6083
type private CompilerState =
6184
{
62-
Commands : CompiledCommand list
63-
UsedImages : HashSet<Image>
64-
ImageLayouts : HashMap<Image, VkImageLayout>
85+
Commands : CompiledCommand list
86+
UsedImages : HashSet<Image>
87+
ImageLayouts : HashMap<Image, VkImageLayout>
88+
ConstantBuffers : nativeptr<uint8> list
6589
}
6690

91+
member this.Free() =
92+
for cmd in this.Commands do cmd.Dispose()
93+
for buf in this.ConstantBuffers do NativePtr.free buf
94+
6795
type private ICompiledTask =
6896
abstract member State : CompilerState
6997

@@ -93,9 +121,10 @@ module internal ComputeTaskInternals =
93121
module private CompilerState =
94122

95123
let empty =
96-
{ Commands = []
97-
UsedImages = HashSet.empty
98-
ImageLayouts = HashMap.empty }
124+
{ Commands = []
125+
UsedImages = HashSet.empty
126+
ImageLayouts = HashMap.empty
127+
ConstantBuffers = [] }
99128

100129
let stream =
101130
State.custom (fun s ->
@@ -127,7 +156,7 @@ module internal ComputeTaskInternals =
127156
State.modify (fun s -> { s with UsedImages = s.UsedImages |> HashSet.add image })
128157

129158
let usedImages =
130-
State.get |> State.map (fun s -> s.UsedImages)
159+
State.get |> State.map _.UsedImages
131160

132161
let inline layout (image : Image) =
133162
State.get |> State.map (fun s ->
@@ -150,6 +179,12 @@ module internal ComputeTaskInternals =
150179
return oldLayout
151180
}
152181

182+
let inline constantBuffer (constant: PushConstant) =
183+
State.custom (fun s ->
184+
let buffer = NativePtr.alloc constant.Size
185+
{ s with ConstantBuffers = buffer :: s.ConstantBuffers }, buffer
186+
)
187+
153188
[<AutoOpen>]
154189
module private CommandStreamExtensions =
155190

@@ -262,14 +297,22 @@ module internal ComputeTaskInternals =
262297
let input = unbox<ComputeInputBinding> input
263298
let! stream = CompilerState.stream
264299
stream.IndirectBindDescriptorSets(input.DescriptorSets.Pointer) |> ignore
265-
match input.PushConstants with
266-
| ValueSome pc -> stream.PushConstants(input.Shader.PipelineLayout.Handle, pc.Handle) |> ignore
267-
| _ -> ()
300+
301+
| ComputeCommand.SetConstantCmd (constant, value) ->
302+
let constant = unbox<PushConstant> constant
303+
let! stream = CompilerState.stream
304+
let! buffer = CompilerState.constantBuffer constant
305+
constant.Write(stream, value, buffer.Address)
268306

269307
| ComputeCommand.DispatchCmd groups ->
270308
let! stream = CompilerState.stream
271309
stream.Dispatch(uint32 groups.X, uint32 groups.Y, uint32 groups.Z) |> ignore
272310

311+
| ComputeCommand.DispatchIndirectCmd (indirectBuffer, offset) ->
312+
let! stream = CompilerState.stream
313+
let indirectBuffer = indirectBuffer |> unbox<Buffer>
314+
stream.DispatchIndirect(indirectBuffer.Handle, offset) |> ignore
315+
273316
| ComputeCommand.ExecuteCmd other ->
274317
let compiled = unbox<ICompiledTask> other
275318
do! restoreLayouts compiled.State.UsedImages
@@ -415,7 +458,6 @@ module internal ComputeTaskInternals =
415458
failf "unknown input binding type %A" (input.GetType())
416459

417460
resources.Add input.DescriptorSets
418-
input.PushConstants |> ValueOption.iter resources.Add
419461
inputs.[index] <- input
420462

421463
ComputeCommand.SetInputCmd input
@@ -469,7 +511,6 @@ module internal ComputeTaskInternals =
469511
// This way nothing will be released if the input just moved in the command list
470512
for input in removedInputs do
471513
resources.Remove input.DescriptorSets
472-
input.PushConstants |> ValueOption.iter resources.Remove
473514

474515
// Update all hooked compute programs
475516
let mutable changed = deltas.Count > 0
@@ -479,7 +520,7 @@ module internal ComputeTaskInternals =
479520

480521
// Compile updated command list
481522
if changed then
482-
for c in compiled.Commands do c.Dispose()
523+
compiled.Free()
483524
compiled <- ComputeCommand.compile queueFlags commands
484525
true
485526
else
@@ -488,7 +529,6 @@ module internal ComputeTaskInternals =
488529
member x.Dispose() =
489530
for KeyValue(_, input) in inputs do
490531
resources.Remove input.DescriptorSets
491-
input.PushConstants |> ValueOption.iter resources.Remove
492532
inputs.Clear()
493533

494534
for KeyValue(_, task) in nested do
@@ -497,7 +537,7 @@ module internal ComputeTaskInternals =
497537

498538
hooked.Clear()
499539
commands <- IndexList.empty
500-
for c in compiled.Commands do c.Dispose()
540+
compiled.Free()
501541
compiled <- CompilerState.empty
502542

503543
interface IDisposable with

src/Aardvark.Rendering.Vulkan/Runtime/Runtime.fs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,9 @@ type Runtime(device : Device) as this =
451451
member x.CreateInputBinding(shader : IComputeShader, inputs : IUniformProvider) : IComputeInputBinding =
452452
manager.CreateComputeInputBinding(shader, inputs)
453453

454+
member x.GetComputeConstant<'T>(shader : IComputeShader, name : Symbol) : IComputeConstant<'T> =
455+
PushConstant<'T>(shader, name)
456+
454457
member x.CompileCompute (commands : alist<ComputeCommand>) =
455458
new ComputeTask(manager, commands) :> IComputeTask
456459

@@ -561,6 +564,9 @@ type Runtime(device : Device) as this =
561564
member x.CreateInputBinding(shader : IComputeShader, inputs : IUniformProvider) =
562565
x.CreateInputBinding(shader, inputs)
563566

567+
member x.GetComputeConstant<'T>(shader : IComputeShader, name : Symbol) =
568+
x.GetComputeConstant<'T>(shader, name)
569+
564570
member x.CompileCompute (commands : alist<ComputeCommand>) =
565571
x.CompileCompute commands
566572

src/Aardvark.Rendering/Runtime/Compute/Compute.fs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,21 @@ and IComputeShader =
2626
and IComputeInputBinding =
2727
abstract member Shader : IComputeShader
2828

29+
and IComputeConstant =
30+
abstract member Shader : IComputeShader
31+
abstract member Name : Symbol
32+
33+
and IComputeConstant<'T> =
34+
inherit IComputeConstant
35+
2936
and IComputeRuntime =
3037
inherit IBufferRuntime
3138
inherit ITextureRuntime
3239
abstract member ContextLock : IDisposable
3340
abstract member MaxLocalSize : V3i
3441
abstract member CreateComputeShader : shader: FShade.ComputeShader -> IComputeShader
3542
abstract member CreateInputBinding : shader: IComputeShader * inputs: IUniformProvider -> IComputeInputBinding
43+
abstract member GetComputeConstant<'T> : shader: IComputeShader * name: Symbol -> IComputeConstant<'T>
3644
abstract member CompileCompute : commands: alist<ComputeCommand> -> IComputeTask
3745

3846
and [<RequireQualifiedAccess>]
@@ -44,7 +52,9 @@ and [<RequireQualifiedAccess>]
4452
ComputeCommand =
4553
| BindCmd of shader: IComputeShader
4654
| SetInputCmd of input: IComputeInputBinding
55+
| SetConstantCmd of constant: IComputeConstant * data: obj
4756
| DispatchCmd of groups: V3i
57+
| DispatchIndirectCmd of indirectBuffer: IBackendBuffer * offset: uint64
4858
| ExecuteCmd of task: IComputeTask
4959
| CopyBufferCmd of src: IBufferRange * dst: IBufferRange
5060
| DownloadBufferCmd of src: IBufferRange * dst: HostMemory
@@ -66,6 +76,9 @@ and [<RequireQualifiedAccess>]
6676
static member SetInput(input : IComputeInputBinding) =
6777
ComputeCommand.SetInputCmd input
6878

79+
static member SetConstant<'T>(constant : IComputeConstant<'T>, value : 'T) =
80+
ComputeCommand.SetConstantCmd(constant, value)
81+
6982
static member Dispatch(groups : V3i) =
7083
ComputeCommand.DispatchCmd groups
7184

@@ -75,6 +88,9 @@ and [<RequireQualifiedAccess>]
7588
static member Dispatch(groups : int) =
7689
ComputeCommand.DispatchCmd (V3i(groups, 1, 1))
7790

91+
static member DispatchIndirect(indirectBuffer : IBackendBuffer, [<Optional; DefaultParameterValue(0UL)>] offset : uint64) =
92+
ComputeCommand.DispatchIndirectCmd(indirectBuffer, offset)
93+
7894
static member Execute(task : IComputeTask) =
7995
ComputeCommand.ExecuteCmd task
8096

src/Aardvark.Rendering/Runtime/Compute/ComputeExtensions.fs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ type IComputeRuntimeExtensions private() =
6363
static member Run(runtime : IComputeRuntime, commands : list<ComputeCommand>) =
6464
runtime.Run(commands, RenderToken.Empty)
6565

66+
[<Extension>]
67+
static member GetComputeConstant<'T>(runtime : IComputeRuntime, shader : IComputeShader, name : string) =
68+
runtime.GetComputeConstant<'T>(shader, Sym.ofString name)
69+
6670

6771
[<AbstractClass; Sealed; Extension>]
6872
type IComputeShaderExtensions private() =
@@ -71,6 +75,14 @@ type IComputeShaderExtensions private() =
7175
static member CreateInputBinding(shader : IComputeShader, inputs : IUniformProvider) =
7276
shader.Runtime.CreateInputBinding(shader, inputs)
7377

78+
[<Extension>]
79+
static member GetConstant<'T>(shader : IComputeShader, name : Symbol) =
80+
shader.Runtime.GetComputeConstant<'T>(shader, name)
81+
82+
[<Extension>]
83+
static member GetConstant<'T>(shader : IComputeShader, name : string) =
84+
shader.Runtime.GetComputeConstant<'T>(shader, name)
85+
7486
[<Extension>]
7587
static member Invoke(shader : IComputeShader, groupCount : V3i, input : IComputeInputBinding, renderToken : RenderToken) =
7688
shader.Runtime.Run([

src/Tests/Aardvark.Rendering.Tests/Aardvark.Rendering.Tests.fsproj

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@
7777
<Compile Include="Tests\Compute\Sorting.fs" />
7878
<Compile Include="Tests\Compute\Jpeg.fs" />
7979
<Compile Include="Tests\Compute\MutableInputBinding.fs" />
80+
<Compile Include="Tests\Compute\PushConstants.fs" />
81+
<Compile Include="Tests\Compute\DispatchIndirect.fs" />
8082
<Compile Include="Tests\Compute\ComputeTests.fs" />
8183
<Compile Include="Tests\Other\Camera.fs" />
8284
<Compile Include="Tests\Other\IndexedGeometryTests.fs" />

src/Tests/Aardvark.Rendering.Tests/Tests/Compute/ComputeTests.fs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ module ``Compute Tests`` =
1313
ComputeSorting.tests
1414
ComputeJpeg.tests
1515
MutableInputBinding.tests
16+
PushConstants.tests
17+
DispatchIndirect.tests
1618
]
1719

1820
[<Tests>]

0 commit comments

Comments
 (0)