Sigbjørn Skjæret commited on
Commit
b54b644
·
1 Parent(s): cfa3731

cuda : implement bf16 cpy ops and enable bf16 cont (llama/14763)

Browse files

* implement bf16 cpy ops and enable bf16 cont

* deduplicate copy functions

* deduplicate checks

ggml/src/ggml-cuda/cpy-utils.cuh CHANGED
@@ -2,24 +2,13 @@
2
 
3
  #include "ggml-common.h"
4
 
5
- static __device__ __forceinline__ void convert_f32_f32(const float * src, float * dst) {
6
- *dst = *src;
7
- }
8
-
9
- static __device__ __forceinline__ void convert_f32_f16(const float * src, half * dst) {
10
- *dst = __float2half(*src);
11
- }
12
-
13
- static __device__ __forceinline__ void convert_f32_bf16(const float * src, nv_bfloat16 * dst) {
14
- *dst = *src;
15
- }
16
-
17
- static __device__ __forceinline__ void convert_f16_f16(const half * src, half * dst) {
18
- *dst = *src;
19
- }
20
-
21
- static __device__ __forceinline__ void convert_f16_f32(const half * src, float * dst) {
22
- *dst = *src;
23
  }
24
 
25
  static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
@@ -230,22 +219,7 @@ static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
230
  quantize_f32_iq4_nl_block((const float *)cxi, (block_iq4_nl *)cdsti);
231
  }
232
 
233
- static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) {
234
- convert_f32_f32((const float *)cxi, (float *)cdsti);
235
- }
236
-
237
- static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
238
- convert_f32_f16((const float *)cxi, (half *)cdsti);
239
- }
240
-
241
- static __device__ void cpy_1_f32_bf16(const char * cxi, char * cdsti) {
242
- convert_f32_bf16((const float *)cxi, (nv_bfloat16 *)cdsti);
243
- }
244
-
245
- static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) {
246
- convert_f16_f16((const half *)cxi, (half *)cdsti);
247
- }
248
-
249
- static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) {
250
- convert_f16_f32((const half *)cxi, (float *)cdsti);
251
  }
 
2
 
3
  #include "ggml-common.h"
4
 
5
+ template<typename src_t, typename dst_t>
6
+ static __device__ __forceinline__ void convert_flt(const src_t * src, dst_t * dst) {
7
+ if constexpr (std::is_same_v<src_t, dst_t>) {
8
+ *dst = *src;
9
+ } else {
10
+ *dst = float(*src);
11
+ }
 
 
 
 
 
 
 
 
 
 
 
12
  }
13
 
14
  static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
 
219
  quantize_f32_iq4_nl_block((const float *)cxi, (block_iq4_nl *)cdsti);
220
  }
221
 
222
+ template<typename src_t, typename dst_t>
223
+ static __device__ void cpy_1_flt(const char * cxi, char * cdsti) {
224
+ convert_flt((const src_t *)cxi, (dst_t *)cdsti);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  }
ggml/src/ggml-cuda/cpy.cu CHANGED
@@ -8,10 +8,10 @@
8
  typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
9
 
10
  template <cpy_kernel_t cpy_1>
11
- static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const int ne,
12
- const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
13
- const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
14
- const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
15
  const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
16
 
17
  if (i >= ne) {
@@ -139,43 +139,14 @@ void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_des
139
  #endif
140
  }
141
 
