Bitmap topk#3009
Conversation
Signed-off-by: tdophung <tdophung@nvidia.com>
Without this XLA_FFI_REGISTER_ENUM_ATTR_DECODING the FFI handler templates cannot instantiate AttrDecoding<JAXX_Routing_Map_Format>, breaking the JAX build in router.cpp. Signed-off-by: tdophung <tdophung@nvidia.com>
for more information, see https://pre-commit.ci
…g the routing map type enum Signed-off-by: tdophung <tdophung@nvidia.com>
Greptile SummaryThis PR adds BITMAP_U8 as a new output format for the routing map in the fused topk router kernels, alongside the existing BYTEMAP format. A new
Confidence Score: 5/5The change is additive and backward-compatible: the V1 API is preserved as a BYTEMAP-delegating wrapper, and BYTEMAP remains the default everywhere. Both CUDA kernel paths use correct memory stride arithmetic, proper __syncwarp synchronization before reading the shmem bitmap accumulator, and the little-endian uint32 to uint8 reinterpret_cast is safe on all CUDA devices. The Python layer correctly saves the flat routing_map for backward and reshapes it for callers. No files require special attention. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["fused_topk_with_score_function(logits, ..., routing_map_format)"] --> B{routing_map_format?}
B -->|BYTEMAP| C["allocate bool[T, E]"]
B -->|BITMAP_U8| D["allocate uint8[T, ceil(E/8)]"]
C --> E["V2 forward (BYTEMAP)"]
D --> F["V2 forward (BITMAP_U8)"]
E --> G["CUDA kernel: write 0/1 bytes to global routing_map"]
F --> H["CUDA kernel: atomicOr into shmem uint32, byte-copy to global uint8"]
G --> I["Return probs[T,E], routing_map bool[T,E]"]
H --> J["Return probs[T,E], routing_map uint8[T,ceil(E/8)]"]
I --> K["Backward: read routing_map[pos+i] != 0"]
J --> L["Backward: read (bitmap_row[i/8] >> i%8) & 1"]
Reviews (3): Last reviewed commit: "Merge branch 'main' into bitmap_topk" | Re-trigger Greptile |
Signed-off-by: tdophung <tdophung@nvidia.com>
for more information, see https://pre-commit.ci
jberchtold-nvidia
left a comment
There was a problem hiding this comment.
LGTM, thanks! I reviewed core and JAX changes but not PyTorch
|
/te_ci |
| } | ||
|
|
||
|
|
||
| def _validate_routing_map_format( |
There was a problem hiding this comment.
Could we use regular ints instead of the enums and only do this just before calling the tex function? This function screams CPU overhead so can we simplify it, to not e.g. do lower on the string every time etc.?
| scaling_factor: Optional[float], | ||
| score_function: str, | ||
| expert_bias: Optional[torch.Tensor], | ||
| routing_map_format: "RoutingMapFormat", |
There was a problem hiding this comment.
Why do you put the string here as the type?
| # Save the flat 2D routing_map for backward (kernel indexes by | ||
| # num_tokens x trailing_dim), then restore the leading dims of the | ||
| # input on the returned outputs. The trailing dim of routing_map | ||
| # depends on the format: num_experts for BYTEMAP, ceil(num_experts/8) | ||
| # for BITMAP_U8. |
There was a problem hiding this comment.
Please don't do that and instead use flat_first_dim and flat_last_dim on the C++ side to avoid Python overheads here - so the routing map should be created with the proper shape and we should not do the view afterwards.
| scores = scores.view(tensor_shape) | ||
| routing_map = routing_map.view(*tensor_shape[:-1], routing_map.shape[-1]) |
There was a problem hiding this comment.
As before - please don't do that, let's avoid the views on the Python side.
Description
Add a new path to our topk kernel to output the routing map in bitmap format instead of bytemap alone. The default still stay at bytemap so no regression for existing consumers downstream of this op. However, since the op now requires an additional arg to specify the routing map type (bytemap or bitmap), we introduce a V2 of the API to accomplish this, while keeping the original API the same not to break customers.
This helps NCCL EP not have to do the token_indices (sparse format) conversion to bitmap format for comms later.
Fixes #2999
Type of change
Changes
Checklist: