aboutsummaryrefslogtreecommitdiff
path: root/contrib/dxc_2025_07_14/inc/hlsl/dx/linalg.h
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/dxc_2025_07_14/inc/hlsl/dx/linalg.h')
-rw-r--r--contrib/dxc_2025_07_14/inc/hlsl/dx/linalg.h198
1 files changed, 198 insertions, 0 deletions
diff --git a/contrib/dxc_2025_07_14/inc/hlsl/dx/linalg.h b/contrib/dxc_2025_07_14/inc/hlsl/dx/linalg.h
new file mode 100644
index 0000000..4f5e620
--- /dev/null
+++ b/contrib/dxc_2025_07_14/inc/hlsl/dx/linalg.h
@@ -0,0 +1,198 @@
1// Header for linear algebra APIs.
2
3#if __spirv__
4#error "Cooperative vectors not (yet) supported for SPIRV"
5#endif
6
7#if ((__SHADER_TARGET_MAJOR > 6) || \
8 (__SHADER_TARGET_MAJOR == 6 && __SHADER_TARGET_MINOR >= 9)) && \
9 (__HLSL_VERSION >= 2021)
10
11namespace dx {
12namespace linalg {
13
14// NOTE: can't be an enum class because we get this error:
15// error: non-type template argument of type 'dx::linalg::DataType' is not
16// an integral constant expression
17//
18enum DataType {
19 DATA_TYPE_SINT16 = 2, // ComponentType::I16
20 DATA_TYPE_UINT16 = 3, // ComponentType::U16
21 DATA_TYPE_SINT32 = 4, // ComponentType::I32
22 DATA_TYPE_UINT32 = 5, // ComponentType::U32
23 DATA_TYPE_FLOAT16 = 8, // ComponentType::F16
24 DATA_TYPE_FLOAT32 = 9, // ComponentType::F32
25 DATA_TYPE_SINT8_T4_PACKED = 17, // ComponentType::PackedS8x32
26 DATA_TYPE_UINT8_T4_PACKED = 18, // ComponentType::PackedU8x32
27 DATA_TYPE_UINT8 = 19, // ComponentType::U8
28 DATA_TYPE_SINT8 = 20, // ComponentType::I8
29 DATA_TYPE_FLOAT8_E4M3 = 21, // ComponentType::F8_E4M3
30 // (1 sign, 4 exp, 3 mantissa bits)
31 DATA_TYPE_FLOAT8_E5M2 = 22, // ComponentType::F8_E5M2
32 // (1 sign, 5 exp, 2 mantissa bits)
33};
34
35enum MatrixLayout {
36 MATRIX_LAYOUT_ROW_MAJOR = 0,
37 MATRIX_LAYOUT_COLUMN_MAJOR = 1,
38 MATRIX_LAYOUT_MUL_OPTIMAL = 2,
39 MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL = 3
40};
41
42//
43// Helper for signedness
44//
45namespace details {
46
47template <typename T> struct IsUnsigned {};
48
49#define _SPECIALIZE_ISUNSIGNED(type, value) \
50 template <> struct IsUnsigned<type> { \
51 static const bool Value = value; \
52 }
53
54_SPECIALIZE_ISUNSIGNED(uint8_t4_packed, true);
55_SPECIALIZE_ISUNSIGNED(int8_t4_packed, true);
56_SPECIALIZE_ISUNSIGNED(uint32_t, true);
57_SPECIALIZE_ISUNSIGNED(int32_t, false);
58_SPECIALIZE_ISUNSIGNED(float32_t, false);
59
60#ifdef __HLSL_ENABLE_16_BIT
61_SPECIALIZE_ISUNSIGNED(uint16_t, true);
62_SPECIALIZE_ISUNSIGNED(int16_t, false);
63_SPECIALIZE_ISUNSIGNED(float16_t, false);
64#else // //__HLSL_ENABLE_16_BIT
65_SPECIALIZE_ISUNSIGNED(half, false);
66#endif //__HLSL_ENABLE_16_BIT
67
68#undef _SPECIALIZE_ISUNSIGNED
69
70} // namespace details
71
72//
73// (RW)MatrixRef
74//
75
76template <typename BufferTy, DataType DT, uint M, uint K, MatrixLayout ML,
77 bool Transpose>
78struct MatrixRefImpl {
79 BufferTy Buffer;
80 uint StartOffset;
81 uint Stride;
82};
83
84template <DataType DT, uint M, uint K, MatrixLayout ML, bool Transpose = false>
85using MatrixRef = MatrixRefImpl<ByteAddressBuffer, DT, M, K, ML, Transpose>;
86
87template <DataType DT, uint M, uint K, MatrixLayout ML, bool Transpose = false>
88using RWMatrixRef = MatrixRefImpl<RWByteAddressBuffer, DT, M, K, ML, Transpose>;
89
90//
91// (RW)VectorRef
92//
93
94template <typename BufferTy, DataType DT> struct VectorRefImpl {
95 BufferTy Buffer;
96 uint StartOffset;
97};
98
99template <DataType DT> using VectorRef = VectorRefImpl<ByteAddressBuffer, DT>;
100
101template <DataType DT>
102using RWVectorRef = VectorRefImpl<RWByteAddressBuffer, DT>;
103
104//
105// Vector
106//
107
108template <typename T, int N, DataType DT> struct InterpretedVector {
109 vector<T, N> Data;
110};
111
112template <DataType DT, typename T, int N>
113InterpretedVector<T, N, DT> MakeInterpretedVector(vector<T, N> Vec) {
114 InterpretedVector<T, N, DT> IV = {Vec};
115 return IV;
116}
117
118//
119// Mul
120//
121
122template <typename OutputElTy, typename InputElTy, int InputElCount,
123 typename MatrixBufferTy, DataType InputDT, DataType MatrixDT,
124 uint MatrixM, uint MatrixK, MatrixLayout MatrixLayout,
125 bool MatrixTranspose>
126vector<OutputElTy, MatrixM>
127Mul(MatrixRefImpl<MatrixBufferTy, MatrixDT, MatrixM, MatrixK, MatrixLayout,
128 MatrixTranspose>
129 Matrix,
130 InterpretedVector<InputElTy, InputElCount, InputDT> InputVector) {
131
132 vector<OutputElTy, MatrixM> OutputVector;
133
134 __builtin_MatVecMul(
135 /*out*/ OutputVector, details::IsUnsigned<OutputElTy>::Value,
136 InputVector.Data, details::IsUnsigned<InputElTy>::Value, InputDT,
137 Matrix.Buffer, Matrix.StartOffset, MatrixDT, MatrixM, MatrixK,
138 MatrixLayout, MatrixTranspose, Matrix.Stride);
139
140 return OutputVector;
141}
142
143//
144// MulAdd
145//
146
147template <typename OutputElTy, typename InputElTy, int InputElCount,
148 typename MatrixBufferTy, DataType InputDT, DataType MatrixDT,
149 uint MatrixM, uint MatrixK, MatrixLayout MatrixLayout,
150 bool MatrixTranspose, typename BiasVectorBufferTy,
151 DataType BiasVectorDT>
152vector<OutputElTy, MatrixM>
153MulAdd(MatrixRefImpl<MatrixBufferTy, MatrixDT, MatrixM, MatrixK, MatrixLayout,
154 MatrixTranspose>
155 Matrix,
156 InterpretedVector<InputElTy, InputElCount, InputDT> InputVector,
157 VectorRefImpl<BiasVectorBufferTy, BiasVectorDT> BiasVector) {
158
159 vector<OutputElTy, MatrixM> OutputVector;
160
161 __builtin_MatVecMulAdd(
162 /*out*/ OutputVector, details::IsUnsigned<OutputElTy>::Value,
163 InputVector.Data, details::IsUnsigned<InputElTy>::Value, InputDT,
164 Matrix.Buffer, Matrix.StartOffset, MatrixDT, MatrixM, MatrixK,
165 MatrixLayout, MatrixTranspose, Matrix.Stride, BiasVector.Buffer,
166 BiasVector.StartOffset, BiasVectorDT);
167
168 return OutputVector;
169}
170
171//
172// OuterProductAccumulate
173//
174
175template <typename ElTy, int MatrixM, int MatrixN, DataType MatrixDT,
176 MatrixLayout MatrixLayout>
177void OuterProductAccumulate(
178 vector<ElTy, MatrixM> InputVector1, vector<ElTy, MatrixN> InputVector2,
179 RWMatrixRef<MatrixDT, MatrixM, MatrixN, MatrixLayout, false> Matrix) {
180 __builtin_OuterProductAccumulate(InputVector1, InputVector2, Matrix.Buffer,
181 Matrix.StartOffset, MatrixDT, MatrixLayout,
182 Matrix.Stride);
183}
184
185//
186// VectorAccumulate
187//
188
189template <typename ElTy, int ElCount>
190void VectorAccumulate(vector<ElTy, ElCount> InputVector,
191 RWByteAddressBuffer Buffer, uint Offset) {
192 __builtin_VectorAccumulate(InputVector, Buffer, Offset);
193}
194
195} // namespace linalg
196} // namespace dx
197
198#endif // SM 6.9 check and HV version check