Spaces:
Running
Running
R0CKSTAR
commited on
Commit
·
ec2f307
1
Parent(s):
f0d6f5c
cuda : organize vendor-specific headers into vendors directory (llama/8746)
Browse files- ggml/src/ggml-cuda/common.cuh +4 -374
- ggml/src/ggml-cuda/vendors/cuda.h +14 -0
- ggml/src/ggml-cuda/vendors/hip.h +177 -0
- ggml/src/ggml-cuda/vendors/musa.h +171 -0
ggml/src/ggml-cuda/common.cuh
CHANGED
|
@@ -27,255 +27,11 @@
|
|
| 27 |
#include <vector>
|
| 28 |
|
| 29 |
#if defined(GGML_USE_HIPBLAS)
|
| 30 |
-
#include
|
| 31 |
-
#include <hipblas/hipblas.h>
|
| 32 |
-
#include <hip/hip_fp16.h>
|
| 33 |
-
#ifdef __HIP_PLATFORM_AMD__
|
| 34 |
-
// for rocblas_initialize()
|
| 35 |
-
#include "rocblas/rocblas.h"
|
| 36 |
-
#endif // __HIP_PLATFORM_AMD__
|
| 37 |
-
#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
|
| 38 |
-
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
|
| 39 |
-
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
|
| 40 |
-
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
|
| 41 |
-
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
|
| 42 |
-
#define CUBLAS_OP_N HIPBLAS_OP_N
|
| 43 |
-
#define CUBLAS_OP_T HIPBLAS_OP_T
|
| 44 |
-
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
|
| 45 |
-
#define CUBLAS_TF32_TENSOR_OP_MATH 0
|
| 46 |
-
#define CUDA_R_16F HIPBLAS_R_16F
|
| 47 |
-
#define CUDA_R_32F HIPBLAS_R_32F
|
| 48 |
-
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
|
| 49 |
-
#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
|
| 50 |
-
#define cublasCreate hipblasCreate
|
| 51 |
-
#define cublasDestroy hipblasDestroy
|
| 52 |
-
#define cublasGemmEx hipblasGemmEx
|
| 53 |
-
#define cublasGemmBatchedEx hipblasGemmBatchedEx
|
| 54 |
-
#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
|
| 55 |
-
#define cublasHandle_t hipblasHandle_t
|
| 56 |
-
#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
|
| 57 |
-
#define cublasSetStream hipblasSetStream
|
| 58 |
-
#define cublasSgemm hipblasSgemm
|
| 59 |
-
#define cublasStatus_t hipblasStatus_t
|
| 60 |
-
#define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6
|
| 61 |
-
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
|
| 62 |
-
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
|
| 63 |
-
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
|
| 64 |
-
#define cudaDeviceProp hipDeviceProp_t
|
| 65 |
-
#define cudaDeviceSynchronize hipDeviceSynchronize
|
| 66 |
-
#define cudaError_t hipError_t
|
| 67 |
-
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
|
| 68 |
-
#define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled
|
| 69 |
-
#define cudaEventCreateWithFlags hipEventCreateWithFlags
|
| 70 |
-
#define cudaEventDisableTiming hipEventDisableTiming
|
| 71 |
-
#define cudaEventRecord hipEventRecord
|
| 72 |
-
#define cudaEventSynchronize hipEventSynchronize
|
| 73 |
-
#define cudaEvent_t hipEvent_t
|
| 74 |
-
#define cudaEventDestroy hipEventDestroy
|
| 75 |
-
#define cudaFree hipFree
|
| 76 |
-
#define cudaFreeHost hipHostFree
|
| 77 |
-
#define cudaGetDevice hipGetDevice
|
| 78 |
-
#define cudaGetDeviceCount hipGetDeviceCount
|
| 79 |
-
#define cudaGetDeviceProperties hipGetDeviceProperties
|
| 80 |
-
#define cudaGetErrorString hipGetErrorString
|
| 81 |
-
#define cudaGetLastError hipGetLastError
|
| 82 |
-
#define cudaHostRegister hipHostRegister
|
| 83 |
-
#define cudaHostRegisterPortable hipHostRegisterPortable
|
| 84 |
-
#define cudaHostRegisterReadOnly hipHostRegisterReadOnly
|
| 85 |
-
#define cudaHostUnregister hipHostUnregister
|
| 86 |
-
#define cudaLaunchHostFunc hipLaunchHostFunc
|
| 87 |
-
#define cudaMalloc hipMalloc
|
| 88 |
-
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
|
| 89 |
-
#define cudaMemcpy hipMemcpy
|
| 90 |
-
#define cudaMemcpyAsync hipMemcpyAsync
|
| 91 |
-
#define cudaMemcpyPeerAsync hipMemcpyPeerAsync
|
| 92 |
-
#define cudaMemcpy2DAsync hipMemcpy2DAsync
|
| 93 |
-
#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
|
| 94 |
-
#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
|
| 95 |
-
#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
|
| 96 |
-
#define cudaMemcpyKind hipMemcpyKind
|
| 97 |
-
#define cudaMemset hipMemset
|
| 98 |
-
#define cudaMemsetAsync hipMemsetAsync
|
| 99 |
-
#define cudaMemGetInfo hipMemGetInfo
|
| 100 |
-
#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
|
| 101 |
-
#define cudaSetDevice hipSetDevice
|
| 102 |
-
#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
|
| 103 |
-
#define cudaStreamDestroy hipStreamDestroy
|
| 104 |
-
#define cudaStreamFireAndForget hipStreamFireAndForget
|
| 105 |
-
#define cudaStreamNonBlocking hipStreamNonBlocking
|
| 106 |
-
#define cudaStreamPerThread hipStreamPerThread
|
| 107 |
-
#define cudaStreamSynchronize hipStreamSynchronize
|
| 108 |
-
#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
|
| 109 |
-
#define cudaStream_t hipStream_t
|
| 110 |
-
#define cudaSuccess hipSuccess
|
| 111 |
-
#define __trap() do { abort(); __builtin_unreachable(); } while(0)
|
| 112 |
-
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
|
| 113 |
-
#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
|
| 114 |
-
#define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED
|
| 115 |
-
#define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE
|
| 116 |
-
#define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH
|
| 117 |
-
#define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR
|
| 118 |
-
#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED
|
| 119 |
-
#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
|
| 120 |
-
#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
|
| 121 |
#elif defined(GGML_USE_MUSA)
|
| 122 |
-
#include
|
| 123 |
-
#include <musa.h>
|
| 124 |
-
#include <mublas.h>
|
| 125 |
-
#include <musa_fp16.h>
|
| 126 |
-
// XXX: Keep the following order the same as hipBLAS
|
| 127 |
-
// #define CUBLAS_COMPUTE_16F MUBLAS_COMPUTE_16F
|
| 128 |
-
// #define CUBLAS_COMPUTE_32F MUBLAS_COMPUTE_32F
|
| 129 |
-
#define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F
|
| 130 |
-
#define CUBLAS_GEMM_DEFAULT MUBLAS_GEMM_DEFAULT
|
| 131 |
-
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT
|
| 132 |
-
#define CUBLAS_OP_N MUBLAS_OP_N
|
| 133 |
-
#define CUBLAS_OP_T MUBLAS_OP_T
|
| 134 |
-
#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS
|
| 135 |
-
// #define CUBLAS_TF32_TENSOR_OP_MATH 0
|
| 136 |
-
#define CUDA_R_16F MUSA_R_16F
|
| 137 |
-
#define CUDA_R_32F MUSA_R_32F
|
| 138 |
-
// #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
|
| 139 |
-
// #define cublasComputeType_t mublasComputeType_t
|
| 140 |
-
#define cublasCreate mublasCreate
|
| 141 |
-
#define cublasDestroy mublasDestroy
|
| 142 |
-
#define cublasGemmEx mublasGemmEx
|
| 143 |
-
#define cublasGemmBatchedEx mublasGemmBatchedEx
|
| 144 |
-
#define cublasGemmStridedBatchedEx mublasGemmStridedBatchedEx
|
| 145 |
-
#define cublasHandle_t mublasHandle_t
|
| 146 |
-
// #define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
|
| 147 |
-
#define cublasSetMathMode mublasSetMathMode
|
| 148 |
-
#define cublasSetStream mublasSetStream
|
| 149 |
-
#define cublasSgemm mublasSgemm
|
| 150 |
-
#define cublasStatus_t mublasStatus_t
|
| 151 |
-
#define cudaDataType_t musaDataType_t //deprecated, new hipblasDatatype not in 5.6
|
| 152 |
-
#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
|
| 153 |
-
#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess
|
| 154 |
-
#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess
|
| 155 |
-
#define cudaDeviceProp musaDeviceProp
|
| 156 |
-
#define cudaDeviceSynchronize musaDeviceSynchronize
|
| 157 |
-
#define cudaError_t musaError_t
|
| 158 |
-
#define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled
|
| 159 |
-
#define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled
|
| 160 |
-
#define cudaEventCreateWithFlags musaEventCreateWithFlags
|
| 161 |
-
#define cudaEventDisableTiming musaEventDisableTiming
|
| 162 |
-
#define cudaEventRecord musaEventRecord
|
| 163 |
-
#define cudaEventSynchronize musaEventSynchronize
|
| 164 |
-
#define cudaEvent_t musaEvent_t
|
| 165 |
-
#define cudaEventDestroy musaEventDestroy
|
| 166 |
-
#define cudaFree musaFree
|
| 167 |
-
#define cudaFreeHost musaFreeHost
|
| 168 |
-
#define cudaGetDevice musaGetDevice
|
| 169 |
-
#define cudaGetDeviceCount musaGetDeviceCount
|
| 170 |
-
#define cudaGetDeviceProperties musaGetDeviceProperties
|
| 171 |
-
#define cudaGetErrorString musaGetErrorString
|
| 172 |
-
#define cudaGetLastError musaGetLastError
|
| 173 |
-
#define cudaHostRegister musaHostRegister
|
| 174 |
-
#define cudaHostRegisterPortable musaHostRegisterPortable
|
| 175 |
-
#define cudaHostRegisterReadOnly musaHostRegisterReadOnly
|
| 176 |
-
#define cudaHostUnregister musaHostUnregister
|
| 177 |
-
#define cudaLaunchHostFunc musaLaunchHostFunc
|
| 178 |
-
#define cudaMalloc musaMalloc
|
| 179 |
-
#define cudaMallocHost musaMallocHost
|
| 180 |
-
#define cudaMemcpy musaMemcpy
|
| 181 |
-
#define cudaMemcpyAsync musaMemcpyAsync
|
| 182 |
-
#define cudaMemcpyPeerAsync musaMemcpyPeerAsync
|
| 183 |
-
#define cudaMemcpy2DAsync musaMemcpy2DAsync
|
| 184 |
-
#define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice
|
| 185 |
-
#define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost
|
| 186 |
-
#define cudaMemcpyHostToDevice musaMemcpyHostToDevice
|
| 187 |
-
#define cudaMemcpyKind musaMemcpyKind
|
| 188 |
-
#define cudaMemset musaMemset
|
| 189 |
-
#define cudaMemsetAsync musaMemsetAsync
|
| 190 |
-
#define cudaMemGetInfo musaMemGetInfo
|
| 191 |
-
#define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize
|
| 192 |
-
#define cudaSetDevice musaSetDevice
|
| 193 |
-
#define cudaStreamCreateWithFlags musaStreamCreateWithFlags
|
| 194 |
-
#define cudaStreamDestroy musaStreamDestroy
|
| 195 |
-
#define cudaStreamFireAndForget musaStreamFireAndForget
|
| 196 |
-
#define cudaStreamNonBlocking musaStreamNonBlocking
|
| 197 |
-
#define cudaStreamPerThread musaStreamPerThread
|
| 198 |
-
#define cudaStreamSynchronize musaStreamSynchronize
|
| 199 |
-
#define cudaStreamWaitEvent musaStreamWaitEvent
|
| 200 |
-
#define cudaStream_t musaStream_t
|
| 201 |
-
#define cudaSuccess musaSuccess
|
| 202 |
-
|
| 203 |
-
// XXX: Other CUDA => MUSA mapping
|
| 204 |
-
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE MU_MEM_ACCESS_FLAGS_PROT_READWRITE
|
| 205 |
-
#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED
|
| 206 |
-
#define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED
|
| 207 |
-
#define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE
|
| 208 |
-
#define CUdevice MUdevice
|
| 209 |
-
#define CUdeviceptr MUdeviceptr
|
| 210 |
-
#define CUmemAccessDesc MUmemAccessDesc
|
| 211 |
-
#define CUmemAllocationProp MUmemAllocationProp
|
| 212 |
-
#define CUmemGenericAllocationHandle MUmemGenericAllocationHandle
|
| 213 |
-
#define cuDeviceGet muDeviceGet
|
| 214 |
-
#define cuDeviceGetAttribute muDeviceGetAttribute
|
| 215 |
-
#define cuMemAddressFree muMemAddressFree
|
| 216 |
-
#define cuMemAddressReserve muMemAddressReserve
|
| 217 |
-
#define cuMemCreate muMemCreate
|
| 218 |
-
#define cuMemGetAllocationGranularity muMemGetAllocationGranularity
|
| 219 |
-
#define cuMemMap muMemMap
|
| 220 |
-
#define cuMemRelease muMemRelease
|
| 221 |
-
#define cuMemSetAccess muMemSetAccess
|
| 222 |
-
#define cuMemUnmap muMemUnmap
|
| 223 |
-
#define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize
|
| 224 |
-
#define cudaFuncSetAttribute musaFuncSetAttribute
|
| 225 |
-
#define cudaMemcpy3DPeerParms musaMemcpy3DPeerParms
|
| 226 |
-
#define make_cudaExtent make_musaExtent
|
| 227 |
-
#define make_cudaPitchedPtr make_musaPitchedPtr
|
| 228 |
-
|
| 229 |
-
// XXX: USE_CUDA_GRAPH
|
| 230 |
-
#define CUDA_SUCCESS MUSA_SUCCESS
|
| 231 |
-
#define CUresult MUresult
|
| 232 |
-
#define cuGetErrorString muGetErrorString
|
| 233 |
-
#define cudaErrorGraphExecUpdateFailure musaErrorGraphExecUpdateFailure
|
| 234 |
-
#define cudaErrorInvalidDeviceFunction musaErrorInvalidDeviceFunction
|
| 235 |
-
#define cudaGraphDestroy musaGraphDestroy
|
| 236 |
-
#define cudaGraphExecDestroy musaGraphExecDestroy
|
| 237 |
-
#define cudaGraphExec_t musaGraphExec_t
|
| 238 |
-
#define cudaGraphExecUpdate musaGraphExecUpdate
|
| 239 |
-
#define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult
|
| 240 |
-
#define cudaGraphGetNodes musaGraphGetNodes
|
| 241 |
-
#define cudaGraphInstantiate musaGraphInstantiate
|
| 242 |
-
#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
|
| 243 |
-
#define cudaGraphKernelNodeSetParams musaGraphKernelNodeSetParams
|
| 244 |
-
#define cudaGraphLaunch musaGraphLaunch
|
| 245 |
-
#define cudaGraphNodeGetType musaGraphNodeGetType
|
| 246 |
-
#define cudaGraphNode_t musaGraphNode_t
|
| 247 |
-
#define cudaGraphNodeType musaGraphNodeType
|
| 248 |
-
#define cudaGraphNodeTypeKernel musaGraphNodeTypeKernel
|
| 249 |
-
#define cudaGraph_t musaGraph_t
|
| 250 |
-
#define cudaKernelNodeParams musaKernelNodeParams
|
| 251 |
-
#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
|
| 252 |
-
#define cudaStreamEndCapture musaStreamEndCapture
|
| 253 |
-
|
| 254 |
-
// XXX: cuBLAS => muBLAS mapping
|
| 255 |
-
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
|
| 256 |
-
#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT
|
| 257 |
-
#define CUBLAS_COMPUTE_16F CUDA_R_16F
|
| 258 |
-
#define CUBLAS_COMPUTE_32F CUDA_R_32F
|
| 259 |
-
#define cublasComputeType_t cudaDataType_t
|
| 260 |
-
|
| 261 |
-
// XXX: Clang builtins mapping
|
| 262 |
-
#define __vsub4 __vsub4_musa
|
| 263 |
-
#define __vcmpeq4 __vcmpeq4_musa
|
| 264 |
-
#define __vcmpne4 __vcmpne4_musa
|
| 265 |
#else
|
| 266 |
-
#include
|
| 267 |
-
#include <cuda.h>
|
| 268 |
-
#include <cublas_v2.h>
|
| 269 |
-
#include <cuda_fp16.h>
|
| 270 |
-
|
| 271 |
-
#if CUDART_VERSION < 11020
|
| 272 |
-
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
|
| 273 |
-
#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
|
| 274 |
-
#define CUBLAS_COMPUTE_16F CUDA_R_16F
|
| 275 |
-
#define CUBLAS_COMPUTE_32F CUDA_R_32F
|
| 276 |
-
#define cublasComputeType_t cudaDataType_t
|
| 277 |
-
#endif // CUDART_VERSION < 11020
|
| 278 |
-
|
| 279 |
#endif // defined(GGML_USE_HIPBLAS)
|
| 280 |
|
| 281 |
#define STRINGIZE_IMPL(...) #__VA_ARGS__
|
|
@@ -318,11 +74,7 @@ void ggml_cuda_error(const char * stmt, const char * func, const char * file, in
|
|
| 318 |
|
| 319 |
#if CUDART_VERSION >= 12000 || defined(GGML_USE_MUSA)
|
| 320 |
static const char * cublas_get_error_str(const cublasStatus_t err) {
|
| 321 |
-
#ifndef GGML_USE_MUSA
|
| 322 |
return cublasGetStatusString(err);
|
| 323 |
-
#else
|
| 324 |
-
return mublasStatus_to_string(err);
|
| 325 |
-
#endif // GGML_USE_MUSA
|
| 326 |
}
|
| 327 |
#else
|
| 328 |
static const char * cublas_get_error_str(const cublasStatus_t err) {
|
|
@@ -364,129 +116,7 @@ typedef half2 dfloat2;
|
|
| 364 |
#else
|
| 365 |
typedef float dfloat; // dequantize float
|
| 366 |
typedef float2 dfloat2;
|
| 367 |
-
#endif //GGML_CUDA_F16
|
| 368 |
-
|
| 369 |
-
#if defined(GGML_USE_MUSA)
|
| 370 |
-
#ifndef __has_builtin
|
| 371 |
-
#define __has_builtin(x) 0
|
| 372 |
-
#endif
|
| 373 |
-
|
| 374 |
-
typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
|
| 375 |
-
|
| 376 |
-
static __device__ __forceinline__ int __vsub4_musa(const int a, const int b) {
|
| 377 |
-
return __vsubss4(a, b);
|
| 378 |
-
}
|
| 379 |
-
|
| 380 |
-
static __device__ __forceinline__ unsigned int __vcmpeq4_musa(unsigned int a, unsigned int b) {
|
| 381 |
-
const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
|
| 382 |
-
const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
|
| 383 |
-
unsigned int c;
|
| 384 |
-
uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
|
| 385 |
-
#pragma unroll
|
| 386 |
-
for (int i = 0; i < 4; ++i) {
|
| 387 |
-
vc[i] = va[i] == vb[i] ? 0xff : 0x00;
|
| 388 |
-
}
|
| 389 |
-
return c;
|
| 390 |
-
}
|
| 391 |
-
|
| 392 |
-
static __device__ __forceinline__ unsigned int __vcmpne4_musa(unsigned int a, unsigned int b) {
|
| 393 |
-
const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
|
| 394 |
-
const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
|
| 395 |
-
unsigned int c;
|
| 396 |
-
uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
|
| 397 |
-
#pragma unroll
|
| 398 |
-
for (int i = 0; i < 4; ++i) {
|
| 399 |
-
vc[i] = va[i] == vb[i] ? 0x00 : 0xff;
|
| 400 |
-
}
|
| 401 |
-
return c;
|
| 402 |
-
}
|
| 403 |
-
#endif // defined(GGML_USE_MUSA)
|
| 404 |
-
|
| 405 |
-
#if defined(GGML_USE_HIPBLAS)
|
| 406 |
-
#define __CUDA_ARCH__ 1300
|
| 407 |
-
|
| 408 |
-
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
|
| 409 |
-
defined(__gfx1150__) || defined(__gfx1151__)
|
| 410 |
-
#define RDNA3
|
| 411 |
-
#endif
|
| 412 |
-
|
| 413 |
-
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \
|
| 414 |
-
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
|
| 415 |
-
#define RDNA2
|
| 416 |
-
#endif
|
| 417 |
-
|
| 418 |
-
#if defined(__gfx1010__) || defined(__gfx1012__)
|
| 419 |
-
#define RDNA1
|
| 420 |
-
#endif
|
| 421 |
-
|
| 422 |
-
#ifndef __has_builtin
|
| 423 |
-
#define __has_builtin(x) 0
|
| 424 |
-
#endif
|
| 425 |
-
|
| 426 |
-
typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
|
| 427 |
-
typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
|
| 428 |
-
static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
|
| 429 |
-
const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
|
| 430 |
-
const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
|
| 431 |
-
#if __has_builtin(__builtin_elementwise_sub_sat)
|
| 432 |
-
const int8x4_t c = __builtin_elementwise_sub_sat(va, vb);
|
| 433 |
-
return reinterpret_cast<const int &>(c);
|
| 434 |
-
#else
|
| 435 |
-
int8x4_t c;
|
| 436 |
-
int16_t tmp;
|
| 437 |
-
#pragma unroll
|
| 438 |
-
for (int i = 0; i < 4; i++) {
|
| 439 |
-
tmp = va[i] - vb[i];
|
| 440 |
-
if(tmp > std::numeric_limits<int8_t>::max()) tmp = std::numeric_limits<int8_t>::max();
|
| 441 |
-
if(tmp < std::numeric_limits<int8_t>::min()) tmp = std::numeric_limits<int8_t>::min();
|
| 442 |
-
c[i] = tmp;
|
| 443 |
-
}
|
| 444 |
-
return reinterpret_cast<int &>(c);
|
| 445 |
-
#endif // __has_builtin(__builtin_elementwise_sub_sat)
|
| 446 |
-
}
|
| 447 |
-
|
| 448 |
-
static __device__ __forceinline__ int __vsub4(const int a, const int b) {
|
| 449 |
-
return __vsubss4(a, b);
|
| 450 |
-
}
|
| 451 |
-
|
| 452 |
-
static __device__ __forceinline__ unsigned int __vcmpeq4(unsigned int a, unsigned int b) {
|
| 453 |
-
const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
|
| 454 |
-
const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
|
| 455 |
-
unsigned int c;
|
| 456 |
-
uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
|
| 457 |
-
#pragma unroll
|
| 458 |
-
for (int i = 0; i < 4; ++i) {
|
| 459 |
-
vc[i] = va[i] == vb[i] ? 0xff : 0x00;
|
| 460 |
-
}
|
| 461 |
-
return c;
|
| 462 |
-
}
|
| 463 |
-
|
| 464 |
-
static __device__ __forceinline__ unsigned int __vcmpne4(unsigned int a, unsigned int b) {
|
| 465 |
-
const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
|
| 466 |
-
const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
|
| 467 |
-
unsigned int c;
|
| 468 |
-
uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
|
| 469 |
-
#pragma unroll
|
| 470 |
-
for (int i = 0; i < 4; ++i) {
|
| 471 |
-
vc[i] = va[i] == vb[i] ? 0x00 : 0xff;
|
| 472 |
-
}
|
| 473 |
-
return c;
|
| 474 |
-
}
|
| 475 |
-
|
| 476 |
-
#if defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000
|
| 477 |
-
// __shfl_xor() for half2 was added in ROCm 5.6
|
| 478 |
-
static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int width) {
|
| 479 |
-
typedef union half2_b32 {
|
| 480 |
-
half2 val;
|
| 481 |
-
int b32;
|
| 482 |
-
} half2_b32_t;
|
| 483 |
-
half2_b32_t tmp;
|
| 484 |
-
tmp.val = var;
|
| 485 |
-
tmp.b32 = __shfl_xor(tmp.b32, laneMask, width);
|
| 486 |
-
return tmp.val;
|
| 487 |
-
}
|
| 488 |
-
#endif // defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000
|
| 489 |
-
#endif // defined(GGML_USE_HIPBLAS)
|
| 490 |
|
| 491 |
#if (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
|
| 492 |
#define FP16_AVAILABLE
|
|
|
|
| 27 |
#include <vector>
|
| 28 |
|
| 29 |
#if defined(GGML_USE_HIPBLAS)
|
| 30 |
+
#include "vendors/hip.h"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
#elif defined(GGML_USE_MUSA)
|
| 32 |
+
#include "vendors/musa.h"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
#else
|
| 34 |
+
#include "vendors/cuda.h"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
#endif // defined(GGML_USE_HIPBLAS)
|
| 36 |
|
| 37 |
#define STRINGIZE_IMPL(...) #__VA_ARGS__
|
|
|
|
| 74 |
|
| 75 |
#if CUDART_VERSION >= 12000 || defined(GGML_USE_MUSA)
|
| 76 |
static const char * cublas_get_error_str(const cublasStatus_t err) {
|
|
|
|
| 77 |
return cublasGetStatusString(err);
|
|
|
|
|
|
|
|
|
|
| 78 |
}
|
| 79 |
#else
|
| 80 |
static const char * cublas_get_error_str(const cublasStatus_t err) {
|
|
|
|
| 116 |
#else
|
| 117 |
typedef float dfloat; // dequantize float
|
| 118 |
typedef float2 dfloat2;
|
| 119 |
+
#endif // GGML_CUDA_F16
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
|
| 121 |
#if (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
|
| 122 |
#define FP16_AVAILABLE
|
ggml/src/ggml-cuda/vendors/cuda.h
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <cuda_runtime.h>
|
| 4 |
+
#include <cuda.h>
|
| 5 |
+
#include <cublas_v2.h>
|
| 6 |
+
#include <cuda_fp16.h>
|
| 7 |
+
|
| 8 |
+
#if CUDART_VERSION < 11020
|
| 9 |
+
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
|
| 10 |
+
#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
|
| 11 |
+
#define CUBLAS_COMPUTE_16F CUDA_R_16F
|
| 12 |
+
#define CUBLAS_COMPUTE_32F CUDA_R_32F
|
| 13 |
+
#define cublasComputeType_t cudaDataType_t
|
| 14 |
+
#endif // CUDART_VERSION < 11020
|
ggml/src/ggml-cuda/vendors/hip.h
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <hip/hip_runtime.h>
|
| 4 |
+
#include <hipblas/hipblas.h>
|
| 5 |
+
#include <hip/hip_fp16.h>
|
| 6 |
+
#ifdef __HIP_PLATFORM_AMD__
|
| 7 |
+
// for rocblas_initialize()
|
| 8 |
+
#include "rocblas/rocblas.h"
|
| 9 |
+
#endif // __HIP_PLATFORM_AMD__
|
| 10 |
+
#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
|
| 11 |
+
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
|
| 12 |
+
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
|
| 13 |
+
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
|
| 14 |
+
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
|
| 15 |
+
#define CUBLAS_OP_N HIPBLAS_OP_N
|
| 16 |
+
#define CUBLAS_OP_T HIPBLAS_OP_T
|
| 17 |
+
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
|
| 18 |
+
#define CUBLAS_TF32_TENSOR_OP_MATH 0
|
| 19 |
+
#define CUDA_R_16F HIPBLAS_R_16F
|
| 20 |
+
#define CUDA_R_32F HIPBLAS_R_32F
|
| 21 |
+
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
|
| 22 |
+
#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
|
| 23 |
+
#define cublasCreate hipblasCreate
|
| 24 |
+
#define cublasDestroy hipblasDestroy
|
| 25 |
+
#define cublasGemmEx hipblasGemmEx
|
| 26 |
+
#define cublasGemmBatchedEx hipblasGemmBatchedEx
|
| 27 |
+
#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
|
| 28 |
+
#define cublasHandle_t hipblasHandle_t
|
| 29 |
+
#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
|
| 30 |
+
#define cublasSetStream hipblasSetStream
|
| 31 |
+
#define cublasSgemm hipblasSgemm
|
| 32 |
+
#define cublasStatus_t hipblasStatus_t
|
| 33 |
+
#define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6
|
| 34 |
+
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
|
| 35 |
+
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
|
| 36 |
+
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
|
| 37 |
+
#define cudaDeviceProp hipDeviceProp_t
|
| 38 |
+
#define cudaDeviceSynchronize hipDeviceSynchronize
|
| 39 |
+
#define cudaError_t hipError_t
|
| 40 |
+
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
|
| 41 |
+
#define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled
|
| 42 |
+
#define cudaEventCreateWithFlags hipEventCreateWithFlags
|
| 43 |
+
#define cudaEventDisableTiming hipEventDisableTiming
|
| 44 |
+
#define cudaEventRecord hipEventRecord
|
| 45 |
+
#define cudaEventSynchronize hipEventSynchronize
|
| 46 |
+
#define cudaEvent_t hipEvent_t
|
| 47 |
+
#define cudaEventDestroy hipEventDestroy
|
| 48 |
+
#define cudaFree hipFree
|
| 49 |
+
#define cudaFreeHost hipHostFree
|
| 50 |
+
#define cudaGetDevice hipGetDevice
|
| 51 |
+
#define cudaGetDeviceCount hipGetDeviceCount
|
| 52 |
+
#define cudaGetDeviceProperties hipGetDeviceProperties
|
| 53 |
+
#define cudaGetErrorString hipGetErrorString
|
| 54 |
+
#define cudaGetLastError hipGetLastError
|
| 55 |
+
#define cudaHostRegister hipHostRegister
|
| 56 |
+
#define cudaHostRegisterPortable hipHostRegisterPortable
|
| 57 |
+
#define cudaHostRegisterReadOnly hipHostRegisterReadOnly
|
| 58 |
+
#define cudaHostUnregister hipHostUnregister
|
| 59 |
+
#define cudaLaunchHostFunc hipLaunchHostFunc
|
| 60 |
+
#define cudaMalloc hipMalloc
|
| 61 |
+
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
|
| 62 |
+
#define cudaMemcpy hipMemcpy
|
| 63 |
+
#define cudaMemcpyAsync hipMemcpyAsync
|
| 64 |
+
#define cudaMemcpyPeerAsync hipMemcpyPeerAsync
|
| 65 |
+
#define cudaMemcpy2DAsync hipMemcpy2DAsync
|
| 66 |
+
#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
|
| 67 |
+
#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
|
| 68 |
+
#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
|
| 69 |
+
#define cudaMemcpyKind hipMemcpyKind
|
| 70 |
+
#define cudaMemset hipMemset
|
| 71 |
+
#define cudaMemsetAsync hipMemsetAsync
|
| 72 |
+
#define cudaMemGetInfo hipMemGetInfo
|
| 73 |
+
#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
|
| 74 |
+
#define cudaSetDevice hipSetDevice
|
| 75 |
+
#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
|
| 76 |
+
#define cudaStreamDestroy hipStreamDestroy
|
| 77 |
+
#define cudaStreamFireAndForget hipStreamFireAndForget
|
| 78 |
+
#define cudaStreamNonBlocking hipStreamNonBlocking
|
| 79 |
+
#define cudaStreamPerThread hipStreamPerThread
|
| 80 |
+
#define cudaStreamSynchronize hipStreamSynchronize
|
| 81 |
+
#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
|
| 82 |
+
#define cudaStream_t hipStream_t
|
| 83 |
+
#define cudaSuccess hipSuccess
|
| 84 |
+
#define __trap() do { abort(); __builtin_unreachable(); } while(0)
|
| 85 |
+
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
|
| 86 |
+
#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
|
| 87 |
+
#define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED
|
| 88 |
+
#define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE
|
| 89 |
+
#define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH
|
| 90 |
+
#define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR
|
| 91 |
+
#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED
|
| 92 |
+
#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
|
| 93 |
+
#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
|
| 94 |
+
|
| 95 |
+
#define __CUDA_ARCH__ 1300
|
| 96 |
+
|
| 97 |
+
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
|
| 98 |
+
defined(__gfx1150__) || defined(__gfx1151__)
|
| 99 |
+
#define RDNA3
|
| 100 |
+
#endif
|
| 101 |
+
|
| 102 |
+
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \
|
| 103 |
+
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
|
| 104 |
+
#define RDNA2
|
| 105 |
+
#endif
|
| 106 |
+
|
| 107 |
+
#if defined(__gfx1010__) || defined(__gfx1012__)
|
| 108 |
+
#define RDNA1
|
| 109 |
+
#endif
|
| 110 |
+
|
| 111 |
+
#ifndef __has_builtin
|
| 112 |
+
#define __has_builtin(x) 0
|
| 113 |
+
#endif
|
| 114 |
+
|
| 115 |
+
typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
|
| 116 |
+
typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
|
| 117 |
+
static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
|
| 118 |
+
const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
|
| 119 |
+
const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
|
| 120 |
+
#if __has_builtin(__builtin_elementwise_sub_sat)
|
| 121 |
+
const int8x4_t c = __builtin_elementwise_sub_sat(va, vb);
|
| 122 |
+
return reinterpret_cast<const int &>(c);
|
| 123 |
+
#else
|
| 124 |
+
int8x4_t c;
|
| 125 |
+
int16_t tmp;
|
| 126 |
+
#pragma unroll
|
| 127 |
+
for (int i = 0; i < 4; i++) {
|
| 128 |
+
tmp = va[i] - vb[i];
|
| 129 |
+
if(tmp > std::numeric_limits<int8_t>::max()) tmp = std::numeric_limits<int8_t>::max();
|
| 130 |
+
if(tmp < std::numeric_limits<int8_t>::min()) tmp = std::numeric_limits<int8_t>::min();
|
| 131 |
+
c[i] = tmp;
|
| 132 |
+
}
|
| 133 |
+
return reinterpret_cast<int &>(c);
|
| 134 |
+
#endif // __has_builtin(__builtin_elementwise_sub_sat)
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
static __device__ __forceinline__ int __vsub4(const int a, const int b) {
|
| 138 |
+
return __vsubss4(a, b);
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
static __device__ __forceinline__ unsigned int __vcmpeq4(unsigned int a, unsigned int b) {
|
| 142 |
+
const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
|
| 143 |
+
const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
|
| 144 |
+
unsigned int c;
|
| 145 |
+
uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
|
| 146 |
+
#pragma unroll
|
| 147 |
+
for (int i = 0; i < 4; ++i) {
|
| 148 |
+
vc[i] = va[i] == vb[i] ? 0xff : 0x00;
|
| 149 |
+
}
|
| 150 |
+
return c;
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
static __device__ __forceinline__ unsigned int __vcmpne4(unsigned int a, unsigned int b) {
|
| 154 |
+
const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
|
| 155 |
+
const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
|
| 156 |
+
unsigned int c;
|
| 157 |
+
uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
|
| 158 |
+
#pragma unroll
|
| 159 |
+
for (int i = 0; i < 4; ++i) {
|
| 160 |
+
vc[i] = va[i] == vb[i] ? 0x00 : 0xff;
|
| 161 |
+
}
|
| 162 |
+
return c;
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
#if defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000
|
| 166 |
+
// __shfl_xor() for half2 was added in ROCm 5.6
|
| 167 |
+
static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int width) {
|
| 168 |
+
typedef union half2_b32 {
|
| 169 |
+
half2 val;
|
| 170 |
+
int b32;
|
| 171 |
+
} half2_b32_t;
|
| 172 |
+
half2_b32_t tmp;
|
| 173 |
+
tmp.val = var;
|
| 174 |
+
tmp.b32 = __shfl_xor(tmp.b32, laneMask, width);
|
| 175 |
+
return tmp.val;
|
| 176 |
+
}
|
| 177 |
+
#endif // defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000
|
ggml/src/ggml-cuda/vendors/musa.h
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <musa_runtime.h>
|
| 4 |
+
#include <musa.h>
|
| 5 |
+
#include <mublas.h>
|
| 6 |
+
#include <musa_fp16.h>
|
| 7 |
+
#define CUBLAS_COMPUTE_16F CUDA_R_16F
|
| 8 |
+
#define CUBLAS_COMPUTE_32F CUDA_R_32F
|
| 9 |
+
#define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F
|
| 10 |
+
#define CUBLAS_GEMM_DEFAULT MUBLAS_GEMM_DEFAULT
|
| 11 |
+
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT
|
| 12 |
+
#define CUBLAS_OP_N MUBLAS_OP_N
|
| 13 |
+
#define CUBLAS_OP_T MUBLAS_OP_T
|
| 14 |
+
#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS
|
| 15 |
+
#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT
|
| 16 |
+
#define CUDA_R_16F MUSA_R_16F
|
| 17 |
+
#define CUDA_R_32F MUSA_R_32F
|
| 18 |
+
#define cublasComputeType_t cudaDataType_t
|
| 19 |
+
#define cublasCreate mublasCreate
|
| 20 |
+
#define cublasDestroy mublasDestroy
|
| 21 |
+
#define cublasGemmEx mublasGemmEx
|
| 22 |
+
#define cublasGemmBatchedEx mublasGemmBatchedEx
|
| 23 |
+
#define cublasGemmStridedBatchedEx mublasGemmStridedBatchedEx
|
| 24 |
+
#define cublasHandle_t mublasHandle_t
|
| 25 |
+
#define cublasSetMathMode mublasSetMathMode
|
| 26 |
+
#define cublasSetStream mublasSetStream
|
| 27 |
+
#define cublasSgemm mublasSgemm
|
| 28 |
+
#define cublasStatus_t mublasStatus_t
|
| 29 |
+
#define cublasGetStatusString mublasStatus_to_string
|
| 30 |
+
#define cudaDataType_t musaDataType_t
|
| 31 |
+
#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
|
| 32 |
+
#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess
|
| 33 |
+
#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess
|
| 34 |
+
#define cudaDeviceProp musaDeviceProp
|
| 35 |
+
#define cudaDeviceSynchronize musaDeviceSynchronize
|
| 36 |
+
#define cudaError_t musaError_t
|
| 37 |
+
#define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled
|
| 38 |
+
#define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled
|
| 39 |
+
#define cudaEventCreateWithFlags musaEventCreateWithFlags
|
| 40 |
+
#define cudaEventDisableTiming musaEventDisableTiming
|
| 41 |
+
#define cudaEventRecord musaEventRecord
|
| 42 |
+
#define cudaEventSynchronize musaEventSynchronize
|
| 43 |
+
#define cudaEvent_t musaEvent_t
|
| 44 |
+
#define cudaEventDestroy musaEventDestroy
|
| 45 |
+
#define cudaFree musaFree
|
| 46 |
+
#define cudaFreeHost musaFreeHost
|
| 47 |
+
#define cudaGetDevice musaGetDevice
|
| 48 |
+
#define cudaGetDeviceCount musaGetDeviceCount
|
| 49 |
+
#define cudaGetDeviceProperties musaGetDeviceProperties
|
| 50 |
+
#define cudaGetErrorString musaGetErrorString
|
| 51 |
+
#define cudaGetLastError musaGetLastError
|
| 52 |
+
#define cudaHostRegister musaHostRegister
|
| 53 |
+
#define cudaHostRegisterPortable musaHostRegisterPortable
|
| 54 |
+
#define cudaHostRegisterReadOnly musaHostRegisterReadOnly
|
| 55 |
+
#define cudaHostUnregister musaHostUnregister
|
| 56 |
+
#define cudaLaunchHostFunc musaLaunchHostFunc
|
| 57 |
+
#define cudaMalloc musaMalloc
|
| 58 |
+
#define cudaMallocHost musaMallocHost
|
| 59 |
+
#define cudaMemcpy musaMemcpy
|
| 60 |
+
#define cudaMemcpyAsync musaMemcpyAsync
|
| 61 |
+
#define cudaMemcpyPeerAsync musaMemcpyPeerAsync
|
| 62 |
+
#define cudaMemcpy2DAsync musaMemcpy2DAsync
|
| 63 |
+
#define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice
|
| 64 |
+
#define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost
|
| 65 |
+
#define cudaMemcpyHostToDevice musaMemcpyHostToDevice
|
| 66 |
+
#define cudaMemcpyKind musaMemcpyKind
|
| 67 |
+
#define cudaMemset musaMemset
|
| 68 |
+
#define cudaMemsetAsync musaMemsetAsync
|
| 69 |
+
#define cudaMemGetInfo musaMemGetInfo
|
| 70 |
+
#define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize
|
| 71 |
+
#define cudaSetDevice musaSetDevice
|
| 72 |
+
#define cudaStreamCreateWithFlags musaStreamCreateWithFlags
|
| 73 |
+
#define cudaStreamDestroy musaStreamDestroy
|
| 74 |
+
#define cudaStreamFireAndForget musaStreamFireAndForget
|
| 75 |
+
#define cudaStreamNonBlocking musaStreamNonBlocking
|
| 76 |
+
#define cudaStreamPerThread musaStreamPerThread
|
| 77 |
+
#define cudaStreamSynchronize musaStreamSynchronize
|
| 78 |
+
#define cudaStreamWaitEvent musaStreamWaitEvent
|
| 79 |
+
#define cudaStream_t musaStream_t
|
| 80 |
+
#define cudaSuccess musaSuccess
|
| 81 |
+
|
| 82 |
+
// Additional mappings for MUSA virtual memory pool
|
| 83 |
+
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
|
| 84 |
+
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE MU_MEM_ACCESS_FLAGS_PROT_READWRITE
|
| 85 |
+
#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED
|
| 86 |
+
#define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED
|
| 87 |
+
#define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE
|
| 88 |
+
#define CUdevice MUdevice
|
| 89 |
+
#define CUdeviceptr MUdeviceptr
|
| 90 |
+
#define CUmemAccessDesc MUmemAccessDesc
|
| 91 |
+
#define CUmemAllocationProp MUmemAllocationProp
|
| 92 |
+
#define CUmemGenericAllocationHandle MUmemGenericAllocationHandle
|
| 93 |
+
#define cuDeviceGet muDeviceGet
|
| 94 |
+
#define cuDeviceGetAttribute muDeviceGetAttribute
|
| 95 |
+
#define cuMemAddressFree muMemAddressFree
|
| 96 |
+
#define cuMemAddressReserve muMemAddressReserve
|
| 97 |
+
#define cuMemCreate muMemCreate
|
| 98 |
+
#define cuMemGetAllocationGranularity muMemGetAllocationGranularity
|
| 99 |
+
#define cuMemMap muMemMap
|
| 100 |
+
#define cuMemRelease muMemRelease
|
| 101 |
+
#define cuMemSetAccess muMemSetAccess
|
| 102 |
+
#define cuMemUnmap muMemUnmap
|
| 103 |
+
#define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize
|
| 104 |
+
#define cudaFuncSetAttribute musaFuncSetAttribute
|
| 105 |
+
#define cudaMemcpy3DPeerParms musaMemcpy3DPeerParms
|
| 106 |
+
#define make_cudaExtent make_musaExtent
|
| 107 |
+
#define make_cudaPitchedPtr make_musaPitchedPtr
|
| 108 |
+
|
| 109 |
+
// Additional mappings for MUSA graphs
|
| 110 |
+
#define CUDA_SUCCESS MUSA_SUCCESS
|
| 111 |
+
#define CUresult MUresult
|
| 112 |
+
#define cuGetErrorString muGetErrorString
|
| 113 |
+
#define cudaErrorGraphExecUpdateFailure musaErrorGraphExecUpdateFailure
|
| 114 |
+
#define cudaErrorInvalidDeviceFunction musaErrorInvalidDeviceFunction
|
| 115 |
+
#define cudaGraphDestroy musaGraphDestroy
|
| 116 |
+
#define cudaGraphExecDestroy musaGraphExecDestroy
|
| 117 |
+
#define cudaGraphExec_t musaGraphExec_t
|
| 118 |
+
#define cudaGraphExecUpdate musaGraphExecUpdate
|
| 119 |
+
#define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult
|
| 120 |
+
#define cudaGraphGetNodes musaGraphGetNodes
|
| 121 |
+
#define cudaGraphInstantiate musaGraphInstantiate
|
| 122 |
+
#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
|
| 123 |
+
#define cudaGraphKernelNodeSetParams musaGraphKernelNodeSetParams
|
| 124 |
+
#define cudaGraphLaunch musaGraphLaunch
|
| 125 |
+
#define cudaGraphNodeGetType musaGraphNodeGetType
|
| 126 |
+
#define cudaGraphNode_t musaGraphNode_t
|
| 127 |
+
#define cudaGraphNodeType musaGraphNodeType
|
| 128 |
+
#define cudaGraphNodeTypeKernel musaGraphNodeTypeKernel
|
| 129 |
+
#define cudaGraph_t musaGraph_t
|
| 130 |
+
#define cudaKernelNodeParams musaKernelNodeParams
|
| 131 |
+
#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
|
| 132 |
+
#define cudaStreamEndCapture musaStreamEndCapture
|
| 133 |
+
|
| 134 |
+
// XXX: Clang builtins mapping
|
| 135 |
+
#define __vsub4 __vsub4_musa
|
| 136 |
+
#define __vcmpeq4 __vcmpeq4_musa
|
| 137 |
+
#define __vcmpne4 __vcmpne4_musa
|
| 138 |
+
|
| 139 |
+
#ifndef __has_builtin
|
| 140 |
+
#define __has_builtin(x) 0
|
| 141 |
+
#endif
|
| 142 |
+
|
| 143 |
+
typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
|
| 144 |
+
|
| 145 |
+
static __device__ __forceinline__ int __vsub4_musa(const int a, const int b) {
|
| 146 |
+
return __vsubss4(a, b);
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
static __device__ __forceinline__ unsigned int __vcmpeq4_musa(unsigned int a, unsigned int b) {
|
| 150 |
+
const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
|
| 151 |
+
const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
|
| 152 |
+
unsigned int c;
|
| 153 |
+
uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
|
| 154 |
+
#pragma unroll
|
| 155 |
+
for (int i = 0; i < 4; ++i) {
|
| 156 |
+
vc[i] = va[i] == vb[i] ? 0xff : 0x00;
|
| 157 |
+
}
|
| 158 |
+
return c;
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
static __device__ __forceinline__ unsigned int __vcmpne4_musa(unsigned int a, unsigned int b) {
|
| 162 |
+
const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
|
| 163 |
+
const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
|
| 164 |
+
unsigned int c;
|
| 165 |
+
uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
|
| 166 |
+
#pragma unroll
|
| 167 |
+
for (int i = 0; i < 4; ++i) {
|
| 168 |
+
vc[i] = va[i] == vb[i] ? 0x00 : 0xff;
|
| 169 |
+
}
|
| 170 |
+
return c;
|
| 171 |
+
}
|