aboutsummaryrefslogtreecommitdiff
path: root/contrib/dxc_2025_07_14/inc/hlsl/vk/khr/cooperative_matrix.impl
diff options
context:
space:
mode:
author3gg <3gg@shellblade.net>2025-12-02 16:39:36 -0800
committer3gg <3gg@shellblade.net>2025-12-02 16:39:36 -0800
commit6c8ae19be66cee247980a48e736a4e05d14de179 (patch)
treed860767907bf0cbe17ec66422e11bea700cf56d9 /contrib/dxc_2025_07_14/inc/hlsl/vk/khr/cooperative_matrix.impl
parent8f594c8ebd11f0e5f8a0c6369c3fe7383d250cbe (diff)
Immediate-mode renderer, triangle demo, shader compilation in cmake, Agility SDKHEADmain
Diffstat (limited to 'contrib/dxc_2025_07_14/inc/hlsl/vk/khr/cooperative_matrix.impl')
-rw-r--r--contrib/dxc_2025_07_14/inc/hlsl/vk/khr/cooperative_matrix.impl377
1 files changed, 377 insertions, 0 deletions
diff --git a/contrib/dxc_2025_07_14/inc/hlsl/vk/khr/cooperative_matrix.impl b/contrib/dxc_2025_07_14/inc/hlsl/vk/khr/cooperative_matrix.impl
new file mode 100644
index 0000000..2acae8e
--- /dev/null
+++ b/contrib/dxc_2025_07_14/inc/hlsl/vk/khr/cooperative_matrix.impl
@@ -0,0 +1,377 @@
1// Copyright (c) 2024 Google LLC
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
7#include "vk/opcode_selector.h"
8
9template <typename ResultType, typename ComponentType>
10[[vk::ext_instruction(/* OpMatrixTimesScalar */ 143)]] ResultType
11__builtin_spv_MatrixTimesScalar(ResultType a, ComponentType b);
12
13template <typename ComponentType, vk::Scope scope, uint rows, uint columns,
14 vk::CooperativeMatrixUse use>
15[[vk::ext_instruction(/* OpCompositeExtract */ 81)]] ComponentType
16__builtin_spv_ExtractFromCooperativeMatrix(
17 typename vk::khr::CooperativeMatrix<ComponentType, scope, rows, columns,
18 use>::SpirvMatrixType matrix,
19 uint32_t index);
20
21template <typename CoopMatrixType, typename ComponentType>
22[[vk::ext_instruction(/* OpCompositeConstruct */ 80)]] CoopMatrixType
23__builtin_spv_ConstructCooperativeMatrix(ComponentType value);
24
25template <class ResultPointerType, class BaseType>
26[[vk::ext_instruction(/* OpAccessChain */ 65)]] ResultPointerType
27__builtin_spv_AccessChain([[vk::ext_reference]] BaseType base, uint32_t index);
28
29template <class ObjectType, class PointerType>
30[[vk::ext_instruction(/* OpLoad */ 61)]] ObjectType
31__builtin_spv_LoadPointer(PointerType base);
32
33template <class PointerType, class ObjectType>
34[[vk::ext_instruction(/* OpLoad */ 62)]] void
35__builtin_spv_StorePointer(PointerType base, ObjectType object);
36
37template <typename ComponentType, vk::Scope scope, uint rows, uint columns,
38 vk::CooperativeMatrixUse use>
39[[vk::ext_instruction(/* OpCompositeInsert */ 82)]]
40typename vk::khr::CooperativeMatrix<ComponentType, scope, rows, columns,
41 use>::SpirvMatrixType
42__builtin_spv_InsertIntoCooperativeMatrix(
43 ComponentType value,
44 typename vk::khr::CooperativeMatrix<ComponentType, scope, rows, columns,
45 use>::SpirvMatrixType matrix,
46 uint32_t index);
47
48// Define the load and store instructions
49template <typename ResultType, typename PointerType>
50[[vk::ext_instruction(/* OpCooperativeMatrixLoadKHR */ 4457)]] ResultType
51__builtin_spv_CooperativeMatrixLoadKHR(
52 [[vk::ext_reference]] PointerType pointer,
53 vk::CooperativeMatrixLayout memory_layout, uint stride,
54 [[vk::ext_literal]] uint32_t memory_operand);
55
56template <typename ResultType, typename PointerType>
57[[vk::ext_instruction(/* OpCooperativeMatrixLoadKHR */ 4457)]] ResultType
58__builtin_spv_CooperativeMatrixLoadKHR(
59 [[vk::ext_reference]] PointerType pointer,
60 vk::CooperativeMatrixLayout memory_layout, uint stride,
61 [[vk::ext_literal]] uint32_t memory_operand, vk::Scope scope);
62
63template <typename ResultType, typename PointerType>
64[[vk::ext_instruction(/* OpCooperativeMatrixLoadKHR */ 4457)]] ResultType
65__builtin_spv_CooperativeMatrixWorkgroupLoadKHR(
66 vk::WorkgroupSpirvPointer<PointerType> pointer,
67 vk::CooperativeMatrixLayout memory_layout, uint stride,
68 [[vk::ext_literal]] uint32_t memory_operand, vk::Scope scope);
69
70template <typename ObjectType, typename PointerType>
71[[vk::ext_instruction(/* OpCooperativeMatrixStoreKHR */ 4458)]] void
72__builtin_spv_CooperativeMatrixStoreKHR(
73 [[vk::ext_reference]] PointerType pointer, ObjectType object,
74 vk::CooperativeMatrixLayout memory_layout, uint stride,
75 [[vk::ext_literal]] uint32_t memory_operand, vk::Scope scope);
76
77template <typename ObjectType, typename PointerType>
78[[vk::ext_instruction(/* OpCooperativeMatrixStoreKHR */ 4458)]] void
79__builtin_spv_CooperativeMatrixStoreKHR(
80 [[vk::ext_reference]] PointerType pointer, ObjectType object,
81 vk::CooperativeMatrixLayout memory_layout, uint stride,
82 [[vk::ext_literal]] uint32_t memory_operand);
83
84template <typename ObjectType, typename PointerType>
85[[vk::ext_instruction(/* OpCooperativeMatrixStoreKHR */ 4458)]] void
86__builtin_spv_CooperativeMatrixWorkgroupStoreKHR(
87 vk::WorkgroupSpirvPointer<PointerType> pointer, ObjectType object,
88 vk::CooperativeMatrixLayout memory_layout, uint stride,
89 [[vk::ext_literal]] uint32_t memory_operand, vk::Scope scope);
90
91// We cannot define `OpCooperativeMatrixLengthKHR` using ext_instruction because
92// one of the operands is a type id. This builtin will have specific code in the
93// compiler to expand it.
94template <class MatrixType> uint __builtin_spv_CooperativeMatrixLengthKHR();
95
96// Arithmetic Instructions
97template <typename ResultType, typename MatrixTypeA, typename MatrixTypeB,
98 typename MatrixTypeC>
99[[vk::ext_instruction(/* OpCooperativeMatrixMulAddKHR */ 4459)]] ResultType
100__builtin_spv_CooperativeMatrixMulAddKHR(MatrixTypeA a, MatrixTypeB b,
101 MatrixTypeC c,
102 [[vk::ext_literal]] int operands);
103namespace vk {
104namespace khr {
105
106template <class ComponentType, Scope scope, uint rows, uint columns,
107 CooperativeMatrixUse use>
108template <class NewComponentType>
109CooperativeMatrix<NewComponentType, scope, rows, columns, use>
110CooperativeMatrix<ComponentType, scope, rows, columns, use>::cast() {
111 using ResultType =
112 CooperativeMatrix<NewComponentType, scope, rows, columns, use>;
113 ResultType result;
114 result._matrix = util::ConversionSelector<ComponentType, NewComponentType>::
115 template Convert<typename ResultType::SpirvMatrixType>(_matrix);
116 return result;
117}
118
119template <class ComponentType, Scope scope, uint rows, uint columns,
120 CooperativeMatrixUse use>
121CooperativeMatrix<ComponentType, scope, rows, columns, use>
122CooperativeMatrix<ComponentType, scope, rows, columns, use>::negate() {
123 CooperativeMatrix result;
124 result._matrix = util::ArithmeticSelector<ComponentType>::Negate(_matrix);
125 return result;
126}
127
128template <class ComponentType, Scope scope, uint rows, uint columns,
129 CooperativeMatrixUse use>
130CooperativeMatrix<ComponentType, scope, rows, columns, use>
131CooperativeMatrix<ComponentType, scope, rows, columns, use>::operator+(
132 CooperativeMatrix other) {
133 CooperativeMatrix result;
134 result._matrix =
135 util::ArithmeticSelector<ComponentType>::Add(_matrix, other._matrix);
136 return result;
137}
138
139template <class ComponentType, Scope scope, uint rows, uint columns,
140 CooperativeMatrixUse use>
141CooperativeMatrix<ComponentType, scope, rows, columns, use>
142CooperativeMatrix<ComponentType, scope, rows, columns, use>::operator-(
143 CooperativeMatrix other) {
144 CooperativeMatrix result;
145 result._matrix =
146 util::ArithmeticSelector<ComponentType>::Sub(_matrix, other._matrix);
147 return result;
148}
149
150template <class ComponentType, Scope scope, uint rows, uint columns,
151 CooperativeMatrixUse use>
152CooperativeMatrix<ComponentType, scope, rows, columns, use>
153CooperativeMatrix<ComponentType, scope, rows, columns, use>::operator*(
154 CooperativeMatrix other) {
155 CooperativeMatrix result;
156 result._matrix =
157 util::ArithmeticSelector<ComponentType>::Mul(_matrix, other._matrix);
158 return result;
159}
160
161template <class ComponentType, Scope scope, uint rows, uint columns,
162 CooperativeMatrixUse use>
163CooperativeMatrix<ComponentType, scope, rows, columns, use>
164CooperativeMatrix<ComponentType, scope, rows, columns, use>::operator/(
165 CooperativeMatrix other) {
166 CooperativeMatrix result;
167 result._matrix =
168 util::ArithmeticSelector<ComponentType>::Div(_matrix, other._matrix);
169 return result;
170}
171
172template <class ComponentType, Scope scope, uint rows, uint columns,
173 CooperativeMatrixUse use>
174CooperativeMatrix<ComponentType, scope, rows, columns, use>
175CooperativeMatrix<ComponentType, scope, rows, columns, use>::operator*(
176 ComponentType scalar) {
177 CooperativeMatrix result;
178 result._matrix = __builtin_spv_MatrixTimesScalar(_matrix, scalar);
179 return result;
180}
181
182template <class ComponentType, Scope scope, uint rows, uint columns,
183 CooperativeMatrixUse use>
184template <uint32_t memoryAccessOperands, CooperativeMatrixLayout layout,
185 class Type>
186void CooperativeMatrix<ComponentType, scope, rows, columns, use>::Store(
187 WorkgroupSpirvPointer<Type> data, uint32_t stride) {
188 __builtin_spv_CooperativeMatrixWorkgroupStoreKHR(
189 data, _matrix, layout, stride,
190 memoryAccessOperands | MemoryAccessNonPrivatePointerMask |
191 MemoryAccessMakePointerAvailableMask,
192 ScopeWorkgroup);
193}
194
195template <class ComponentType, Scope scope, uint rows, uint columns,
196 CooperativeMatrixUse use>
197template <uint32_t memoryAccessOperands, CooperativeMatrixLayout layout,
198 class Type>
199void CooperativeMatrix<ComponentType, scope, rows, columns, use>::Store(
200 RWStructuredBuffer<Type> data, uint32_t index, uint32_t stride) {
201 __builtin_spv_CooperativeMatrixStoreKHR(data[index], _matrix, layout, stride,
202 memoryAccessOperands);
203}
204
205template <class ComponentType, Scope scope, uint rows, uint columns,
206 CooperativeMatrixUse use>
207template <uint32_t memoryAccessOperands, CooperativeMatrixLayout layout,
208 class Type>
209void CooperativeMatrix<ComponentType, scope, rows, columns, use>::CoherentStore(
210 globallycoherent RWStructuredBuffer<Type> data, uint32_t index,
211 uint32_t stride) {
212 __builtin_spv_CooperativeMatrixStoreKHR(
213 data[index], _matrix, layout, stride,
214 memoryAccessOperands | MemoryAccessNonPrivatePointerMask |
215 MemoryAccessMakePointerAvailableMask,
216 ScopeQueueFamily);
217}
218
219template <class ComponentType, Scope scope, uint rows, uint columns,
220 CooperativeMatrixUse use>
221template <uint32_t memoryAccessOperands, CooperativeMatrixLayout layout,
222 class Type>
223CooperativeMatrix<ComponentType, scope, rows, columns, use>
224CooperativeMatrix<ComponentType, scope, rows, columns, use>::Load(
225 vk::WorkgroupSpirvPointer<Type> buffer, uint32_t stride) {
226 CooperativeMatrix result;
227 result._matrix =
228 __builtin_spv_CooperativeMatrixWorkgroupLoadKHR<SpirvMatrixType>(
229 buffer, layout, stride,
230 memoryAccessOperands | MemoryAccessNonPrivatePointerMask |
231 MemoryAccessMakePointerVisibleMask,
232 ScopeWorkgroup);
233 return result;
234}
235
236template <class ComponentType, Scope scope, uint rows, uint columns,
237 CooperativeMatrixUse use>
238template <uint32_t memoryAccessOperands, CooperativeMatrixLayout layout,
239 class Type>
240CooperativeMatrix<ComponentType, scope, rows, columns, use>
241CooperativeMatrix<ComponentType, scope, rows, columns, use>::Load(
242 RWStructuredBuffer<Type> buffer, uint32_t index, uint32_t stride) {
243 CooperativeMatrix result;
244 result._matrix = __builtin_spv_CooperativeMatrixLoadKHR<SpirvMatrixType>(
245 buffer[index], layout, stride, memoryAccessOperands);
246 return result;
247}
248
249template <class ComponentType, Scope scope, uint rows, uint columns,
250 CooperativeMatrixUse use>
251template <uint32_t memoryAccessOperands, CooperativeMatrixLayout layout,
252 class Type>
253CooperativeMatrix<ComponentType, scope, rows, columns, use>
254CooperativeMatrix<ComponentType, scope, rows, columns, use>::CoherentLoad(
255 RWStructuredBuffer<Type> buffer, uint32_t index, uint32_t stride) {
256 CooperativeMatrix result;
257 result._matrix = __builtin_spv_CooperativeMatrixLoadKHR<SpirvMatrixType>(
258 buffer[index], layout, stride,
259 memoryAccessOperands | MemoryAccessNonPrivatePointerMask |
260 MemoryAccessMakePointerVisibleMask,
261 ScopeQueueFamily);
262 return result;
263}
264
265template <class ComponentType, Scope scope, uint rows, uint columns,
266 CooperativeMatrixUse use>
267template <uint32_t memoryAccessOperands, CooperativeMatrixLayout layout,
268 class Type>
269CooperativeMatrix<ComponentType, scope, rows, columns, use>
270CooperativeMatrix<ComponentType, scope, rows, columns, use>::Load(
271 StructuredBuffer<Type> buffer, uint32_t index, uint32_t stride) {
272 CooperativeMatrix result;
273 result._matrix = __builtin_spv_CooperativeMatrixLoadKHR<SpirvMatrixType>(
274 buffer[index], layout, stride, MemoryAccessMaskNone);
275 return result;
276}
277
278template <class ComponentType, Scope scope, uint rows, uint columns,
279 CooperativeMatrixUse use>
280CooperativeMatrix<ComponentType, scope, rows, columns, use>
281CooperativeMatrix<ComponentType, scope, rows, columns, use>::Splat(
282 ComponentType v) {
283 CooperativeMatrix result;
284 result._matrix = __builtin_spv_ConstructCooperativeMatrix<SpirvMatrixType>(v);
285 return result;
286}
287
288template <class ComponentType, Scope scope, uint rows, uint columns,
289 CooperativeMatrixUse use>
290uint CooperativeMatrix<ComponentType, scope, rows, columns, use>::GetLength() {
291 return __builtin_spv_CooperativeMatrixLengthKHR<SpirvMatrixType>();
292}
293
294template <class ComponentType, Scope scope, uint rows, uint columns,
295 CooperativeMatrixUse use>
296ComponentType CooperativeMatrix<ComponentType, scope, rows, columns, use>::Get(
297 uint32_t index) {
298 // clang-format off
299 using ComponentPtr = vk::SpirvOpaqueType<
300 /* OpTypePointer */ 32,
301 /* function storage class */ vk::Literal<vk::integral_constant<uint, 7> >,
302 ComponentType>;
303 // clang-format on
304 ComponentPtr ptr = __builtin_spv_AccessChain<ComponentPtr>(_matrix, index);
305 return __builtin_spv_LoadPointer<ComponentType>(ptr);
306}
307
308template <class ComponentType, Scope scope, uint rows, uint columns,
309 CooperativeMatrixUse use>
310void CooperativeMatrix<ComponentType, scope, rows, columns, use>::Set(
311 ComponentType value, uint32_t index) {
312 // clang-format off
313 using ComponentPtr = vk::SpirvOpaqueType<
314 /* OpTypePointer */ 32,
315 /* function storage class */ vk::Literal<vk::integral_constant<uint, 7> >,
316 ComponentType>;
317 // clang-format on
318 ComponentPtr ptr = __builtin_spv_AccessChain<ComponentPtr>(_matrix, index);
319 return __builtin_spv_StorePointer(ptr, value);
320}
321
322template <typename ComponentType, Scope scope, uint rows, uint columns, uint K>
323CooperativeMatrixAccumulator<ComponentType, scope, rows, columns>
324cooperativeMatrixMultiplyAdd(
325 CooperativeMatrixA<ComponentType, scope, rows, K> a,
326 CooperativeMatrixB<ComponentType, scope, K, columns> b,
327 CooperativeMatrixAccumulator<ComponentType, scope, rows, columns> c) {
328
329 const vk::CooperativeMatrixOperandsMask allSignedComponents =
330 vk::CooperativeMatrixOperandsMatrixASignedComponentsKHRMask |
331 vk::CooperativeMatrixOperandsMatrixBSignedComponentsKHRMask |
332 vk::CooperativeMatrixOperandsMatrixCSignedComponentsKHRMask |
333 vk::CooperativeMatrixOperandsMatrixResultSignedComponentsKHRMask;
334
335 const vk::CooperativeMatrixOperandsMask operands =
336 (vk::CooperativeMatrixOperandsMask)(
337 a.hasSignedIntegerComponentType
338 ? allSignedComponents
339 : vk::CooperativeMatrixOperandsMaskNone);
340
341 CooperativeMatrixAccumulator<ComponentType, scope, rows, columns> result;
342 result._matrix = __builtin_spv_CooperativeMatrixMulAddKHR<
343 typename CooperativeMatrixAccumulator<ComponentType, scope, rows,
344 columns>::SpirvMatrixType>(
345 a._matrix, b._matrix, c._matrix, operands);
346 return result;
347}
348
349template <typename ComponentType, Scope scope, uint rows, uint columns, uint K>
350CooperativeMatrixAccumulator<ComponentType, scope, rows, columns>
351cooperativeMatrixSaturatingMultiplyAdd(
352 CooperativeMatrixA<ComponentType, scope, rows, K> a,
353 CooperativeMatrixB<ComponentType, scope, K, columns> b,
354 CooperativeMatrixAccumulator<ComponentType, scope, rows, columns> c) {
355
356 const vk::CooperativeMatrixOperandsMask allSignedComponents =
357 vk::CooperativeMatrixOperandsMatrixASignedComponentsKHRMask |
358 vk::CooperativeMatrixOperandsMatrixBSignedComponentsKHRMask |
359 vk::CooperativeMatrixOperandsMatrixCSignedComponentsKHRMask |
360 vk::CooperativeMatrixOperandsMatrixResultSignedComponentsKHRMask |
361 vk::CooperativeMatrixOperandsSaturatingAccumulationKHRMask;
362
363 const vk::CooperativeMatrixOperandsMask operands =
364 (vk::CooperativeMatrixOperandsMask)(
365 a.hasSignedIntegerComponentType
366 ? allSignedComponents
367 : vk::CooperativeMatrixOperandsSaturatingAccumulationKHRMask);
368 CooperativeMatrixAccumulator<ComponentType, scope, rows, columns> result;
369 result._matrix = __builtin_spv_CooperativeMatrixMulAddKHR<
370 typename CooperativeMatrixAccumulator<ComponentType, scope, rows,
371 columns>::SpirvMatrixType>(
372 a._matrix, b._matrix, c._matrix, operands);
373 return result;
374}
375
376} // namespace khr
377} // namespace vk