142
- static void ggml_cpy_f16_f32_cuda(
 
143
  const char * cx, char * cdst, const int ne,
144
  const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
145
  const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
146
 
147
  const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
148
- cpy_f32_f16<cpy_1_f16_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
149
- (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
150
- }
151
-
152
- static void ggml_cpy_f32_f32_cuda(
153
- const char * cx, char * cdst, const int ne,
154
- const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
155
- const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
156
-
157
- const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
158
- cpy_f32_f16<cpy_1_f32_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
159
- (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
160
- }
161
-
162
- static void ggml_cpy_f32_bf16_cuda(
163
- const char * cx, char * cdst, const int ne,
164
- const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
165
- const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
166
-
167
- const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
168
- cpy_f32_f16<cpy_1_f32_bf16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
169
- (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
170
- }
171
-
172
- static void ggml_cpy_f32_f16_cuda(
173
- const char * cx, char * cdst, const int ne,
174
- const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
175
- const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
176
-
177
- const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
178
- cpy_f32_f16<cpy_1_f32_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
179
  (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
180
  }
181
 
@@ -307,16 +278,6 @@ static void ggml_cpy_f32_iq4_nl_cuda(
307
  (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
308
  }
309
 
310
- static void ggml_cpy_f16_f16_cuda(
311
- const char * cx, char * cdst, const int ne,
312
- const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
313
- const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
314
-
315
- const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
316
- cpy_f32_f16<cpy_1_f16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
317
- (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
318
- }
319
-
320
  void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection_for_this_node) {
321
  const int64_t ne = ggml_nelements(src0);
322
  GGML_ASSERT(ne == ggml_nelements(src1));
@@ -372,11 +333,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
372
  CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
373
  }
374
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
375
- ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
376
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
377
- ggml_cpy_f32_bf16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
378
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
379
- ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
380
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
381
  ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
382
  } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
@@ -403,9 +364,17 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
403
  } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
404
  ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
405
  } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
406
- ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
 
 
407
  } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
408
- ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
 
 
 
 
 
 
409
  } else {
410
  GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
411
  ggml_type_name(src0->type), ggml_type_name(src1->type));
@@ -430,11 +399,11 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
430
  if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
431
  return nullptr;
432
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
433
- return (void*) cpy_f32_f16<cpy_1_f32_f32>;
434
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
435
- return (void*) cpy_f32_f16<cpy_1_f32_bf16>;
436
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
437
- return (void*) cpy_f32_f16<cpy_1_f32_f16>;
438
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
439
  return (void*) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
440
  } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
@@ -458,9 +427,17 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
458
  } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
459
  return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1>;
460
  } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
461
- return (void*) cpy_f32_f16<cpy_1_f32_f16>;
 
 
462
  } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
463
- return (void*) cpy_f32_f16<cpy_1_f16_f32>;
 
 
 
 
 
 
464
  } else {
465
  GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
466
  ggml_type_name(src0->type), ggml_type_name(src1->type));
 
8
  typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
9
 
10
  template <cpy_kernel_t cpy_1>
11
+ static __global__ void cpy_flt(const char * cx, char * cdst_direct, const int ne,
12
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
13
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
14
+ const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
15
  const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
16
 
17
  if (i >= ne) {
 
139
  #endif
140
  }
141
 
142
+ template<typename src_t, typename dst_t>
143
+ static void ggml_cpy_flt_cuda(
144
  const char * cx, char * cdst, const int ne,
145
  const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
146
  const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
147
 
148
  const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
149
+ cpy_flt<cpy_1_flt<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
151
  }
152
 
 
278
  (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
279
  }
280
 
 
 
 
 
 
 
 
 
 
 
281
  void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection_for_this_node) {
282
  const int64_t ne = ggml_nelements(src0);
283
  GGML_ASSERT(ne == ggml_nelements(src1));
 
333
  CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
334
  }
335
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
336
+ ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
337
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
338
+ ggml_cpy_flt_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
339
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
340
+ ggml_cpy_flt_cuda<float, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
341
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
342
  ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
343
  } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
 
364
  } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
365
  ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
366
  } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
367
+ ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
368
+ } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
369
+ ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
370
  } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
371
+ ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
372
+ } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
373
+ ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
374
+ } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
375
+ ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
376
+ } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
377
+ ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
378
  } else {
379
  GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
380
  ggml_type_name(src0->type), ggml_type_name(src1->type));
 
399
  if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
400
  return nullptr;
401
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
402
+ return (void*) cpy_flt<cpy_1_flt<float, float>>;
403
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
404
+ return (void*) cpy_flt<cpy_1_flt<float, nv_bfloat16>>;
405
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
406
+ return (void*) cpy_flt<cpy_1_flt<float, half>>;
407
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
408
  return (void*) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
409
  } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
 
427
  } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
428
  return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1>;
429
  } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
