aboutsummaryrefslogtreecommitdiff
path: root/contrib/dxc_2025_07_14/inc/hlsl/vk/khr
diff options
context:
space:
mode:
author3gg <3gg@shellblade.net>2025-12-02 16:39:36 -0800
committer3gg <3gg@shellblade.net>2025-12-02 16:39:36 -0800
commit6c8ae19be66cee247980a48e736a4e05d14de179 (patch)
treed860767907bf0cbe17ec66422e11bea700cf56d9 /contrib/dxc_2025_07_14/inc/hlsl/vk/khr
parent8f594c8ebd11f0e5f8a0c6369c3fe7383d250cbe (diff)
Immediate-mode renderer, triangle demo, shader compilation in cmake, Agility SDKHEADmain
Diffstat (limited to 'contrib/dxc_2025_07_14/inc/hlsl/vk/khr')
-rw-r--r--contrib/dxc_2025_07_14/inc/hlsl/vk/khr/cooperative_matrix.h275
-rw-r--r--contrib/dxc_2025_07_14/inc/hlsl/vk/khr/cooperative_matrix.impl377
2 files changed, 652 insertions, 0 deletions
diff --git a/contrib/dxc_2025_07_14/inc/hlsl/vk/khr/cooperative_matrix.h b/contrib/dxc_2025_07_14/inc/hlsl/vk/khr/cooperative_matrix.h
new file mode 100644
index 0000000..a53ab4c
--- /dev/null
+++ b/contrib/dxc_2025_07_14/inc/hlsl/vk/khr/cooperative_matrix.h
@@ -0,0 +1,275 @@
1// Copyright (c) 2024 Google LLC
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
7#ifndef _HLSL_VK_KHR_COOPERATIVE_MATRIX_H_
8#define _HLSL_VK_KHR_COOPERATIVE_MATRIX_H_
9
10#if __SPIRV_MAJOR_VERSION__ == 1 && __SPIRV_MINOR_VERSION__ < 6
11#error "CooperativeMatrix requires a minimum of SPIR-V 1.6"
12#endif
13
14#include "vk/spirv.h"
15
16namespace vk {
17namespace khr {
18
19// The base cooperative matrix class. The template arguments correspond to the
20// operands in the OpTypeCooperativeMatrixKHR instruction.
21template <typename ComponentType, Scope scope, uint rows, uint columns,
22 CooperativeMatrixUse use>
23class CooperativeMatrix {
24 template <class NewComponentType>
25 CooperativeMatrix<NewComponentType, scope, rows, columns, use> cast();
26
27 // Apply OpSNegate or OFNegate, depending on ComponentType, in a element by
28 // element manner.
29 CooperativeMatrix negate();
30
31 // Apply OpIAdd or OFAdd, depending on ComponentType, in a element by element
32 // manner.
33 CooperativeMatrix operator+(CooperativeMatrix other);
34
35 // Apply OpISub or OFSub, depending on ComponentType, in a element by element
36 // manner.
37 CooperativeMatrix operator-(CooperativeMatrix other);
38
39 // Apply OpIMul or OFMul, depending on ComponentType, in a element by element
40 // manner.
41 CooperativeMatrix operator*(CooperativeMatrix other);
42
43 // Apply OpSDiv, OpUDiv or OFDiv, depending on ComponentType, in a element by
44 // element manner.
45 CooperativeMatrix operator/(CooperativeMatrix other);
46
47 // Apply OpMatrixTimesScalar in a element by element manner.
48 CooperativeMatrix operator*(ComponentType scalar);
49
50 // Store the cooperative matrix using OpCooperativeMatrixStoreKHR to
51 // data using the given memory layout, stride, and memory access operands.
52 // `NonPrivatePointer` and `MakePointerAvailable` with the workgroup scope
53 // will be added to the memory access operands to make the memory coherent.
54 //
55 // This function uses a SPIR-V pointer because HLSL does not allow groupshared
56 // memory object to be passed by reference. The pointer is a hack to get
57 // around that.
58 //
59 // The layout and stride will be passed to the SPIR-V instruction as is. The
60 // precise meaning can be found in the specification for
61 // SPV_KHR_cooperative_matrix.
62 template <uint32_t memoryAccessOperands, CooperativeMatrixLayout layout,
63 class Type>
64 void Store(WorkgroupSpirvPointer<Type> data, uint32_t stride);
65
66 // Same as above, but uses MemoryAccessMaskNone for the memory access
67 // operands.
68 template <CooperativeMatrixLayout layout, class Type>
69 void Store(WorkgroupSpirvPointer<Type> data, uint32_t stride) {
70 Store<MemoryAccessMaskNone, layout>(data, stride);
71 }
72
73 // Store the cooperative matrix using OpCooperativeMatrixStoreKHR to
74 // data[index] using the given memory layout, stride, and memory access
75 // operands. The layout and stride will be passed to the SPIR-V instruction as
76 // is. The precise meaning can be found in the specification for
77 // SPV_KHR_cooperative_matrix.
78 template <uint32_t memoryAccessOperands, CooperativeMatrixLayout layout,
79 class Type>
80 void Store(RWStructuredBuffer<Type> data, uint32_t index, uint32_t stride);
81
82 // Same as above, but uses MemoryAccessMaskNone for the memory access
83 // operands.
84 template <CooperativeMatrixLayout layout, class Type>
85 void Store(RWStructuredBuffer<Type> data, uint32_t index, uint32_t stride) {
86 Store<MemoryAccessMaskNone, layout>(data, index, stride);
87 }
88
89 // Store the cooperative matrix using OpCooperativeMatrixStoreKHR to
90 // data[index] using the given memory layout, stride, and memory access
91 // operands. `NonPrivatePointer` and `MakePointerAvailable` with the
92 // QueueFamily scope will be added to the memory access operands to make the
93 // memory coherent.
94 //
95 // The layout and stride will be passed to the SPIR-V instruction as is. The
96 // precise meaning can be found in the specification for
97 // SPV_KHR_cooperative_matrix.
98 template <uint32_t memoryAccessOperands, CooperativeMatrixLayout layout,
99 class Type>
100 void CoherentStore(globallycoherent RWStructuredBuffer<Type> data,
101 uint32_t index, uint32_t stride);
102
103 // Same as above, but uses MemoryAccessMaskNone for the memory access operands
104 // template argument.
105 template <CooperativeMatrixLayout layout, class Type>
106 void CoherentStore(globallycoherent RWStructuredBuffer<Type> data,
107 uint32_t index, uint32_t stride) {
108 CoherentStore<MemoryAccessMaskNone, layout>(data, index, stride);
109 }
110
111 // Loads a cooperative matrix using OpCooperativeMatrixLoadKHR from
112 // data using the given memory layout, stride, and memory access operands.
113 // `NonPrivatePointer` and `MakePointerVisible` with the workgroup scope
114 // will be added to the memory access operands to make the memory coherent.
115 //
116 // This function uses a SPIR-V pointer because HLSL does not allow groupshared
117 // memory object to be passed by reference. The pointer is a hack to get
118 // around that.
119 //
120 // The layout and stride will be passed to the SPIR-V instruction as is. The
121 // precise meaning can be found in the specification for
122 // SPV_KHR_cooperative_matrix.
123 template <uint32_t memoryAccessOperands, CooperativeMatrixLayout layout,
124 class Type>
125 static CooperativeMatrix Load(WorkgroupSpirvPointer<Type> data,
126 uint32_t stride);
127
128 // Same as above, but uses MemoryAccessMaskNone for the memory access
129 // operands.
130 template <CooperativeMatrixLayout layout, class Type>
131 static CooperativeMatrix Load(WorkgroupSpirvPointer<Type> data,
132 uint32_t stride) {
133 return Load<MemoryAccessMaskNone, layout>(data, stride);
134 }
135
136 // Loads a cooperative matrix using OpCooperativeMatrixLoadKHR from
137 // data[index] using the given memory layout, stride, and memory access
138 // operands.
139 //
140 // The layout and stride will be passed to the SPIR-V instruction as is. The
141 // precise meaning can be found in the specification for
142 // SPV_KHR_cooperative_matrix.
143 template <uint32_t memoryAccessOperands, CooperativeMatrixLayout layout,
144 class Type>
145 static CooperativeMatrix Load(RWStructuredBuffer<Type> data, uint32_t index,
146 uint32_t stride);
147
148 // Same as above, but uses MemoryAccessMaskNone for the memory access
149 // operands.
150 template <CooperativeMatrixLayout layout, class Type>
151 static CooperativeMatrix Load(RWStructuredBuffer<Type> data, uint32_t index,
152 uint32_t stride) {
153 return Load<MemoryAccessMaskNone, layout>(data, index, stride);
154 }
155
156 // Loads a cooperative matrix using OpCooperativeMatrixLoadKHR from
157 // data[index] using the given memory layout, stride, and memory access
158 // operands. `NonPrivatePointer` and `MakePointerVisible` with the QueueFamily
159 // scope will be added to the memory access operands to make the memory
160 // coherent.
161 //
162 //
163 // The layout and stride will be passed to the SPIR-V instruction as is. The
164 // precise meaning can be found in the specification for
165 // SPV_KHR_cooperative_matrix.
166 template <uint32_t memoryAccessOperands, CooperativeMatrixLayout layout,
167 class Type>
168 static CooperativeMatrix
169 CoherentLoad(globallycoherent RWStructuredBuffer<Type> data, uint32_t index,
170 uint32_t stride);
171
172 // Same as above, but uses MemoryAccessMaskNone for the memory access operands
173 // template argument.
174 template <CooperativeMatrixLayout layout, class Type>
175 static CooperativeMatrix
176 CoherentLoad(globallycoherent RWStructuredBuffer<Type> data, uint32_t index,
177 uint32_t stride) {
178 return CoherentLoad<MemoryAccessMaskNone, layout>(data, index, stride);
179 }
180
181 // Loads a cooperative matrix using OpCooperativeMatrixLoadKHR from
182 // data[index] using the given memory layout, stride, and memory access
183 // operands. No memory access bits are added to the operands. Since the memory
184 // is readonly, there should be no need.
185 //
186 // The layout and stride will be passed to the SPIR-V instruction as is. The
187 // precise meaning can be found in the specification for
188 // SPV_KHR_cooperative_matrix.
189 template <uint32_t memoryAccessOperands, CooperativeMatrixLayout layout,
190 class Type>
191 static CooperativeMatrix Load(StructuredBuffer<Type> data, uint32_t index,
192 uint32_t stride);
193
194 // Same as above, but uses MemoryAccessMaskNone for the memory access
195 // operands.
196 template <CooperativeMatrixLayout layout, class Type>
197 static CooperativeMatrix Load(StructuredBuffer<Type> data, uint32_t index,
198 uint32_t stride) {
199 return Load<MemoryAccessMaskNone, layout>(data, index, stride);
200 }
201
202 // Constructs a cooperative matrix with all values initialized to v. Note that
203 // all threads in scope must have the same value for v.
204 static CooperativeMatrix Splat(ComponentType v);
205
206 // Returns the result of OpCooperativeMatrixLengthKHR on the current type.
207 static uint32_t GetLength();
208
209 // Functions to access the elements of the cooperative matrix. The index must
210 // be less than GetLength().
211 void Set(ComponentType value, uint32_t index);
212 ComponentType Get(uint32_t index);
213
214 static const bool hasSignedIntegerComponentType =
215 (ComponentType(0) - ComponentType(1) < ComponentType(0));
216
217 // clang-format off
218 using SpirvMatrixType = vk::SpirvOpaqueType<
219 /* OpTypeCooperativeMatrixKHR */ 4456, ComponentType,
220 vk::integral_constant<uint, scope>, vk::integral_constant<uint, rows>,
221 vk::integral_constant<uint, columns>, vk::integral_constant<uint, use> >;
222
223 [[vk::ext_extension("SPV_KHR_cooperative_matrix")]]
224 [[vk::ext_capability(/* CooperativeMatrixKHRCapability */ 6022)]]
225 [[vk::ext_capability(/* VulkanMemoryModel */ 5345)]]
226 SpirvMatrixType _matrix;
227 // clang-format on
228};
229
230// Cooperative matrix that can be used in the "a" position of a multiply add
231// instruction (r = (a * b) + c).
232template <typename ComponentType, Scope scope, uint rows, uint columns>
233using CooperativeMatrixA =
234 CooperativeMatrix<ComponentType, scope, rows, columns,
235 CooperativeMatrixUseMatrixAKHR>;
236
237// Cooperative matrix that can be used in the "b" position of a multiply add
238// instruction (r = (a * b) + c).
239template <typename ComponentType, Scope scope, uint rows, uint columns>
240using CooperativeMatrixB =
241 CooperativeMatrix<ComponentType, scope, rows, columns,
242 CooperativeMatrixUseMatrixBKHR>;
243
244// Cooperative matrix that can be used in the "r" and "c" position of a multiply
245// add instruction (r = (a * b) + c).
246template <typename ComponentType, Scope scope, uint rows, uint columns>
247using CooperativeMatrixAccumulator =
248 CooperativeMatrix<ComponentType, scope, rows, columns,
249 CooperativeMatrixUseMatrixAccumulatorKHR>;
250
251// Returns the result of OpCooperativeMatrixMulAddKHR when applied to a, b, and
252// c. The cooperative matrix operands are inferred, with the
253// SaturatingAccumulationKHR bit not set.
254template <typename ComponentType, Scope scope, uint rows, uint columns, uint K>
255CooperativeMatrixAccumulator<ComponentType, scope, rows, columns>
256cooperativeMatrixMultiplyAdd(
257 CooperativeMatrixA<ComponentType, scope, rows, K> a,
258 CooperativeMatrixB<ComponentType, scope, K, columns> b,
259 CooperativeMatrixAccumulator<ComponentType, scope, rows, columns> c);
260
261// Returns the result of OpCooperativeMatrixMulAddKHR when applied to a, b, and
262// c. The cooperative matrix operands are inferred, with the
263// SaturatingAccumulationKHR bit set.
264template <typename ComponentType, Scope scope, uint rows, uint columns, uint K>
265CooperativeMatrixAccumulator<ComponentType, scope, rows, columns>
266cooperativeMatrixSaturatingMultiplyAdd(
267 CooperativeMatrixA<ComponentType, scope, rows, K> a,
268 CooperativeMatrixB<ComponentType, scope, K, columns> b,
269 CooperativeMatrixAccumulator<ComponentType, scope, rows, columns> c);
270
271} // namespace khr
272} // namespace vk
273
274#include "cooperative_matrix.impl"
275#endif // _HLSL_VK_KHR_COOPERATIVE_MATRIX_H_
diff --git a/contrib/dxc_2025_07_14/inc/hlsl/vk/khr/cooperative_matrix.impl b/contrib/dxc_2025_07_14/inc/hlsl/vk/khr/cooperative_matrix.impl
new file mode 100644
index 0000000..2acae8e
--- /dev/null
+++ b/contrib/dxc_2025_07_14/inc/hlsl/vk/khr/cooperative_matrix.impl
@@ -0,0 +1,377 @@
1// Copyright (c) 2024 Google LLC
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
7#include "vk/opcode_selector.h"
8
9template <typename ResultType, typename ComponentType>
10[[vk::ext_instruction(/* OpMatrixTimesScalar */ 143)]] ResultType
11__builtin_spv_MatrixTimesScalar(ResultType a, ComponentType b);
12
13template <typename ComponentType, vk::Scope scope, uint rows, uint columns,
14 vk::CooperativeMatrixUse use>
15[[vk::ext_instruction(/* OpCompositeExtract */ 81)]] ComponentType
16__builtin_spv_ExtractFromCooperativeMatrix(
17 typename vk::khr::CooperativeMatrix<ComponentType, scope, rows, columns,
18 use>::SpirvMatrixType matrix,
19 uint32_t index);
20
21template <typename CoopMatrixType, typename ComponentType>
22[[vk::ext_instruction(/* OpCompositeConstruct */ 80)]] CoopMatrixType
23__builtin_spv_ConstructCooperativeMatrix(ComponentType value);
24
25template <class ResultPointerType, class BaseType>
26[[vk::ext_instruction(/* OpAccessChain */ 65)]] ResultPointerType
27__builtin_spv_AccessChain([[vk::ext_reference]] BaseType base, uint32_t index);
28
29template <class ObjectType, class PointerType>
30[[vk::ext_instruction(/* OpLoad */ 61)]] ObjectType
31__builtin_spv_LoadPointer(PointerType base);
32
33template <class PointerType, class ObjectType>
34[[vk::ext_instruction(/* OpLoad */ 62)]] void
35__builtin_spv_StorePointer(PointerType base, ObjectType object);
36
37template <typename ComponentType, vk::Scope scope, uint rows, uint columns,
38 vk::CooperativeMatrixUse use>
39[[vk::ext_instruction(/* OpCompositeInsert */ 82)]]
40typename vk::khr::CooperativeMatrix<ComponentType, scope, rows, columns,
41 use>::SpirvMatrixType
42__builtin_spv_InsertIntoCooperativeMatrix(
43 ComponentType value,
44 typename vk::khr::CooperativeMatrix<ComponentType, scope, rows, columns,
45 use>::SpirvMatrixType matrix,
46 uint32_t index);
47
48// Define the load and store instructions
49template <typename ResultType, typename PointerType>
50[[vk::ext_instruction(/* OpCooperativeMatrixLoadKHR */ 4457)]] ResultType
51__builtin_spv_CooperativeMatrixLoadKHR(
52 [[vk::ext_reference]] PointerType pointer,
53 vk::CooperativeMatrixLayout memory_layout, uint stride,
54 [[vk::ext_literal]] uint32_t memory_operand);
55
56template <typename ResultType, typename PointerType>
57[[vk::ext_instruction(/* OpCooperativeMatrixLoadKHR */ 4457)]] ResultType
58__builtin_spv_CooperativeMatrixLoadKHR(
59 [[vk::ext_reference]] PointerType pointer,
60 vk::CooperativeMatrixLayout memory_layout, uint stride,
61 [[vk::ext_literal]] uint32_t memory_operand, vk::Scope scope);
62
63template <typename ResultType, typename PointerType>
64[[vk::ext_instruction(/* OpCooperativeMatrixLoadKHR */ 4457)]] ResultType
65__builtin_spv_CooperativeMatrixWorkgroupLoadKHR(
66 vk::WorkgroupSpirvPointer<PointerType> pointer,
67 vk::CooperativeMatrixLayout memory_layout, uint stride,
68 [[vk::ext_literal]] uint32_t memory_operand, vk::Scope scope);
69
70template <typename ObjectType, typename PointerType>
71[[vk::ext_instruction(/* OpCooperativeMatrixStoreKHR */ 4458)]] void
72__builtin_spv_CooperativeMatrixStoreKHR(
73 [[vk::ext_reference]] PointerType pointer, ObjectType object,
74 vk::CooperativeMatrixLayout memory_layout, uint stride,
75 [[vk::ext_literal]] uint32_t memory_operand, vk::Scope scope);
76
77template <typename ObjectType, typename PointerType>
78[[vk::ext_instruction(/* OpCooperativeMatrixStoreKHR */ 4458)]] void
79__builtin_spv_CooperativeMatrixStoreKHR(
80 [[vk::ext_reference]] PointerType pointer, ObjectType object,
81 vk::CooperativeMatrixLayout memory_layout, uint stride,
82 [[vk::ext_literal]] uint32_t memory_operand);
83
84template <typename ObjectType, typename PointerType>
85[[vk::ext_instruction(/* OpCooperativeMatrixStoreKHR */ 4458)]] void
86__builtin_spv_CooperativeMatrixWorkgroupStoreKHR(
87 vk::WorkgroupSpirvPointer<PointerType> pointer, ObjectType object,
88 vk::CooperativeMatrixLayout memory_layout, uint stride,
89 [[vk::ext_literal]] uint32_t memory_operand, vk::Scope scope);
90
91// We cannot define `OpCooperativeMatrixLengthKHR` using ext_instruction because
92// one of the operands is a type id. This builtin will have specific code in the
93// compiler to expand it.
94template <class MatrixType> uint __builtin_spv_CooperativeMatrixLengthKHR();
95
96// Arithmetic Instructions
97template <typename ResultType, typename MatrixTypeA, typename MatrixTypeB,
98 typename MatrixTypeC>
99[[vk::ext_instruction(/* OpCooperativeMatrixMulAddKHR */ 4459)]] ResultType
100__builtin_spv_CooperativeMatrixMulAddKHR(MatrixTypeA a, MatrixTypeB b,
101 MatrixTypeC c,
102 [[vk::ext_literal]] int operands);
103namespace vk {
104namespace khr {
105
106template <class ComponentType, Scope scope, uint rows, uint columns,
107 CooperativeMatrixUse use>
108template <class NewComponentType>
109CooperativeMatrix<NewComponentType, scope, rows, columns, use>
110CooperativeMatrix<ComponentType, scope, rows, columns, use>::cast() {
111 using ResultType =
112 CooperativeMatrix<NewComponentType, scope, rows, columns, use>;
113 ResultType result;
114 result._matrix = util::ConversionSelector<ComponentType, NewComponentType>::
115 template Convert<typename ResultType::SpirvMatrixType>(_matrix);
116 return result;
117}
118
119template <class ComponentType, Scope scope, uint rows, uint columns,
120 CooperativeMatrixUse use>
121CooperativeMatrix<ComponentType, scope, rows, columns, use>
122CooperativeMatrix<ComponentType, scope, rows, columns, use>::negate() {
123 CooperativeMatrix result;
124 result._matrix = util::ArithmeticSelector<ComponentType>::Negate(_matrix);
125 return result;
126}
127
128template <class ComponentType, Scope scope, uint rows, uint columns,
129 CooperativeMatrixUse use>
130CooperativeMatrix<ComponentType, scope, rows, columns, use>
131CooperativeMatrix<ComponentType, scope, rows, columns, use>::operator+(
132 CooperativeMatrix other) {
133 CooperativeMatrix result;
134 result._matrix =
135 util::ArithmeticSelector<ComponentType>::Add(_matrix, other._matrix);
136 return result;
137}
138
139template <class ComponentType, Scope scope, uint rows, uint columns,
140 CooperativeMatrixUse use>
141CooperativeMatrix<ComponentType, scope, rows, columns, use>
142CooperativeMatrix<ComponentType, scope, rows, columns, use>::operator-(
143 CooperativeMatrix other) {
144 CooperativeMatrix result;
145 result._matrix =
146 util::ArithmeticSelector<ComponentType>::Sub(_matrix, other._matrix);
147 return result;
148}
149
150template <class ComponentType, Scope scope, uint rows, uint columns,
151 CooperativeMatrixUse use>
152CooperativeMatrix<ComponentType, scope, rows, columns, use>
153CooperativeMatrix<ComponentType, scope, rows, columns, use>::operator*(
154 CooperativeMatrix other) {
155 CooperativeMatrix result;
156 result._matrix =
157 util::ArithmeticSelector<ComponentType>::Mul(_matrix, other._matrix);
158 return result;
159}
160
161template <class ComponentType, Scope scope, uint rows, uint columns,
162 CooperativeMatrixUse use>
163CooperativeMatrix<ComponentType, scope, rows, columns, use>
164CooperativeMatrix<ComponentType, scope, rows, columns, use>::operator/(
165 CooperativeMatrix other) {
166 CooperativeMatrix result;
167 result._matrix =
168 util::ArithmeticSelector<ComponentType>::Div(_matrix, other._matrix);
169 return result;
170}
171
172template <class ComponentType, Scope scope, uint rows, uint columns,
173 CooperativeMatrixUse use>
174CooperativeMatrix<ComponentType, scope, rows, columns, use>
175CooperativeMatrix<ComponentType, scope, rows, columns, use>::operator*(
176 ComponentType scalar) {
177 CooperativeMatrix result;
178 result._matrix = __builtin_spv_MatrixTimesScalar(_matrix, scalar);
179 return result;
180}
181
182template <class ComponentType, Scope scope, uint rows, uint columns,
183 CooperativeMatrixUse use>
184template <uint32_t memoryAccessOperands, CooperativeMatrixLayout layout,
185 class Type>
186void CooperativeMatrix<ComponentType, scope, rows, columns, use>::Store(
187 WorkgroupSpirvPointer<Type> data, uint32_t stride) {
188 __builtin_spv_CooperativeMatrixWorkgroupStoreKHR(
189 data, _matrix, layout, stride,
190 memoryAccessOperands | MemoryAccessNonPrivatePointerMask |
191 MemoryAccessMakePointerAvailableMask,
192 ScopeWorkgroup);
193}
194
195template <class ComponentType, Scope scope, uint rows, uint columns,
196 CooperativeMatrixUse use>
197template <uint32_t memoryAccessOperands, CooperativeMatrixLayout layout,
198 class Type>
199void CooperativeMatrix<ComponentType, scope, rows, columns, use>::Store(
200 RWStructuredBuffer<Type> data, uint32_t index, uint32_t stride) {
201 __builtin_spv_CooperativeMatrixStoreKHR(data[index], _matrix, layout, stride,
202 memoryAccessOperands);
203}
204
205template <class ComponentType, Scope scope, uint rows, uint columns,
206 CooperativeMatrixUse use>
207template <uint32_t memoryAccessOperands, CooperativeMatrixLayout layout,
208 class Type>
209void CooperativeMatrix<ComponentType, scope, rows, columns, use>::CoherentStore(
210 globallycoherent RWStructuredBuffer<Type> data, uint32_t index,
211 uint32_t stride) {
212 __builtin_spv_CooperativeMatrixStoreKHR(
213 data[index], _matrix, layout, stride,
214 memoryAccessOperands | MemoryAccessNonPrivatePointerMask |
215 MemoryAccessMakePointerAvailableMask,
216 ScopeQueueFamily);
217}
218
219template <class ComponentType, Scope scope, uint rows, uint columns,
220 CooperativeMatrixUse use>
221template <uint32_t memoryAccessOperands, CooperativeMatrixLayout layout,
222 class Type>
223CooperativeMatrix<ComponentType, scope, rows, columns, use>
224CooperativeMatrix<ComponentType, scope, rows, columns, use>::Load(
225 vk::WorkgroupSpirvPointer<Type> buffer, uint32_t stride) {
226 CooperativeMatrix result;
227 result._matrix =
228 __builtin_spv_CooperativeMatrixWorkgroupLoadKHR<SpirvMatrixType>(
229 buffer, layout, stride,
230 memoryAccessOperands | MemoryAccessNonPrivatePointerMask |
231 MemoryAccessMakePointerVisibleMask,
232 ScopeWorkgroup);
233 return result;
234}
235
236template <class ComponentType, Scope scope, uint rows, uint columns,
237 CooperativeMatrixUse use>
238template <uint32_t memoryAccessOperands, CooperativeMatrixLayout layout,
239 class Type>
240CooperativeMatrix<ComponentType, scope, rows, columns, use>
241CooperativeMatrix<ComponentType, scope, rows, columns, use>::Load(
242 RWStructuredBuffer<Type> buffer, uint32_t index, uint32_t stride) {
243 CooperativeMatrix result;
244 result._matrix = __builtin_spv_CooperativeMatrixLoadKHR<SpirvMatrixType>(
245 buffer[index], layout, stride, memoryAccessOperands);
246 return result;
247}
248
249template <class ComponentType, Scope scope, uint rows, uint columns,
250 CooperativeMatrixUse use>
251template <uint32_t memoryAccessOperands, CooperativeMatrixLayout layout,
252 class Type>
253CooperativeMatrix<ComponentType, scope, rows, columns, use>
254CooperativeMatrix<ComponentType, scope, rows, columns, use>::CoherentLoad(
255 RWStructuredBuffer<Type> buffer, uint32_t index, uint32_t stride) {
256 CooperativeMatrix result;
257 result._matrix = __builtin_spv_CooperativeMatrixLoadKHR<SpirvMatrixType>(
258 buffer[index], layout, stride,
259 memoryAccessOperands | MemoryAccessNonPrivatePointerMask |
260 MemoryAccessMakePointerVisibleMask,
261 ScopeQueueFamily);
262 return result;
263}
264
265template <class ComponentType, Scope scope, uint rows, uint columns,
266 CooperativeMatrixUse use>
267template <uint32_t memoryAccessOperands, CooperativeMatrixLayout layout,
268 class Type>
269CooperativeMatrix<ComponentType, scope, rows, columns, use>
270CooperativeMatrix<ComponentType, scope, rows, columns, use>::Load(
271 StructuredBuffer<Type> buffer, uint32_t index, uint32_t stride) {
272 CooperativeMatrix result;
273 result._matrix = __builtin_spv_CooperativeMatrixLoadKHR<SpirvMatrixType>(
274 buffer[index], layout, stride, MemoryAccessMaskNone);
275 return result;
276}
277
278template <class ComponentType, Scope scope, uint rows, uint columns,
279 CooperativeMatrixUse use>
280CooperativeMatrix<ComponentType, scope, rows, columns, use>
281CooperativeMatrix<ComponentType, scope, rows, columns, use>::Splat(
282 ComponentType v) {
283 CooperativeMatrix result;
284 result._matrix = __builtin_spv_ConstructCooperativeMatrix<SpirvMatrixType>(v);
285 return result;
286}
287
288template <class ComponentType, Scope scope, uint rows, uint columns,
289 CooperativeMatrixUse use>
290uint CooperativeMatrix<ComponentType, scope, rows, columns, use>::GetLength() {
291 return __builtin_spv_CooperativeMatrixLengthKHR<SpirvMatrixType>();
292}
293
294template <class ComponentType, Scope scope, uint rows, uint columns,
295 CooperativeMatrixUse use>
296ComponentType CooperativeMatrix<ComponentType, scope, rows, columns, use>::Get(
297 uint32_t index) {
298 // clang-format off
299 using ComponentPtr = vk::SpirvOpaqueType<
300 /* OpTypePointer */ 32,
301 /* function storage class */ vk::Literal<vk::integral_constant<uint, 7> >,
302 ComponentType>;
303 // clang-format on
304 ComponentPtr ptr = __builtin_spv_AccessChain<ComponentPtr>(_matrix, index);
305 return __builtin_spv_LoadPointer<ComponentType>(ptr);
306}
307
308template <class ComponentType, Scope scope, uint rows, uint columns,
309 CooperativeMatrixUse use>
310void CooperativeMatrix<ComponentType, scope, rows, columns, use>::Set(
311 ComponentType value, uint32_t index) {
312 // clang-format off
313 using ComponentPtr = vk::SpirvOpaqueType<
314 /* OpTypePointer */ 32,
315 /* function storage class */ vk::Literal<vk::integral_constant<uint, 7> >,
316 ComponentType>;
317 // clang-format on
318 ComponentPtr ptr = __builtin_spv_AccessChain<ComponentPtr>(_matrix, index);
319 return __builtin_spv_StorePointer(ptr, value);
320}
321
322template <typename ComponentType, Scope scope, uint rows, uint columns, uint K>
323CooperativeMatrixAccumulator<ComponentType, scope, rows, columns>
324cooperativeMatrixMultiplyAdd(
325 CooperativeMatrixA<ComponentType, scope, rows, K> a,
326 CooperativeMatrixB<ComponentType, scope, K, columns> b,
327 CooperativeMatrixAccumulator<ComponentType, scope, rows, columns> c) {
328
329 const vk::CooperativeMatrixOperandsMask allSignedComponents =
330 vk::CooperativeMatrixOperandsMatrixASignedComponentsKHRMask |
331 vk::CooperativeMatrixOperandsMatrixBSignedComponentsKHRMask |
332 vk::CooperativeMatrixOperandsMatrixCSignedComponentsKHRMask |
333 vk::CooperativeMatrixOperandsMatrixResultSignedComponentsKHRMask;
334
335 const vk::CooperativeMatrixOperandsMask operands =
336 (vk::CooperativeMatrixOperandsMask)(
337 a.hasSignedIntegerComponentType
338 ? allSignedComponents
339 : vk::CooperativeMatrixOperandsMaskNone);
340
341 CooperativeMatrixAccumulator<ComponentType, scope, rows, columns> result;
342 result._matrix = __builtin_spv_CooperativeMatrixMulAddKHR<
343 typename CooperativeMatrixAccumulator<ComponentType, scope, rows,
344 columns>::SpirvMatrixType>(
345 a._matrix, b._matrix, c._matrix, operands);
346 return result;
347}
348
349template <typename ComponentType, Scope scope, uint rows, uint columns, uint K>
350CooperativeMatrixAccumulator<ComponentType, scope, rows, columns>
351cooperativeMatrixSaturatingMultiplyAdd(
352 CooperativeMatrixA<ComponentType, scope, rows, K> a,
353 CooperativeMatrixB<ComponentType, scope, K, columns> b,
354 CooperativeMatrixAccumulator<ComponentType, scope, rows, columns> c) {
355
356 const vk::CooperativeMatrixOperandsMask allSignedComponents =
357 vk::CooperativeMatrixOperandsMatrixASignedComponentsKHRMask |
358 vk::CooperativeMatrixOperandsMatrixBSignedComponentsKHRMask |
359 vk::CooperativeMatrixOperandsMatrixCSignedComponentsKHRMask |
360 vk::CooperativeMatrixOperandsMatrixResultSignedComponentsKHRMask |
361 vk::CooperativeMatrixOperandsSaturatingAccumulationKHRMask;
362
363 const vk::CooperativeMatrixOperandsMask operands =
364 (vk::CooperativeMatrixOperandsMask)(
365 a.hasSignedIntegerComponentType
366 ? allSignedComponents
367 : vk::CooperativeMatrixOperandsSaturatingAccumulationKHRMask);
368 CooperativeMatrixAccumulator<ComponentType, scope, rows, columns> result;
369 result._matrix = __builtin_spv_CooperativeMatrixMulAddKHR<
370 typename CooperativeMatrixAccumulator<ComponentType, scope, rows,
371 columns>::SpirvMatrixType>(
372 a._matrix, b._matrix, c._matrix, operands);
373 return result;
374}
375
376} // namespace khr
377} // namespace vk