R0CKSTAR commited on
Commit
ec2f307
·
1 Parent(s): f0d6f5c

cuda : organize vendor-specific headers into vendors directory (llama/8746)

Browse files
ggml/src/ggml-cuda/common.cuh CHANGED
@@ -27,255 +27,11 @@
27
  #include <vector>
28
 
29
  #if defined(GGML_USE_HIPBLAS)
30
- #include <hip/hip_runtime.h>
31
- #include <hipblas/hipblas.h>
32
- #include <hip/hip_fp16.h>
33
- #ifdef __HIP_PLATFORM_AMD__
34
- // for rocblas_initialize()
35
- #include "rocblas/rocblas.h"
36
- #endif // __HIP_PLATFORM_AMD__
37
- #define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
38
- #define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
39
- #define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
40
- #define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
41
- #define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
42
- #define CUBLAS_OP_N HIPBLAS_OP_N
43
- #define CUBLAS_OP_T HIPBLAS_OP_T
44
- #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
45
- #define CUBLAS_TF32_TENSOR_OP_MATH 0
46
- #define CUDA_R_16F HIPBLAS_R_16F
47
- #define CUDA_R_32F HIPBLAS_R_32F
48
- #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
49
- #define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
50
- #define cublasCreate hipblasCreate
51
- #define cublasDestroy hipblasDestroy
52
- #define cublasGemmEx hipblasGemmEx
53
- #define cublasGemmBatchedEx hipblasGemmBatchedEx
54
- #define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
55
- #define cublasHandle_t hipblasHandle_t
56
- #define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
57
- #define cublasSetStream hipblasSetStream
58
- #define cublasSgemm hipblasSgemm
59
- #define cublasStatus_t hipblasStatus_t
60
- #define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6
61
- #define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
62
- #define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
63
- #define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
64
- #define cudaDeviceProp hipDeviceProp_t
65
- #define cudaDeviceSynchronize hipDeviceSynchronize
66
- #define cudaError_t hipError_t
67
- #define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
68
- #define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled
69
- #define cudaEventCreateWithFlags hipEventCreateWithFlags
70
- #define cudaEventDisableTiming hipEventDisableTiming
71
- #define cudaEventRecord hipEventRecord
72
- #define cudaEventSynchronize hipEventSynchronize
73
- #define cudaEvent_t hipEvent_t
74
- #define cudaEventDestroy hipEventDestroy
75
- #define cudaFree hipFree
76
- #define cudaFreeHost hipHostFree
77
- #define cudaGetDevice hipGetDevice
78
- #define cudaGetDeviceCount hipGetDeviceCount
79
- #define cudaGetDeviceProperties hipGetDeviceProperties
80
- #define cudaGetErrorString hipGetErrorString
81
- #define cudaGetLastError hipGetLastError
82
- #define cudaHostRegister hipHostRegister
83
- #define cudaHostRegisterPortable hipHostRegisterPortable
84
- #define cudaHostRegisterReadOnly hipHostRegisterReadOnly
85
- #define cudaHostUnregister hipHostUnregister
86
- #define cudaLaunchHostFunc hipLaunchHostFunc
87
- #define cudaMalloc hipMalloc
88
- #define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
89
- #define cudaMemcpy hipMemcpy
90
- #define cudaMemcpyAsync hipMemcpyAsync
91
- #define cudaMemcpyPeerAsync hipMemcpyPeerAsync
92
- #define cudaMemcpy2DAsync hipMemcpy2DAsync
93
- #define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
94
- #define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
95
- #define cudaMemcpyHostToDevice hipMemcpyHostToDevice
96
- #define cudaMemcpyKind hipMemcpyKind
97
- #define cudaMemset hipMemset
98
- #define cudaMemsetAsync hipMemsetAsync
99
- #define cudaMemGetInfo hipMemGetInfo
100
- #define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
101
- #define cudaSetDevice hipSetDevice
102
- #define cudaStreamCreateWithFlags hipStreamCreateWithFlags
103
- #define cudaStreamDestroy hipStreamDestroy
104
- #define cudaStreamFireAndForget hipStreamFireAndForget
105
- #define cudaStreamNonBlocking hipStreamNonBlocking
106
- #define cudaStreamPerThread hipStreamPerThread
107
- #define cudaStreamSynchronize hipStreamSynchronize
108
- #define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
109
- #define cudaStream_t hipStream_t
110
- #define cudaSuccess hipSuccess
111
- #define __trap() do { abort(); __builtin_unreachable(); } while(0)
112
- #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
113
- #define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
114
- #define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED
115
- #define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE
116
- #define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH
117
- #define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR
118
- #define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED
119
- #define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
120
- #define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
121
  #elif defined(GGML_USE_MUSA)
