// Copyright (c) 2024 Google LLC // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception #include "vk/opcode_selector.h" template [[vk::ext_instruction(/* OpMatrixTimesScalar */ 143)]] ResultType __builtin_spv_MatrixTimesScalar(ResultType a, ComponentType b); template [[vk::ext_instruction(/* OpCompositeExtract */ 81)]] ComponentType __builtin_spv_ExtractFromCooperativeMatrix( typename vk::khr::CooperativeMatrix::SpirvMatrixType matrix, uint32_t index); template [[vk::ext_instruction(/* OpCompositeConstruct */ 80)]] CoopMatrixType __builtin_spv_ConstructCooperativeMatrix(ComponentType value); template [[vk::ext_instruction(/* OpAccessChain */ 65)]] ResultPointerType __builtin_spv_AccessChain([[vk::ext_reference]] BaseType base, uint32_t index); template [[vk::ext_instruction(/* OpLoad */ 61)]] ObjectType __builtin_spv_LoadPointer(PointerType base); template [[vk::ext_instruction(/* OpLoad */ 62)]] void __builtin_spv_StorePointer(PointerType base, ObjectType object); template [[vk::ext_instruction(/* OpCompositeInsert */ 82)]] typename vk::khr::CooperativeMatrix::SpirvMatrixType __builtin_spv_InsertIntoCooperativeMatrix( ComponentType value, typename vk::khr::CooperativeMatrix::SpirvMatrixType matrix, uint32_t index); // Define the load and store instructions template [[vk::ext_instruction(/* OpCooperativeMatrixLoadKHR */ 4457)]] ResultType __builtin_spv_CooperativeMatrixLoadKHR( [[vk::ext_reference]] PointerType pointer, vk::CooperativeMatrixLayout memory_layout, uint stride, [[vk::ext_literal]] uint32_t memory_operand); template [[vk::ext_instruction(/* OpCooperativeMatrixLoadKHR */ 4457)]] ResultType __builtin_spv_CooperativeMatrixLoadKHR( [[vk::ext_reference]] PointerType pointer, vk::CooperativeMatrixLayout memory_layout, uint stride, [[vk::ext_literal]] uint32_t memory_operand, vk::Scope scope); template [[vk::ext_instruction(/* OpCooperativeMatrixLoadKHR */ 4457)]] ResultType __builtin_spv_CooperativeMatrixWorkgroupLoadKHR( vk::WorkgroupSpirvPointer pointer, vk::CooperativeMatrixLayout memory_layout, uint stride, [[vk::ext_literal]] uint32_t memory_operand, vk::Scope scope); template [[vk::ext_instruction(/* OpCooperativeMatrixStoreKHR */ 4458)]] void __builtin_spv_CooperativeMatrixStoreKHR( [[vk::ext_reference]] PointerType pointer, ObjectType object, vk::CooperativeMatrixLayout memory_layout, uint stride, [[vk::ext_literal]] uint32_t memory_operand, vk::Scope scope); template [[vk::ext_instruction(/* OpCooperativeMatrixStoreKHR */ 4458)]] void __builtin_spv_CooperativeMatrixStoreKHR( [[vk::ext_reference]] PointerType pointer, ObjectType object, vk::CooperativeMatrixLayout memory_layout, uint stride, [[vk::ext_literal]] uint32_t memory_operand); template [[vk::ext_instruction(/* OpCooperativeMatrixStoreKHR */ 4458)]] void __builtin_spv_CooperativeMatrixWorkgroupStoreKHR( vk::WorkgroupSpirvPointer pointer, ObjectType object, vk::CooperativeMatrixLayout memory_layout, uint stride, [[vk::ext_literal]] uint32_t memory_operand, vk::Scope scope); // We cannot define `OpCooperativeMatrixLengthKHR` using ext_instruction because // one of the operands is a type id. This builtin will have specific code in the // compiler to expand it. template uint __builtin_spv_CooperativeMatrixLengthKHR(); // Arithmetic Instructions template [[vk::ext_instruction(/* OpCooperativeMatrixMulAddKHR */ 4459)]] ResultType __builtin_spv_CooperativeMatrixMulAddKHR(MatrixTypeA a, MatrixTypeB b, MatrixTypeC c, [[vk::ext_literal]] int operands); namespace vk { namespace khr { template template CooperativeMatrix CooperativeMatrix::cast() { using ResultType = CooperativeMatrix; ResultType result; result._matrix = util::ConversionSelector:: template Convert(_matrix); return result; } template CooperativeMatrix CooperativeMatrix::negate() { CooperativeMatrix result; result._matrix = util::ArithmeticSelector::Negate(_matrix); return result; } template CooperativeMatrix CooperativeMatrix::operator+( CooperativeMatrix other) { CooperativeMatrix result; result._matrix = util::ArithmeticSelector::Add(_matrix, other._matrix); return result; } template CooperativeMatrix CooperativeMatrix::operator-( CooperativeMatrix other) { CooperativeMatrix result; result._matrix = util::ArithmeticSelector::Sub(_matrix, other._matrix); return result; } template CooperativeMatrix CooperativeMatrix::operator*( CooperativeMatrix other) { CooperativeMatrix result; result._matrix = util::ArithmeticSelector::Mul(_matrix, other._matrix); return result; } template CooperativeMatrix CooperativeMatrix::operator/( CooperativeMatrix other) { CooperativeMatrix result; result._matrix = util::ArithmeticSelector::Div(_matrix, other._matrix); return result; } template CooperativeMatrix CooperativeMatrix::operator*( ComponentType scalar) { CooperativeMatrix result; result._matrix = __builtin_spv_MatrixTimesScalar(_matrix, scalar); return result; } template template void CooperativeMatrix::Store( WorkgroupSpirvPointer data, uint32_t stride) { __builtin_spv_CooperativeMatrixWorkgroupStoreKHR( data, _matrix, layout, stride, memoryAccessOperands | MemoryAccessNonPrivatePointerMask | MemoryAccessMakePointerAvailableMask, ScopeWorkgroup); } template template void CooperativeMatrix::Store( RWStructuredBuffer data, uint32_t index, uint32_t stride) { __builtin_spv_CooperativeMatrixStoreKHR(data[index], _matrix, layout, stride, memoryAccessOperands); } template template void CooperativeMatrix::CoherentStore( globallycoherent RWStructuredBuffer data, uint32_t index, uint32_t stride) { __builtin_spv_CooperativeMatrixStoreKHR( data[index], _matrix, layout, stride, memoryAccessOperands | MemoryAccessNonPrivatePointerMask | MemoryAccessMakePointerAvailableMask, ScopeQueueFamily); } template template CooperativeMatrix CooperativeMatrix::Load( vk::WorkgroupSpirvPointer buffer, uint32_t stride) { CooperativeMatrix result; result._matrix = __builtin_spv_CooperativeMatrixWorkgroupLoadKHR( buffer, layout, stride, memoryAccessOperands | MemoryAccessNonPrivatePointerMask | MemoryAccessMakePointerVisibleMask, ScopeWorkgroup); return result; } template template CooperativeMatrix CooperativeMatrix::Load( RWStructuredBuffer buffer, uint32_t index, uint32_t stride) { CooperativeMatrix result; result._matrix = __builtin_spv_CooperativeMatrixLoadKHR( buffer[index], layout, stride, memoryAccessOperands); return result; } template template CooperativeMatrix CooperativeMatrix::CoherentLoad( RWStructuredBuffer buffer, uint32_t index, uint32_t stride) { CooperativeMatrix result; result._matrix = __builtin_spv_CooperativeMatrixLoadKHR( buffer[index], layout, stride, memoryAccessOperands | MemoryAccessNonPrivatePointerMask | MemoryAccessMakePointerVisibleMask, ScopeQueueFamily); return result; } template template CooperativeMatrix CooperativeMatrix::Load( StructuredBuffer buffer, uint32_t index, uint32_t stride) { CooperativeMatrix result; result._matrix = __builtin_spv_CooperativeMatrixLoadKHR( buffer[index], layout, stride, MemoryAccessMaskNone); return result; } template CooperativeMatrix CooperativeMatrix::Splat( ComponentType v) { CooperativeMatrix result; result._matrix = __builtin_spv_ConstructCooperativeMatrix(v); return result; } template uint CooperativeMatrix::GetLength() { return __builtin_spv_CooperativeMatrixLengthKHR(); } template ComponentType CooperativeMatrix::Get( uint32_t index) { // clang-format off using ComponentPtr = vk::SpirvOpaqueType< /* OpTypePointer */ 32, /* function storage class */ vk::Literal >, ComponentType>; // clang-format on ComponentPtr ptr = __builtin_spv_AccessChain(_matrix, index); return __builtin_spv_LoadPointer(ptr); } template void CooperativeMatrix::Set( ComponentType value, uint32_t index) { // clang-format off using ComponentPtr = vk::SpirvOpaqueType< /* OpTypePointer */ 32, /* function storage class */ vk::Literal >, ComponentType>; // clang-format on ComponentPtr ptr = __builtin_spv_AccessChain(_matrix, index); return __builtin_spv_StorePointer(ptr, value); } template CooperativeMatrixAccumulator cooperativeMatrixMultiplyAdd( CooperativeMatrixA a, CooperativeMatrixB b, CooperativeMatrixAccumulator c) { const vk::CooperativeMatrixOperandsMask allSignedComponents = vk::CooperativeMatrixOperandsMatrixASignedComponentsKHRMask | vk::CooperativeMatrixOperandsMatrixBSignedComponentsKHRMask | vk::CooperativeMatrixOperandsMatrixCSignedComponentsKHRMask | vk::CooperativeMatrixOperandsMatrixResultSignedComponentsKHRMask; const vk::CooperativeMatrixOperandsMask operands = (vk::CooperativeMatrixOperandsMask)( a.hasSignedIntegerComponentType ? allSignedComponents : vk::CooperativeMatrixOperandsMaskNone); CooperativeMatrixAccumulator result; result._matrix = __builtin_spv_CooperativeMatrixMulAddKHR< typename CooperativeMatrixAccumulator::SpirvMatrixType>( a._matrix, b._matrix, c._matrix, operands); return result; } template CooperativeMatrixAccumulator cooperativeMatrixSaturatingMultiplyAdd( CooperativeMatrixA a, CooperativeMatrixB b, CooperativeMatrixAccumulator c) { const vk::CooperativeMatrixOperandsMask allSignedComponents = vk::CooperativeMatrixOperandsMatrixASignedComponentsKHRMask | vk::CooperativeMatrixOperandsMatrixBSignedComponentsKHRMask | vk::CooperativeMatrixOperandsMatrixCSignedComponentsKHRMask | vk::CooperativeMatrixOperandsMatrixResultSignedComponentsKHRMask | vk::CooperativeMatrixOperandsSaturatingAccumulationKHRMask; const vk::CooperativeMatrixOperandsMask operands = (vk::CooperativeMatrixOperandsMask)( a.hasSignedIntegerComponentType ? allSignedComponents : vk::CooperativeMatrixOperandsSaturatingAccumulationKHRMask); CooperativeMatrixAccumulator result; result._matrix = __builtin_spv_CooperativeMatrixMulAddKHR< typename CooperativeMatrixAccumulator::SpirvMatrixType>( a._matrix, b._matrix, c._matrix, operands); return result; } } // namespace khr } // namespace vk