diff options
| author | 3gg <3gg@shellblade.net> | 2025-12-02 16:39:36 -0800 |
|---|---|---|
| committer | 3gg <3gg@shellblade.net> | 2025-12-02 16:39:36 -0800 |
| commit | 6c8ae19be66cee247980a48e736a4e05d14de179 (patch) | |
| tree | d860767907bf0cbe17ec66422e11bea700cf56d9 /contrib/dxc_2025_07_14/inc/hlsl/dx | |
| parent | 8f594c8ebd11f0e5f8a0c6369c3fe7383d250cbe (diff) | |
Diffstat (limited to 'contrib/dxc_2025_07_14/inc/hlsl/dx')
| -rw-r--r-- | contrib/dxc_2025_07_14/inc/hlsl/dx/linalg.h | 198 |
1 files changed, 198 insertions, 0 deletions
diff --git a/contrib/dxc_2025_07_14/inc/hlsl/dx/linalg.h b/contrib/dxc_2025_07_14/inc/hlsl/dx/linalg.h new file mode 100644 index 0000000..4f5e620 --- /dev/null +++ b/contrib/dxc_2025_07_14/inc/hlsl/dx/linalg.h | |||
| @@ -0,0 +1,198 @@ | |||
| 1 | // Header for linear algebra APIs. | ||
| 2 | |||
| 3 | #if __spirv__ | ||
| 4 | #error "Cooperative vectors not (yet) supported for SPIRV" | ||
| 5 | #endif | ||
| 6 | |||
| 7 | #if ((__SHADER_TARGET_MAJOR > 6) || \ | ||
| 8 | (__SHADER_TARGET_MAJOR == 6 && __SHADER_TARGET_MINOR >= 9)) && \ | ||
| 9 | (__HLSL_VERSION >= 2021) | ||
| 10 | |||
| 11 | namespace dx { | ||
| 12 | namespace linalg { | ||
| 13 | |||
| 14 | // NOTE: can't be an enum class because we get this error: | ||
| 15 | // error: non-type template argument of type 'dx::linalg::DataType' is not | ||
| 16 | // an integral constant expression | ||
| 17 | // | ||
| 18 | enum DataType { | ||
| 19 | DATA_TYPE_SINT16 = 2, // ComponentType::I16 | ||
| 20 | DATA_TYPE_UINT16 = 3, // ComponentType::U16 | ||
| 21 | DATA_TYPE_SINT32 = 4, // ComponentType::I32 | ||
| 22 | DATA_TYPE_UINT32 = 5, // ComponentType::U32 | ||
| 23 | DATA_TYPE_FLOAT16 = 8, // ComponentType::F16 | ||
| 24 | DATA_TYPE_FLOAT32 = 9, // ComponentType::F32 | ||
| 25 | DATA_TYPE_SINT8_T4_PACKED = 17, // ComponentType::PackedS8x32 | ||
| 26 | DATA_TYPE_UINT8_T4_PACKED = 18, // ComponentType::PackedU8x32 | ||
| 27 | DATA_TYPE_UINT8 = 19, // ComponentType::U8 | ||
| 28 | DATA_TYPE_SINT8 = 20, // ComponentType::I8 | ||
| 29 | DATA_TYPE_FLOAT8_E4M3 = 21, // ComponentType::F8_E4M3 | ||
| 30 | // (1 sign, 4 exp, 3 mantissa bits) | ||
| 31 | DATA_TYPE_FLOAT8_E5M2 = 22, // ComponentType::F8_E5M2 | ||
| 32 | // (1 sign, 5 exp, 2 mantissa bits) | ||
| 33 | }; | ||
| 34 | |||
| 35 | enum MatrixLayout { | ||
| 36 | MATRIX_LAYOUT_ROW_MAJOR = 0, | ||
| 37 | MATRIX_LAYOUT_COLUMN_MAJOR = 1, | ||
| 38 | MATRIX_LAYOUT_MUL_OPTIMAL = 2, | ||
| 39 | MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL = 3 | ||
| 40 | }; | ||
| 41 | |||
| 42 | // | ||
| 43 | // Helper for signedness | ||
| 44 | // | ||
| 45 | namespace details { | ||
| 46 | |||
| 47 | template <typename T> struct IsUnsigned {}; | ||
| 48 | |||
| 49 | #define _SPECIALIZE_ISUNSIGNED(type, value) \ | ||
| 50 | template <> struct IsUnsigned<type> { \ | ||
| 51 | static const bool Value = value; \ | ||
| 52 | } | ||
| 53 | |||
| 54 | _SPECIALIZE_ISUNSIGNED(uint8_t4_packed, true); | ||
| 55 | _SPECIALIZE_ISUNSIGNED(int8_t4_packed, true); | ||
| 56 | _SPECIALIZE_ISUNSIGNED(uint32_t, true); | ||
| 57 | _SPECIALIZE_ISUNSIGNED(int32_t, false); | ||
| 58 | _SPECIALIZE_ISUNSIGNED(float32_t, false); | ||
| 59 | |||
| 60 | #ifdef __HLSL_ENABLE_16_BIT | ||
| 61 | _SPECIALIZE_ISUNSIGNED(uint16_t, true); | ||
| 62 | _SPECIALIZE_ISUNSIGNED(int16_t, false); | ||
| 63 | _SPECIALIZE_ISUNSIGNED(float16_t, false); | ||
| 64 | #else // //__HLSL_ENABLE_16_BIT | ||
| 65 | _SPECIALIZE_ISUNSIGNED(half, false); | ||
| 66 | #endif //__HLSL_ENABLE_16_BIT | ||
| 67 | |||
| 68 | #undef _SPECIALIZE_ISUNSIGNED | ||
| 69 | |||
| 70 | } // namespace details | ||
| 71 | |||
| 72 | // | ||
| 73 | // (RW)MatrixRef | ||
| 74 | // | ||
| 75 | |||
| 76 | template <typename BufferTy, DataType DT, uint M, uint K, MatrixLayout ML, | ||
| 77 | bool Transpose> | ||
| 78 | struct MatrixRefImpl { | ||
| 79 | BufferTy Buffer; | ||
| 80 | uint StartOffset; | ||
| 81 | uint Stride; | ||
| 82 | }; | ||
| 83 | |||
| 84 | template <DataType DT, uint M, uint K, MatrixLayout ML, bool Transpose = false> | ||
| 85 | using MatrixRef = MatrixRefImpl<ByteAddressBuffer, DT, M, K, ML, Transpose>; | ||
| 86 | |||
| 87 | template <DataType DT, uint M, uint K, MatrixLayout ML, bool Transpose = false> | ||
| 88 | using RWMatrixRef = MatrixRefImpl<RWByteAddressBuffer, DT, M, K, ML, Transpose>; | ||
| 89 | |||
| 90 | // | ||
| 91 | // (RW)VectorRef | ||
| 92 | // | ||
| 93 | |||
| 94 | template <typename BufferTy, DataType DT> struct VectorRefImpl { | ||
| 95 | BufferTy Buffer; | ||
| 96 | uint StartOffset; | ||
| 97 | }; | ||
| 98 | |||
| 99 | template <DataType DT> using VectorRef = VectorRefImpl<ByteAddressBuffer, DT>; | ||
| 100 | |||
| 101 | template <DataType DT> | ||
| 102 | using RWVectorRef = VectorRefImpl<RWByteAddressBuffer, DT>; | ||
| 103 | |||
| 104 | // | ||
| 105 | // Vector | ||
| 106 | // | ||
| 107 | |||
| 108 | template <typename T, int N, DataType DT> struct InterpretedVector { | ||
| 109 | vector<T, N> Data; | ||
| 110 | }; | ||
| 111 | |||
| 112 | template <DataType DT, typename T, int N> | ||
| 113 | InterpretedVector<T, N, DT> MakeInterpretedVector(vector<T, N> Vec) { | ||
| 114 | InterpretedVector<T, N, DT> IV = {Vec}; | ||
| 115 | return IV; | ||
| 116 | } | ||
| 117 | |||
| 118 | // | ||
| 119 | // Mul | ||
| 120 | // | ||
| 121 | |||
| 122 | template <typename OutputElTy, typename InputElTy, int InputElCount, | ||
| 123 | typename MatrixBufferTy, DataType InputDT, DataType MatrixDT, | ||
| 124 | uint MatrixM, uint MatrixK, MatrixLayout MatrixLayout, | ||
| 125 | bool MatrixTranspose> | ||
| 126 | vector<OutputElTy, MatrixM> | ||
| 127 | Mul(MatrixRefImpl<MatrixBufferTy, MatrixDT, MatrixM, MatrixK, MatrixLayout, | ||
| 128 | MatrixTranspose> | ||
| 129 | Matrix, | ||
| 130 | InterpretedVector<InputElTy, InputElCount, InputDT> InputVector) { | ||
| 131 | |||
| 132 | vector<OutputElTy, MatrixM> OutputVector; | ||
| 133 | |||
| 134 | __builtin_MatVecMul( | ||
| 135 | /*out*/ OutputVector, details::IsUnsigned<OutputElTy>::Value, | ||
| 136 | InputVector.Data, details::IsUnsigned<InputElTy>::Value, InputDT, | ||
| 137 | Matrix.Buffer, Matrix.StartOffset, MatrixDT, MatrixM, MatrixK, | ||
| 138 | MatrixLayout, MatrixTranspose, Matrix.Stride); | ||
| 139 | |||
| 140 | return OutputVector; | ||
| 141 | } | ||
| 142 | |||
| 143 | // | ||
| 144 | // MulAdd | ||
| 145 | // | ||
| 146 | |||
| 147 | template <typename OutputElTy, typename InputElTy, int InputElCount, | ||
| 148 | typename MatrixBufferTy, DataType InputDT, DataType MatrixDT, | ||
| 149 | uint MatrixM, uint MatrixK, MatrixLayout MatrixLayout, | ||
| 150 | bool MatrixTranspose, typename BiasVectorBufferTy, | ||
| 151 | DataType BiasVectorDT> | ||
| 152 | vector<OutputElTy, MatrixM> | ||
| 153 | MulAdd(MatrixRefImpl<MatrixBufferTy, MatrixDT, MatrixM, MatrixK, MatrixLayout, | ||
| 154 | MatrixTranspose> | ||
| 155 | Matrix, | ||
| 156 | InterpretedVector<InputElTy, InputElCount, InputDT> InputVector, | ||
| 157 | VectorRefImpl<BiasVectorBufferTy, BiasVectorDT> BiasVector) { | ||
| 158 | |||
| 159 | vector<OutputElTy, MatrixM> OutputVector; | ||
| 160 | |||
| 161 | __builtin_MatVecMulAdd( | ||
| 162 | /*out*/ OutputVector, details::IsUnsigned<OutputElTy>::Value, | ||
| 163 | InputVector.Data, details::IsUnsigned<InputElTy>::Value, InputDT, | ||
| 164 | Matrix.Buffer, Matrix.StartOffset, MatrixDT, MatrixM, MatrixK, | ||
| 165 | MatrixLayout, MatrixTranspose, Matrix.Stride, BiasVector.Buffer, | ||
| 166 | BiasVector.StartOffset, BiasVectorDT); | ||
| 167 | |||
| 168 | return OutputVector; | ||
| 169 | } | ||
| 170 | |||
| 171 | // | ||
| 172 | // OuterProductAccumulate | ||
| 173 | // | ||
| 174 | |||
| 175 | template <typename ElTy, int MatrixM, int MatrixN, DataType MatrixDT, | ||
| 176 | MatrixLayout MatrixLayout> | ||
| 177 | void OuterProductAccumulate( | ||
| 178 | vector<ElTy, MatrixM> InputVector1, vector<ElTy, MatrixN> InputVector2, | ||
| 179 | RWMatrixRef<MatrixDT, MatrixM, MatrixN, MatrixLayout, false> Matrix) { | ||
| 180 | __builtin_OuterProductAccumulate(InputVector1, InputVector2, Matrix.Buffer, | ||
| 181 | Matrix.StartOffset, MatrixDT, MatrixLayout, | ||
| 182 | Matrix.Stride); | ||
| 183 | } | ||
| 184 | |||
| 185 | // | ||
| 186 | // VectorAccumulate | ||
| 187 | // | ||
| 188 | |||
| 189 | template <typename ElTy, int ElCount> | ||
| 190 | void VectorAccumulate(vector<ElTy, ElCount> InputVector, | ||
| 191 | RWByteAddressBuffer Buffer, uint Offset) { | ||
| 192 | __builtin_VectorAccumulate(InputVector, Buffer, Offset); | ||
| 193 | } | ||
| 194 | |||
| 195 | } // namespace linalg | ||
| 196 | } // namespace dx | ||
| 197 | |||
| 198 | #endif // SM 6.9 check and HV version check | ||