122
- #include <musa_runtime.h>
123
- #include <musa.h>
124
- #include <mublas.h>
125
- #include <musa_fp16.h>
126
- // XXX: Keep the following order the same as hipBLAS
127
- // #define CUBLAS_COMPUTE_16F MUBLAS_COMPUTE_16F
128
- // #define CUBLAS_COMPUTE_32F MUBLAS_COMPUTE_32F
129
- #define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F
130
- #define CUBLAS_GEMM_DEFAULT MUBLAS_GEMM_DEFAULT
131
- #define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT
132
- #define CUBLAS_OP_N MUBLAS_OP_N
133
- #define CUBLAS_OP_T MUBLAS_OP_T
134
- #define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS
135
- // #define CUBLAS_TF32_TENSOR_OP_MATH 0
136
- #define CUDA_R_16F MUSA_R_16F
137
- #define CUDA_R_32F MUSA_R_32F
138
- // #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
139
- // #define cublasComputeType_t mublasComputeType_t
140
- #define cublasCreate mublasCreate
141
- #define cublasDestroy mublasDestroy
142
- #define cublasGemmEx mublasGemmEx
143
- #define cublasGemmBatchedEx mublasGemmBatchedEx
144
- #define cublasGemmStridedBatchedEx mublasGemmStridedBatchedEx
145
- #define cublasHandle_t mublasHandle_t
146
- // #define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
147
- #define cublasSetMathMode mublasSetMathMode
148
- #define cublasSetStream mublasSetStream
149
- #define cublasSgemm mublasSgemm
150
- #define cublasStatus_t mublasStatus_t
151
- #define cudaDataType_t musaDataType_t //deprecated, new hipblasDatatype not in 5.6
152
- #define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
153
- #define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess
154
- #define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess
155
- #define cudaDeviceProp musaDeviceProp
156
- #define cudaDeviceSynchronize musaDeviceSynchronize
157
- #define cudaError_t musaError_t
158
- #define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled
159
- #define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled
160
- #define cudaEventCreateWithFlags musaEventCreateWithFlags
161
- #define cudaEventDisableTiming musaEventDisableTiming
162
- #define cudaEventRecord musaEventRecord
163
- #define cudaEventSynchronize musaEventSynchronize
164
- #define cudaEvent_t musaEvent_t
165
- #define cudaEventDestroy musaEventDestroy
166
- #define cudaFree musaFree
167
- #define cudaFreeHost musaFreeHost
168
- #define cudaGetDevice musaGetDevice
169
- #define cudaGetDeviceCount musaGetDeviceCount
170
- #define cudaGetDeviceProperties musaGetDeviceProperties
171
- #define cudaGetErrorString musaGetErrorString
172
- #define cudaGetLastError musaGetLastError
173
- #define cudaHostRegister musaHostRegister
174
- #define cudaHostRegisterPortable musaHostRegisterPortable
175
- #define cudaHostRegisterReadOnly musaHostRegisterReadOnly
176
- #define cudaHostUnregister musaHostUnregister
177
- #define cudaLaunchHostFunc musaLaunchHostFunc
178
- #define cudaMalloc musaMalloc
179
- #define cudaMallocHost musaMallocHost
180
- #define cudaMemcpy musaMemcpy
181
- #define cudaMemcpyAsync musaMemcpyAsync
182
- #define cudaMemcpyPeerAsync musaMemcpyPeerAsync
183
- #define cudaMemcpy2DAsync musaMemcpy2DAsync
184
- #define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice
185
- #define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost
186
- #define cudaMemcpyHostToDevice musaMemcpyHostToDevice
187
- #define cudaMemcpyKind musaMemcpyKind
188
- #define cudaMemset musaMemset
189
- #define cudaMemsetAsync musaMemsetAsync
190
- #define cudaMemGetInfo musaMemGetInfo
191
- #define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize
192
- #define cudaSetDevice musaSetDevice
193
- #define cudaStreamCreateWithFlags musaStreamCreateWithFlags
194
- #define cudaStreamDestroy musaStreamDestroy
195
- #define cudaStreamFireAndForget musaStreamFireAndForget
196
- #define cudaStreamNonBlocking musaStreamNonBlocking
197
- #define cudaStreamPerThread musaStreamPerThread
198
- #define cudaStreamSynchronize musaStreamSynchronize
199
- #define cudaStreamWaitEvent musaStreamWaitEvent
200
- #define cudaStream_t musaStream_t
201
- #define cudaSuccess musaSuccess
202
-
203
- // XXX: Other CUDA => MUSA mapping
204
- #define CU_MEM_ACCESS_FLAGS_PROT_READWRITE MU_MEM_ACCESS_FLAGS_PROT_READWRITE
205
- #define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED
206
- #define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED
207
- #define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE
208
- #define CUdevice MUdevice
209
- #define CUdeviceptr MUdeviceptr
210
- #define CUmemAccessDesc MUmemAccessDesc
211
- #define CUmemAllocationProp MUmemAllocationProp
212
- #define CUmemGenericAllocationHandle MUmemGenericAllocationHandle
213
- #define cuDeviceGet muDeviceGet
214
- #define cuDeviceGetAttribute muDeviceGetAttribute
215
- #define cuMemAddressFree muMemAddressFree
216
- #define cuMemAddressReserve muMemAddressReserve
217
- #define cuMemCreate muMemCreate
218
- #define cuMemGetAllocationGranularity muMemGetAllocationGranularity
219
- #define cuMemMap muMemMap
220
- #define cuMemRelease muMemRelease
221
- #define cuMemSetAccess muMemSetAccess
222
- #define cuMemUnmap muMemUnmap
223
- #define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize
224
- #define cudaFuncSetAttribute musaFuncSetAttribute
225
- #define cudaMemcpy3DPeerParms musaMemcpy3DPeerParms
226
- #define make_cudaExtent make_musaExtent
227
- #define make_cudaPitchedPtr make_musaPitchedPtr
228
-
229
- // XXX: USE_CUDA_GRAPH
230
- #define CUDA_SUCCESS MUSA_SUCCESS
231
- #define CUresult MUresult
232
- #define cuGetErrorString muGetErrorString
233
- #define cudaErrorGraphExecUpdateFailure musaErrorGraphExecUpdateFailure
234
- #define cudaErrorInvalidDeviceFunction musaErrorInvalidDeviceFunction
235
- #define cudaGraphDestroy musaGraphDestroy
236
- #define cudaGraphExecDestroy musaGraphExecDestroy
237
- #define cudaGraphExec_t musaGraphExec_t
238
- #define cudaGraphExecUpdate musaGraphExecUpdate
239
- #define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult
240
- #define cudaGraphGetNodes musaGraphGetNodes
241
- #define cudaGraphInstantiate musaGraphInstantiate
242
- #define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
243
- #define cudaGraphKernelNodeSetParams musaGraphKernelNodeSetParams
244
- #define cudaGraphLaunch musaGraphLaunch
245
- #define cudaGraphNodeGetType musaGraphNodeGetType
246
- #define cudaGraphNode_t musaGraphNode_t
247
- #define cudaGraphNodeType musaGraphNodeType
248
- #define cudaGraphNodeTypeKernel musaGraphNodeTypeKernel
249
- #define cudaGraph_t musaGraph_t
250
- #define cudaKernelNodeParams musaKernelNodeParams
251
- #define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
252
- #define cudaStreamEndCapture musaStreamEndCapture
253
-
254
- // XXX: cuBLAS => muBLAS mapping
255
- #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
256
- #define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT
257
- #define CUBLAS_COMPUTE_16F CUDA_R_16F
258
- #define CUBLAS_COMPUTE_32F CUDA_R_32F
259
- #define cublasComputeType_t cudaDataType_t
260
-
261
- // XXX: Clang builtins mapping
262
- #define __vsub4 __vsub4_musa
263
- #define __vcmpeq4 __vcmpeq4_musa
264
- #define __vcmpne4 __vcmpne4_musa
265
  #else