430
+ return (void*) cpy_flt<cpy_1_flt<half, half>>;
431
+ } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
432
+ return (void*) cpy_flt<cpy_1_flt<half, nv_bfloat16>>;
433
  } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
434
+ return (void*) cpy_flt<cpy_1_flt<half, float>>;
435
+ } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
436
+ return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, half>>;
437
+ } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
438
+ return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, nv_bfloat16>>;
439
+ } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
440
+ return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, float>>;
441
  } else {
442
  GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
443
  ggml_type_name(src0->type), ggml_type_name(src1->type));
ggml/src/ggml-cuda/ggml-cuda.cu CHANGED
@@ -3242,13 +3242,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3242
  {
3243
  ggml_type src0_type = op->src[0]->type;
3244
  ggml_type src1_type = op->src[1]->type;
3245
- if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
3246
- return true;
3247
- }
3248
- if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_BF16) {
3249
- return true;
3250
- }
3251
- if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
3252
  return true;
3253
  }
3254
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) {
@@ -3284,12 +3280,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3284
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
3285
  return true;
3286
  }
3287
- if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
3288
- return true;
3289
- }
3290
- if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
3291
- return true;
3292
- }
3293
  if (src0_type == src1_type && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) {
3294
  return true;
3295
  }
@@ -3370,7 +3360,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3370
  return op->src[0]->ne[1] % 128 == 0;
3371
  }
3372
  case GGML_OP_CONT:
3373
- return op->src[0]->type != GGML_TYPE_BF16;
3374
  case GGML_OP_DIAG_MASK_INF:
3375
  return true;
3376
  case GGML_OP_SOFT_MAX:
 
3242
  {
3243
  ggml_type src0_type = op->src[0]->type;
3244
  ggml_type src1_type = op->src[1]->type;
3245
+ if ((src0_type == GGML_TYPE_F32 || src0_type == GGML_TYPE_BF16 || src0_type == GGML_TYPE_F16) &&
3246
+ (src1_type == GGML_TYPE_F32 || src1_type == GGML_TYPE_BF16 || src1_type == GGML_TYPE_F16)
3247
+ ) {
 
 
 
 
3248
  return true;
3249
  }
3250
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) {
 
3280
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
3281
  return true;
3282
  }
 
 
 
 
 
 
3283
  if (src0_type == src1_type && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) {
3284
  return true;
3285
  }
 
3360
  return op->src[0]->ne[1] % 128 == 0;
3361
  }
3362
  case GGML_OP_CONT:
3363
+ return true;
3364
  case GGML_OP_DIAG_MASK_INF:
3365
  return true;
3366
  case GGML_OP_SOFT_MAX:
ggml/src/ggml-cuda/set-rows.cu CHANGED
@@ -4,24 +4,8 @@
4
  typedef void (*set_rows_kernel_t)(const char * src, char * dst);
5
 
6
  template<typename src_t, typename dst_t>
7
- __device__ void set_rows_1(const src_t * src_f, dst_t * dst_f) {
8
- GGML_UNUSED(src_f);
9
- GGML_UNUSED(dst_f);
10
- }
11
-
12
- template<>
13
- __device__ __forceinline__ void set_rows_1<float, half>(const float * src_f, half * dst_h) {
14
- convert_f32_f16(src_f, dst_h);
15
- }
16
-
17
- template<>
18
- __device__ __forceinline__ void set_rows_1<float, nv_bfloat16>(const float * src_f, nv_bfloat16 * dst_b) {
19
- convert_f32_bf16(src_f, dst_b);
20
- }
21
-
22
- template<>
23
- __device__ __forceinline__ void set_rows_1<float, float>(const float * src_f, float * dst_f) {
24
- convert_f32_f32(src_f, dst_f);
25
  }
26
 
27
  // Generic quantized set_rows kernel template
 
4
  typedef void (*set_rows_kernel_t)(const char * src, char * dst);
5
 
6
  template<typename src_t, typename dst_t>
7
+ __device__ __forceinline__ void set_rows_1(const src_t * src_f, dst_t * dst_f) {
8
+ convert_flt(src_f, dst_f);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  }
10
 
11
  // Generic quantized set_rows kernel template