aboutsummaryrefslogtreecommitdiff
path: root/contrib/dxc_2025_07_14/inc/hlsl/vk/khr/cooperative_matrix.h
diff options
context:
space:
mode:
author3gg <3gg@shellblade.net>2025-12-02 16:39:36 -0800
committer3gg <3gg@shellblade.net>2025-12-02 16:39:36 -0800
commit6c8ae19be66cee247980a48e736a4e05d14de179 (patch)
treed860767907bf0cbe17ec66422e11bea700cf56d9 /contrib/dxc_2025_07_14/inc/hlsl/vk/khr/cooperative_matrix.h
parent8f594c8ebd11f0e5f8a0c6369c3fe7383d250cbe (diff)
Immediate-mode renderer, triangle demo, shader compilation in cmake, Agility SDKHEADmain
Diffstat (limited to 'contrib/dxc_2025_07_14/inc/hlsl/vk/khr/cooperative_matrix.h')
-rw-r--r--contrib/dxc_2025_07_14/inc/hlsl/vk/khr/cooperative_matrix.h275
1 files changed, 275 insertions, 0 deletions
diff --git a/contrib/dxc_2025_07_14/inc/hlsl/vk/khr/cooperative_matrix.h b/contrib/dxc_2025_07_14/inc/hlsl/vk/khr/cooperative_matrix.h
new file mode 100644
index 0000000..a53ab4c
--- /dev/null
+++ b/contrib/dxc_2025_07_14/inc/hlsl/vk/khr/cooperative_matrix.h
@@ -0,0 +1,275 @@
1// Copyright (c) 2024 Google LLC
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
7#ifndef _HLSL_VK_KHR_COOPERATIVE_MATRIX_H_
8#define _HLSL_VK_KHR_COOPERATIVE_MATRIX_H_
9
10#if __SPIRV_MAJOR_VERSION__ == 1 && __SPIRV_MINOR_VERSION__ < 6
11#error "CooperativeMatrix requires a minimum of SPIR-V 1.6"
12#endif
13
14#include "vk/spirv.h"
15
16namespace vk {
17namespace khr {
18
19// The base cooperative matrix class. The template arguments correspond to the
20// operands in the OpTypeCooperativeMatrixKHR instruction.
21template <typename ComponentType, Scope scope, uint rows, uint columns,
22 CooperativeMatrixUse use>
23class CooperativeMatrix {
24 template <class NewComponentType>
25 CooperativeMatrix<NewComponentType, scope, rows, columns, use> cast();
26
27 // Apply OpSNegate or OFNegate, depending on ComponentType, in a element by
28 // element manner.
29 CooperativeMatrix negate();
30
31 // Apply OpIAdd or OFAdd, depending on ComponentType, in a element by element
32 // manner.
33 CooperativeMatrix operator+(CooperativeMatrix other);
34
35 // Apply OpISub or OFSub, depending on ComponentType, in a element by element
36 // manner.
37 CooperativeMatrix operator-(CooperativeMatrix other);
38
39 // Apply OpIMul or OFMul, depending on ComponentType, in a element by element
40 // manner.
41 CooperativeMatrix operator*(CooperativeMatrix other);
42
43 // Apply OpSDiv, OpUDiv or OFDiv, depending on ComponentType, in a element by
44 // element manner.
45 CooperativeMatrix operator/(CooperativeMatrix other);
46
47 // Apply OpMatrixTimesScalar in a element by element manner.
48 CooperativeMatrix operator*(ComponentType scalar);
49
50 // Store the cooperative matrix using OpCooperativeMatrixStoreKHR to
51 // data using the given memory layout, stride, and memory access operands.
52 // `NonPrivatePointer` and `MakePointerAvailable` with the workgroup scope
53 // will be added to the memory access operands to make the memory coherent.
54 //
55 // This function uses a SPIR-V pointer because HLSL does not allow groupshared
56 // memory object to be passed by reference. The pointer is a hack to get
57 // around that.
58 //
59 // The layout and stride will be passed to the SPIR-V instruction as is. The
60 // precise meaning can be found in the specification for
61 // SPV_KHR_cooperative_matrix.
62 template <uint32_t memoryAccessOperands, CooperativeMatrixLayout layout,
63 class Type>
64 void Store(WorkgroupSpirvPointer<Type> data, uint32_t stride);
65
66 // Same as above, but uses MemoryAccessMaskNone for the memory access
67 // operands.
68 template <CooperativeMatrixLayout layout, class Type>
69 void Store(WorkgroupSpirvPointer<Type> data, uint32_t stride) {
70 Store<MemoryAccessMaskNone, layout>(data, stride);
71 }
72
73 // Store the cooperative matrix using OpCooperativeMatrixStoreKHR to
74 // data[index] using the given memory layout, stride, and memory access
75 // operands. The layout and stride will be passed to the SPIR-V instruction as
76 // is. The precise meaning can be found in the specification for
77 // SPV_KHR_cooperative_matrix.
78 template <uint32_t memoryAccessOperands, CooperativeMatrixLayout layout,
79 class Type>
80 void Store(RWStructuredBuffer<Type> data, uint32_t index, uint32_t stride);
81
82 // Same as above, but uses MemoryAccessMaskNone for the memory access
83 // operands.
84 template <CooperativeMatrixLayout layout, class Type>
85 void Store(RWStructuredBuffer<Type> data, uint32_t index, uint32_t stride) {
86 Store<MemoryAccessMaskNone, layout>(data, index, stride);
87 }
88
89 // Store the cooperative matrix using OpCooperativeMatrixStoreKHR to
90 // data[index] using the given memory layout, stride, and memory access
91 // operands. `NonPrivatePointer` and `MakePointerAvailable` with the
92 // QueueFamily scope will be added to the memory access operands to make the
93 // memory coherent.
94 //
95 // The layout and stride will be passed to the SPIR-V instruction as is. The
96 // precise meaning can be found in the specification for
97 // SPV_KHR_cooperative_matrix.
98 template <uint32_t memoryAccessOperands, CooperativeMatrixLayout layout,
99 class Type>
100 void CoherentStore(globallycoherent RWStructuredBuffer<Type> data,
101 uint32_t index, uint32_t stride);
102
103 // Same as above, but uses MemoryAccessMaskNone for the memory access operands
104 // template argument.
105 template <CooperativeMatrixLayout layout, class Type>
106 void CoherentStore(globallycoherent RWStructuredBuffer<Type> data,
107 uint32_t index, uint32_t stride) {
108 CoherentStore<MemoryAccessMaskNone, layout>(data, index, stride);
109 }
110
111 // Loads a cooperative matrix using OpCooperativeMatrixLoadKHR from
112 // data using the given memory layout, stride, and memory access operands.
113 // `NonPrivatePointer` and `MakePointerVisible` with the workgroup scope
114 // will be added to the memory access operands to make the memory coherent.
115 //
116 // This function uses a SPIR-V pointer because HLSL does not allow groupshared
117 // memory object to be passed by reference. The pointer is a hack to get
118 // around that.
119 //
120 // The layout and stride will be passed to the SPIR-V instruction as is. The
121 // precise meaning can be found in the specification for
122 // SPV_KHR_cooperative_matrix.
123 template <uint32_t memoryAccessOperands, CooperativeMatrixLayout layout,
124 class Type>
125 static CooperativeMatrix Load(WorkgroupSpirvPointer<Type> data,
126 uint32_t stride);
127
128 // Same as above, but uses MemoryAccessMaskNone for the memory access
129 // operands.
130 template <CooperativeMatrixLayout layout, class Type>
131 static CooperativeMatrix Load(WorkgroupSpirvPointer<Type> data,
132 uint32_t stride) {
133 return Load<MemoryAccessMaskNone, layout>(data, stride);
134 }
135
136 // Loads a cooperative matrix using OpCooperativeMatrixLoadKHR from
137 // data[index] using the given memory layout, stride, and memory access
138 // operands.
139 //
140 // The layout and stride will be passed to the SPIR-V instruction as is. The
141 // precise meaning can be found in the specification for
142 // SPV_KHR_cooperative_matrix.
143 template <uint32_t memoryAccessOperands, CooperativeMatrixLayout layout,
144 class Type>
145 static CooperativeMatrix Load(RWStructuredBuffer<Type> data, uint32_t index,
146 uint32_t stride);
147
148 // Same as above, but uses MemoryAccessMaskNone for the memory access
149 // operands.
150 template <CooperativeMatrixLayout layout, class Type>
151 static CooperativeMatrix Load(RWStructuredBuffer<Type> data, uint32_t index,
152 uint32_t stride) {
153 return Load<MemoryAccessMaskNone, layout>(data, index, stride);
154 }
155
156 // Loads a cooperative matrix using OpCooperativeMatrixLoadKHR from
157 // data[index] using the given memory layout, stride, and memory access
158 // operands. `NonPrivatePointer` and `MakePointerVisible` with the QueueFamily
159 // scope will be added to the memory access operands to make the memory
160 // coherent.
161 //
162 //
163 // The layout and stride will be passed to the SPIR-V instruction as is. The
164 // precise meaning can be found in the specification for
165 // SPV_KHR_cooperative_matrix.
166 template <uint32_t memoryAccessOperands, CooperativeMatrixLayout layout,
167 class Type>
168 static CooperativeMatrix
169 CoherentLoad(globallycoherent RWStructuredBuffer<Type> data, uint32_t index,
170 uint32_t stride);
171
172 // Same as above, but uses MemoryAccessMaskNone for the memory access operands
173 // template argument.
174 template <CooperativeMatrixLayout layout, class Type>
175 static CooperativeMatrix
176 CoherentLoad(globallycoherent RWStructuredBuffer<Type> data, uint32_t index,
177 uint32_t stride) {
178 return CoherentLoad<MemoryAccessMaskNone, layout>(data, index, stride);
179 }
180
181 // Loads a cooperative matrix using OpCooperativeMatrixLoadKHR from
182 // data[index] using the given memory layout, stride, and memory access
183 // operands. No memory access bits are added to the operands. Since the memory
184 // is readonly, there should be no need.
185 //
186 // The layout and stride will be passed to the SPIR-V instruction as is. The
187 // precise meaning can be found in the specification for
188 // SPV_KHR_cooperative_matrix.
189 template <uint32_t memoryAccessOperands, CooperativeMatrixLayout layout,
190 class Type>
191 static CooperativeMatrix Load(StructuredBuffer<Type> data, uint32_t index,
192 uint32_t stride);
193
194 // Same as above, but uses MemoryAccessMaskNone for the memory access
195 // operands.
196 template <CooperativeMatrixLayout layout, class Type>
197 static CooperativeMatrix Load(StructuredBuffer<Type> data, uint32_t index,
198 uint32_t stride) {
199 return Load<MemoryAccessMaskNone, layout>(data, index, stride);
200 }
201
202 // Constructs a cooperative matrix with all values initialized to v. Note that
203 // all threads in scope must have the same value for v.
204 static CooperativeMatrix Splat(ComponentType v);
205
206 // Returns the result of OpCooperativeMatrixLengthKHR on the current type.
207 static uint32_t GetLength();
208
209 // Functions to access the elements of the cooperative matrix. The index must
210 // be less than GetLength().
211 void Set(ComponentType value, uint32_t index);
212 ComponentType Get(uint32_t index);
213
214 static const bool hasSignedIntegerComponentType =
215 (ComponentType(0) - ComponentType(1) < ComponentType(0));
216
217 // clang-format off
218 using SpirvMatrixType = vk::SpirvOpaqueType<
219 /* OpTypeCooperativeMatrixKHR */ 4456, ComponentType,
220 vk::integral_constant<uint, scope>, vk::integral_constant<uint, rows>,
221 vk::integral_constant<uint, columns>, vk::integral_constant<uint, use> >;
222
223 [[vk::ext_extension("SPV_KHR_cooperative_matrix")]]
224 [[vk::ext_capability(/* CooperativeMatrixKHRCapability */ 6022)]]
225 [[vk::ext_capability(/* VulkanMemoryModel */ 5345)]]
226 SpirvMatrixType _matrix;
227 // clang-format on
228};
229
230// Cooperative matrix that can be used in the "a" position of a multiply add
231// instruction (r = (a * b) + c).
232template <typename ComponentType, Scope scope, uint rows, uint columns>
233using CooperativeMatrixA =
234 CooperativeMatrix<ComponentType, scope, rows, columns,
235 CooperativeMatrixUseMatrixAKHR>;
236
237// Cooperative matrix that can be used in the "b" position of a multiply add
238// instruction (r = (a * b) + c).
239template <typename ComponentType, Scope scope, uint rows, uint columns>
240using CooperativeMatrixB =
241 CooperativeMatrix<ComponentType, scope, rows, columns,
242 CooperativeMatrixUseMatrixBKHR>;
243
244// Cooperative matrix that can be used in the "r" and "c" position of a multiply
245// add instruction (r = (a * b) + c).
246template <typename ComponentType, Scope scope, uint rows, uint columns>
247using CooperativeMatrixAccumulator =
248 CooperativeMatrix<ComponentType, scope, rows, columns,
249 CooperativeMatrixUseMatrixAccumulatorKHR>;
250
251// Returns the result of OpCooperativeMatrixMulAddKHR when applied to a, b, and
252// c. The cooperative matrix operands are inferred, with the
253// SaturatingAccumulationKHR bit not set.
254template <typename ComponentType, Scope scope, uint rows, uint columns, uint K>
255CooperativeMatrixAccumulator<ComponentType, scope, rows, columns>
256cooperativeMatrixMultiplyAdd(
257 CooperativeMatrixA<ComponentType, scope, rows, K> a,
258 CooperativeMatrixB<ComponentType, scope, K, columns> b,
259 CooperativeMatrixAccumulator<ComponentType, scope, rows, columns> c);
260
261// Returns the result of OpCooperativeMatrixMulAddKHR when applied to a, b, and
262// c. The cooperative matrix operands are inferred, with the
263// SaturatingAccumulationKHR bit set.
264template <typename ComponentType, Scope scope, uint rows, uint columns, uint K>
265CooperativeMatrixAccumulator<ComponentType, scope, rows, columns>
266cooperativeMatrixSaturatingMultiplyAdd(
267 CooperativeMatrixA<ComponentType, scope, rows, K> a,
268 CooperativeMatrixB<ComponentType, scope, K, columns> b,
269 CooperativeMatrixAccumulator<ComponentType, scope, rows, columns> c);
270
271} // namespace khr
272} // namespace vk
273
274#include "cooperative_matrix.impl"
275#endif // _HLSL_VK_KHR_COOPERATIVE_MATRIX_H_