266
- #include <cuda_runtime.h>
267
- #include <cuda.h>
268
- #include <cublas_v2.h>
269
- #include <cuda_fp16.h>
270
-
271
- #if CUDART_VERSION < 11020
272
- #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
273
- #define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
274
- #define CUBLAS_COMPUTE_16F CUDA_R_16F
275
- #define CUBLAS_COMPUTE_32F CUDA_R_32F
276
- #define cublasComputeType_t cudaDataType_t
277
- #endif // CUDART_VERSION < 11020
278
-
279
  #endif // defined(GGML_USE_HIPBLAS)
280
 
281
  #define STRINGIZE_IMPL(...) #__VA_ARGS__
@@ -318,11 +74,7 @@ void ggml_cuda_error(const char * stmt, const char * func, const char * file, in
318
 
319
  #if CUDART_VERSION >= 12000 || defined(GGML_USE_MUSA)
320
  static const char * cublas_get_error_str(const cublasStatus_t err) {
321
- #ifndef GGML_USE_MUSA
322
  return cublasGetStatusString(err);
323
- #else
324
- return mublasStatus_to_string(err);
325
- #endif // GGML_USE_MUSA
326
  }
327
  #else
328
  static const char * cublas_get_error_str(const cublasStatus_t err) {
@@ -364,129 +116,7 @@ typedef half2 dfloat2;
364
  #else
365
  typedef float dfloat; // dequantize float
366
  typedef float2 dfloat2;
367
- #endif //GGML_CUDA_F16
368
-
369
- #if defined(GGML_USE_MUSA)
370
- #ifndef __has_builtin
371
- #define __has_builtin(x) 0
372
- #endif
373
-
374
- typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
375
-
376
- static __device__ __forceinline__ int __vsub4_musa(const int a, const int b) {
377
- return __vsubss4(a, b);
378
- }
379
-
380
- static __device__ __forceinline__ unsigned int __vcmpeq4_musa(unsigned int a, unsigned int b) {
381
- const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
382
- const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
383
- unsigned int c;
384
- uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
385
- #pragma unroll
386
- for (int i = 0; i < 4; ++i) {
387
- vc[i] = va[i] == vb[i] ? 0xff : 0x00;
388
- }
389
- return c;
390
- }
391
-
392
- static __device__ __forceinline__ unsigned int __vcmpne4_musa(unsigned int a, unsigned int b) {
393
- const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
394
- const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
395
- unsigned int c;
396
- uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
397
- #pragma unroll
398
- for (int i = 0; i < 4; ++i) {
399
- vc[i] = va[i] == vb[i] ? 0x00 : 0xff;
400
- }
401
- return c;
402
- }
403
- #endif // defined(GGML_USE_MUSA)
404
-
405
- #if defined(GGML_USE_HIPBLAS)
406
- #define __CUDA_ARCH__ 1300
407
-
408
- #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
409
- defined(__gfx1150__) || defined(__gfx1151__)
410
- #define RDNA3
411
- #endif
412
-
413
- #if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \
414
- defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
415
- #define RDNA2
416
- #endif
417
-
418
- #if defined(__gfx1010__) || defined(__gfx1012__)
419
- #define RDNA1
420
- #endif
421
-
422
- #ifndef __has_builtin
423
- #define __has_builtin(x) 0
424
- #endif
425
-
426
- typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
427
- typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
428
- static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
429
- const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
430
- const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
431
- #if __has_builtin(__builtin_elementwise_sub_sat)
432
- const int8x4_t c = __builtin_elementwise_sub_sat(va, vb);
433
- return reinterpret_cast<const int &>(c);
434
- #else
435
- int8x4_t c;
436
- int16_t tmp;
437
- #pragma unroll
438
- for (int i = 0; i < 4; i++) {
439
- tmp = va[i] - vb[i];
440
- if(tmp > std::numeric_limits<int8_t>::max()) tmp = std::numeric_limits<int8_t>::max();
441
- if(tmp < std::numeric_limits<int8_t>::min()) tmp = std::numeric_limits<int8_t>::min();
442
- c[i] = tmp;
443
- }
444
- return reinterpret_cast<int &>(c);
445
- #endif // __has_builtin(__builtin_elementwise_sub_sat)
446
- }
447
-
448
- static __device__ __forceinline__ int __vsub4(const int a, const int b) {
449
- return __vsubss4(a, b);
450
- }
451
-
452
- static __device__ __forceinline__ unsigned int __vcmpeq4(unsigned int a, unsigned int b) {
453
- const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
454
- const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
455
- unsigned int c;
456
- uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
457
- #pragma unroll
458
- for (int i = 0; i < 4; ++i) {
459
- vc[i] = va[i] == vb[i] ? 0xff : 0x00;
460
- }
461
- return c;
462
- }
463
-
464
- static __device__ __forceinline__ unsigned int __vcmpne4(unsigned int a, unsigned int b) {
465
- const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
466
- const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
467
- unsigned int c;
468
- uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
469
- #pragma unroll
470
- for (int i = 0; i < 4; ++i) {
471
- vc[i] = va[i] == vb[i] ? 0x00 : 0xff;
472
- }
473
- return c;
474
- }
475
-
476
- #if defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000
477
- // __shfl_xor() for half2 was added in ROCm 5.6
478
- static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int width) {
479
- typedef union half2_b32 {
480
- half2 val;
481
- int b32;
482
- } half2_b32_t;
483
- half2_b32_t tmp;
484
- tmp.val = var;
485
- tmp.b32 = __shfl_xor(tmp.b32, laneMask, width);
486
- return tmp.val;
487
- }
488
- #endif // defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000
489
- #endif // defined(GGML_USE_HIPBLAS)
490
 
491
  #if (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
492
  #define FP16_AVAILABLE
 
27
  #include <vector>
28
 
29
  #if defined(GGML_USE_HIPBLAS)
30
+ #include "vendors/hip.h"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  #elif defined(GGML_USE_MUSA)
32
+ #include "vendors/musa.h"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  #else
34
+ #include "vendors/cuda.h"
 
 
 
 
 
 
 
 
 
 
 
 
35
  #endif // defined(GGML_USE_HIPBLAS)
36
 
37
  #define STRINGIZE_IMPL(...) #__VA_ARGS__
 
74
 
75
  #if CUDART_VERSION >= 12000 || defined(GGML_USE_MUSA)
76
  static const char * cublas_get_error_str(const cublasStatus_t err) {
 
77
  return cublasGetStatusString(err);
 
 
 
78
  }
79
  #else
80
  static const char * cublas_get_error_str(const cublasStatus_t err) {
 
116
  #else
117
  typedef float dfloat; // dequantize float
118
  typedef float2 dfloat2;
119
+ #endif // GGML_CUDA_F16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
  #if (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
122
  #define FP16_AVAILABLE
ggml/src/ggml-cuda/vendors/cuda.h ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <cuda_runtime.h>
4
+ #include <cuda.h>
5
+ #include <cublas_v2.h>
6
+ #include <cuda_fp16.h>
7
+
8
+ #if CUDART_VERSION < 11020
9
+ #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
10
+ #define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
11
+ #define CUBLAS_COMPUTE_16F CUDA_R_16F
12
+ #define CUBLAS_COMPUTE_32F CUDA_R_32F
13
+ #define cublasComputeType_t cudaDataType_t
14
+ #endif // CUDART_VERSION < 11020
ggml/src/ggml-cuda/vendors/hip.h ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <hip/hip_runtime.h>
4
+ #include <hipblas/hipblas.h>
5
+ #include <hip/hip_fp16.h>
6
+ #ifdef __HIP_PLATFORM_AMD__
7
+ // for rocblas_initialize()
8
+ #include "rocblas/rocblas.h"
9
+ #endif // __HIP_PLATFORM_AMD__
10
+ #define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
11
+ #define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
12
+ #define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
13
+ #define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
14
+ #define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
15
+ #define CUBLAS_OP_N HIPBLAS_OP_N
16
+ #define CUBLAS_OP_T HIPBLAS_OP_T
17
+ #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
18
+ #define CUBLAS_TF32_TENSOR_OP_MATH 0
19
+ #define CUDA_R_16F HIPBLAS_R_16F
20
+ #define CUDA_R_32F HIPBLAS_R_32F
21
+ #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
22
+ #define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
23
+ #define cublasCreate hipblasCreate
24
+ #define cublasDestroy hipblasDestroy
25
+ #define cublasGemmEx hipblasGemmEx
26
+ #define cublasGemmBatchedEx hipblasGemmBatchedEx
27
+ #define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
28
+ #define cublasHandle_t hipblasHandle_t
29
+ #define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
30
+ #define cublasSetStream hipblasSetStream
31
+ #define cublasSgemm hipblasSgemm
32
+ #define cublasStatus_t hipblasStatus_t
33
+ #define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6
34
+ #define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
35
+ #define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
36
+ #define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
37
+ #define cudaDeviceProp hipDeviceProp_t
38
+ #define cudaDeviceSynchronize hipDeviceSynchronize
39
+ #define cudaError_t hipError_t
40
+ #define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
41
+ #define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled
42
+ #define cudaEventCreateWithFlags hipEventCreateWithFlags
43
+ #define cudaEventDisableTiming hipEventDisableTiming
44
+ #define cudaEventRecord hipEventRecord
45
+ #define cudaEventSynchronize hipEventSynchronize
46
+ #define cudaEvent_t hipEvent_t
47
+ #define cudaEventDestroy hipEventDestroy
48
+ #define cudaFree hipFree
49
+ #define cudaFreeHost hipHostFree
50
+ #define cudaGetDevice hipGetDevice
51
+ #define cudaGetDeviceCount hipGetDeviceCount
52
+ #define cudaGetDeviceProperties hipGetDeviceProperties
53
+ #define cudaGetErrorString hipGetErrorString
54
+ #define cudaGetLastError hipGetLastError
55
+ #define cudaHostRegister hipHostRegister
56
+ #define cudaHostRegisterPortable hipHostRegisterPortable
57
+ #define cudaHostRegisterReadOnly hipHostRegisterReadOnly
58
+ #define cudaHostUnregister hipHostUnregister
59
+ #define cudaLaunchHostFunc hipLaunchHostFunc
60
+ #define cudaMalloc hipMalloc
61
+ #define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
62
+ #define cudaMemcpy hipMemcpy
63
+ #define cudaMemcpyAsync hipMemcpyAsync
64
+ #define cudaMemcpyPeerAsync hipMemcpyPeerAsync
65
+ #define cudaMemcpy2DAsync hipMemcpy2DAsync
66
+ #define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
67
+ #define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
68
+ #define cudaMemcpyHostToDevice hipMemcpyHostToDevice
69
+ #define cudaMemcpyKind hipMemcpyKind
70
+ #define cudaMemset hipMemset
71
+ #define cudaMemsetAsync hipMemsetAsync
72
+ #define cudaMemGetInfo hipMemGetInfo
73
+ #define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
74
+ #define cudaSetDevice hipSetDevice
75
+ #define cudaStreamCreateWithFlags hipStreamCreateWithFlags
76
+ #define cudaStreamDestroy hipStreamDestroy
77
+ #define cudaStreamFireAndForget hipStreamFireAndForget
78
+ #define cudaStreamNonBlocking hipStreamNonBlocking
79
+ #define cudaStreamPerThread hipStreamPerThread
80
+ #define cudaStreamSynchronize hipStreamSynchronize
81
+ #define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
82
+ #define cudaStream_t hipStream_t
83
+ #define cudaSuccess hipSuccess
84
+ #define __trap() do { abort(); __builtin_unreachable(); } while(0)
85
+ #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
86
+ #define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
87
+ #define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED
88
+ #define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE
89
+ #define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH
90
+ #define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR
91
+ #define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED
92
+ #define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
93
+ #define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
94
+
95
+ #define __CUDA_ARCH__ 1300
96
+
97
+ #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
98
+ defined(__gfx1150__) || defined(__gfx1151__)
99
+ #define RDNA3
100
+ #endif
101
+
102
+ #if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \
103
+ defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
104
+ #define RDNA2
105
+ #endif
106
+
107
+ #if defined(__gfx1010__) || defined(__gfx1012__)
108
+ #define RDNA1
109
+ #endif
110
+
111
+ #ifndef __has_builtin
112
+ #define __has_builtin(x) 0
113
+ #endif
114
+
115
+ typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
116
+ typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
117
+ static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
118
+ const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
119
+ const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
120
+ #if __has_builtin(__builtin_elementwise_sub_sat)
121
+ const int8x4_t c = __builtin_elementwise_sub_sat(va, vb);
122
+ return reinterpret_cast<const int &>(c);
123
+ #else
124
+ int8x4_t c;
125
+ int16_t tmp;
126
+ #pragma unroll
127
+ for (int i = 0; i < 4; i++) {
128
+ tmp = va[i] - vb[i];
129
+ if(tmp > std::numeric_limits<int8_t>::max()) tmp = std::numeric_limits<int8_t>::max();
130
+ if(tmp < std::numeric_limits<int8_t>::min()) tmp = std::numeric_limits<int8_t>::min();
131
+ c[i] = tmp;
132
+ }
133
+ return reinterpret_cast<int &>(c);
134
+ #endif // __has_builtin(__builtin_elementwise_sub_sat)
135
+ }
136
+
137
+ static __device__ __forceinline__ int __vsub4(const int a, const int b) {
138
+ return __vsubss4(a, b);
139
+ }
140
+
141
+ static __device__ __forceinline__ unsigned int __vcmpeq4(unsigned int a, unsigned int b) {
142
+ const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
143
+ const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
144
+ unsigned int c;
145
+ uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
146
+ #pragma unroll
147
+ for (int i = 0; i < 4; ++i) {
148
+ vc[i] = va[i] == vb[i] ? 0xff : 0x00;
149
+ }
150
+ return c;
151
+ }
152
+
153
+ static __device__ __forceinline__ unsigned int __vcmpne4(unsigned int a, unsigned int b) {
154
+ const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
155
+ const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
156
+ unsigned int c;
157
+ uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
158
+ #pragma unroll
159
+ for (int i = 0; i < 4; ++i) {
160
+ vc[i] = va[i] == vb[i] ? 0x00 : 0xff;
161
+ }
162
+ return c;
163
+ }
164
+
165
+ #if defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000
166
+ // __shfl_xor() for half2 was added in ROCm 5.6
167
+ static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int width) {
168
+ typedef union half2_b32 {
169
+ half2 val;
170
+ int b32;
171
+ } half2_b32_t;
172
+ half2_b32_t tmp;
173
+ tmp.val = var;
174
+ tmp.b32 = __shfl_xor(tmp.b32, laneMask, width);
175
+ return tmp.val;
176
+ }
177
+ #endif // defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000
ggml/src/ggml-cuda/vendors/musa.h ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <musa_runtime.h>
4
+ #include <musa.h>
5
+ #include <mublas.h>
6
+ #include <musa_fp16.h>
7
+ #define CUBLAS_COMPUTE_16F CUDA_R_16F
8
+ #define CUBLAS_COMPUTE_32F CUDA_R_32F
9
+ #define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F
10
+ #define CUBLAS_GEMM_DEFAULT MUBLAS_GEMM_DEFAULT
11
+ #define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT
12
+ #define CUBLAS_OP_N MUBLAS_OP_N
13
+ #define CUBLAS_OP_T MUBLAS_OP_T
14
+ #define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS
15
+ #define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT
16
+ #define CUDA_R_16F MUSA_R_16F
17
+ #define CUDA_R_32F MUSA_R_32F
18
+ #define cublasComputeType_t cudaDataType_t
19
+ #define cublasCreate mublasCreate
20
+ #define cublasDestroy mublasDestroy
21
+ #define cublasGemmEx mublasGemmEx
22
+ #define cublasGemmBatchedEx mublasGemmBatchedEx
23
+ #define cublasGemmStridedBatchedEx mublasGemmStridedBatchedEx
24
+ #define cublasHandle_t mublasHandle_t
25
+ #define cublasSetMathMode mublasSetMathMode
26
+ #define cublasSetStream mublasSetStream
27
+ #define cublasSgemm mublasSgemm
28
+ #define cublasStatus_t mublasStatus_t
29
+ #define cublasGetStatusString mublasStatus_to_string
30
+ #define cudaDataType_t musaDataType_t
31
+ #define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
32
+ #define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess
33
+ #define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess
34
+ #define cudaDeviceProp musaDeviceProp
35
+ #define cudaDeviceSynchronize musaDeviceSynchronize
36
+ #define cudaError_t musaError_t
37
+ #define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled
38
+ #define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled
39
+ #define cudaEventCreateWithFlags musaEventCreateWithFlags
40
+ #define cudaEventDisableTiming musaEventDisableTiming
41
+ #define cudaEventRecord musaEventRecord
42
+ #define cudaEventSynchronize musaEventSynchronize
43
+ #define cudaEvent_t musaEvent_t
44
+ #define cudaEventDestroy musaEventDestroy
45
+ #define cudaFree musaFree
46
+ #define cudaFreeHost musaFreeHost
47
+ #define cudaGetDevice musaGetDevice
48
+ #define cudaGetDeviceCount musaGetDeviceCount
49
+ #define cudaGetDeviceProperties musaGetDeviceProperties
50
+ #define cudaGetErrorString musaGetErrorString
51
+ #define cudaGetLastError musaGetLastError
52
+ #define cudaHostRegister musaHostRegister
53
+ #define cudaHostRegisterPortable musaHostRegisterPortable
54
+ #define cudaHostRegisterReadOnly musaHostRegisterReadOnly
55
+ #define cudaHostUnregister musaHostUnregister
56
+ #define cudaLaunchHostFunc musaLaunchHostFunc
57
+ #define cudaMalloc musaMalloc
58
+ #define cudaMallocHost musaMallocHost
59
+ #define cudaMemcpy musaMemcpy
60
+ #define cudaMemcpyAsync musaMemcpyAsync
61
+ #define cudaMemcpyPeerAsync musaMemcpyPeerAsync
62
+ #define cudaMemcpy2DAsync musaMemcpy2DAsync
63
+ #define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice
64
+ #define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost
65
+ #define cudaMemcpyHostToDevice musaMemcpyHostToDevice
66
+ #define cudaMemcpyKind musaMemcpyKind
67
+ #define cudaMemset musaMemset
68
+ #define cudaMemsetAsync musaMemsetAsync
69
+ #define cudaMemGetInfo musaMemGetInfo
70
+ #define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize
71
+ #define cudaSetDevice musaSetDevice
72
+ #define cudaStreamCreateWithFlags musaStreamCreateWithFlags
73
+ #define cudaStreamDestroy musaStreamDestroy
74
+ #define cudaStreamFireAndForget musaStreamFireAndForget
75
+ #define cudaStreamNonBlocking musaStreamNonBlocking
76
+ #define cudaStreamPerThread musaStreamPerThread
77
+ #define cudaStreamSynchronize musaStreamSynchronize
78
+ #define cudaStreamWaitEvent musaStreamWaitEvent
79
+ #define cudaStream_t musaStream_t
80
+ #define cudaSuccess musaSuccess
81
+
82
+ // Additional mappings for MUSA virtual memory pool
83
+ #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
84
+ #define CU_MEM_ACCESS_FLAGS_PROT_READWRITE MU_MEM_ACCESS_FLAGS_PROT_READWRITE
85
+ #define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED
86
+ #define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED
87
+ #define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE
88
+ #define CUdevice MUdevice
89
+ #define CUdeviceptr MUdeviceptr
90
+ #define CUmemAccessDesc MUmemAccessDesc
91
+ #define CUmemAllocationProp MUmemAllocationProp
92
+ #define CUmemGenericAllocationHandle MUmemGenericAllocationHandle
93
+ #define cuDeviceGet muDeviceGet
94
+ #define cuDeviceGetAttribute muDeviceGetAttribute
95
+ #define cuMemAddressFree muMemAddressFree
96
+ #define cuMemAddressReserve muMemAddressReserve
97
+ #define cuMemCreate muMemCreate
98
+ #define cuMemGetAllocationGranularity muMemGetAllocationGranularity
99
+ #define cuMemMap muMemMap
100
+ #define cuMemRelease muMemRelease
101
+ #define cuMemSetAccess muMemSetAccess
102
+ #define cuMemUnmap muMemUnmap
103
+ #define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize
104
+ #define cudaFuncSetAttribute musaFuncSetAttribute
105
+ #define cudaMemcpy3DPeerParms musaMemcpy3DPeerParms
106
+ #define make_cudaExtent make_musaExtent
107
+ #define make_cudaPitchedPtr make_musaPitchedPtr
108
+
109
+ // Additional mappings for MUSA graphs
110
+ #define CUDA_SUCCESS MUSA_SUCCESS
111
+ #define CUresult MUresult
112
+ #define cuGetErrorString muGetErrorString
113
+ #define cudaErrorGraphExecUpdateFailure musaErrorGraphExecUpdateFailure
114
+ #define cudaErrorInvalidDeviceFunction musaErrorInvalidDeviceFunction
115
+ #define cudaGraphDestroy musaGraphDestroy
116
+ #define cudaGraphExecDestroy musaGraphExecDestroy
117
+ #define cudaGraphExec_t musaGraphExec_t
118
+ #define cudaGraphExecUpdate musaGraphExecUpdate
119
+ #define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult
120
+ #define cudaGraphGetNodes musaGraphGetNodes
121
+ #define cudaGraphInstantiate musaGraphInstantiate
122
+ #define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
123
+ #define cudaGraphKernelNodeSetParams musaGraphKernelNodeSetParams
124
+ #define cudaGraphLaunch musaGraphLaunch
125
+ #define cudaGraphNodeGetType musaGraphNodeGetType
126
+ #define cudaGraphNode_t musaGraphNode_t
127
+ #define cudaGraphNodeType musaGraphNodeType
128
+ #define cudaGraphNodeTypeKernel musaGraphNodeTypeKernel
129
+ #define cudaGraph_t musaGraph_t
130
+ #define cudaKernelNodeParams musaKernelNodeParams
131
+ #define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
132
+ #define cudaStreamEndCapture musaStreamEndCapture
133
+
134
+ // XXX: Clang builtins mapping
135
+ #define __vsub4 __vsub4_musa
136
+ #define __vcmpeq4 __vcmpeq4_musa
137
+ #define __vcmpne4 __vcmpne4_musa
138
+
139
+ #ifndef __has_builtin
140
+ #define __has_builtin(x) 0
141
+ #endif
142
+
143
+ typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
144
+
145
+ static __device__ __forceinline__ int __vsub4_musa(const int a, const int b) {
146
+ return __vsubss4(a, b);
147
+ }
148
+
149
+ static __device__ __forceinline__ unsigned int __vcmpeq4_musa(unsigned int a, unsigned int b) {
150
+ const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
151
+ const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
152
+ unsigned int c;
153
+ uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
154
+ #pragma unroll
155
+ for (int i = 0; i < 4; ++i) {
156
+ vc[i] = va[i] == vb[i] ? 0xff : 0x00;
157
+ }
158
+ return c;
159
+ }
160
+
161
+ static __device__ __forceinline__ unsigned int __vcmpne4_musa(unsigned int a, unsigned int b) {
162
+ const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
163
+ const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
164
+ unsigned int c;
165
+ uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
166
+ #pragma unroll
167
+ for (int i = 0; i < 4; ++i) {
168
+ vc[i] = va[i] == vb[i] ? 0x00 : 0xff;
169
+ }
170
+ return c;
171
+ }