Spaces:
Running
coreml: fix Whisper to CoreML conversion by disabling SDPA [no ci] (#2979)
Browse files* coreml: fix Whisper to CoreML conversion by disabling SDPA
This commit disables the use of PyTorch's
`scaled_dot_product_attention` in the Whisper model to avoid
compatibility issues during CoreML conversion.
The issue occurs because coremltools requires PyTorch 2.5.0, but the
Whisper implementation may expect behavior from newer PyTorch versions.
By setting `MultiHeadAttention.use_sdpa = False`, we force Whisper to
use its fallback manual attention implementation, which works correctly
with PyTorch 2.5.0 during the tracing process.
Refs: https://github.com/ggerganov/whisper.cpp/issues/2783
* coreml: fix audio shape in whisper decoder conversion
This commit fixes the audio shape in the whisper decoder conversion
script.
The motivation for this is that the audio shape was incorrect and
was causing the conversion to fail.
* coreml : set -e in generate-coreml-interface.sh
The commit sets the -e flag in the generate-coreml-interface.sh script
to make sure the script fails if any command fails.
* coreml : update generated encoder/decoder interfaces
This commit updates the generated encoder/decoder interfaces for the
whisper model which is the result of running the
generate-coreml-interface.sh script.
|
@@ -12,6 +12,15 @@ from coremltools.models.neural_network.quantization_utils import quantize_weight
|
|
| 12 |
from whisper.model import Whisper, AudioEncoder, TextDecoder, ResidualAttentionBlock, MultiHeadAttention, ModelDimensions
|
| 13 |
from whisper import load_model
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
# Use for changing dim of input in encoder and decoder embeddings
|
| 16 |
def linear_to_conv2d_map(state_dict, prefix, local_metadata, strict,
|
| 17 |
missing_keys, unexpected_keys, error_msgs):
|
|
@@ -260,10 +269,11 @@ def convert_decoder(hparams, model, quantize=False):
|
|
| 260 |
model.eval()
|
| 261 |
|
| 262 |
tokens_shape = (1, 1)
|
| 263 |
-
audio_shape = (1, hparams.
|
| 264 |
|
| 265 |
audio_data = torch.randn(audio_shape)
|
| 266 |
-
token_data = torch.randint(
|
|
|
|
| 267 |
traced_model = torch.jit.trace(model, (token_data, audio_data))
|
| 268 |
|
| 269 |
model = ct.convert(
|
|
|
|
| 12 |
from whisper.model import Whisper, AudioEncoder, TextDecoder, ResidualAttentionBlock, MultiHeadAttention, ModelDimensions
|
| 13 |
from whisper import load_model
|
| 14 |
|
| 15 |
+
# Disable PyTorch Scaled Dot-Product Attention (SDPA) to avoid compatibility issues.
|
| 16 |
+
# The Whisper implementation expects a specific behavior from
|
| 17 |
+
# torch.nn.functional.scaled_dot_product_attention that differs between PyTorch
|
| 18 |
+
# versions. Setting use_sdpa=False forces Whisper to use its manual attention
|
| 19 |
+
# implementation instead, which is more stable across different PyTorch versions
|
| 20 |
+
# (2.5.0 required by coremltools vs newer versions).
|
| 21 |
+
import whisper.model
|
| 22 |
+
whisper.model.MultiHeadAttention.use_sdpa = False
|
| 23 |
+
|
| 24 |
# Use for changing dim of input in encoder and decoder embeddings
|
| 25 |
def linear_to_conv2d_map(state_dict, prefix, local_metadata, strict,
|
| 26 |
missing_keys, unexpected_keys, error_msgs):
|
|
|
|
| 269 |
model.eval()
|
| 270 |
|
| 271 |
tokens_shape = (1, 1)
|
| 272 |
+
audio_shape = (1, hparams.n_audio_ctx, hparams.n_audio_state)
|
| 273 |
|
| 274 |
audio_data = torch.randn(audio_shape)
|
| 275 |
+
token_data = torch.randint(hparams.n_vocab, tokens_shape).long()
|
| 276 |
+
|
| 277 |
traced_model = torch.jit.trace(model, (token_data, audio_data))
|
| 278 |
|
| 279 |
model = ct.convert(
|
|
@@ -5,6 +5,8 @@
|
|
| 5 |
# - src/coreml/whisper-decoder-impl.h and src/coreml/whisper-decoder-impl.m
|
| 6 |
#
|
| 7 |
|
|
|
|
|
|
|
| 8 |
wd=$(dirname "$0")
|
| 9 |
cd "$wd/../" || exit
|
| 10 |
|
|
|
|
| 5 |
# - src/coreml/whisper-decoder-impl.h and src/coreml/whisper-decoder-impl.m
|
| 6 |
#
|
| 7 |
|
| 8 |
+
set -e
|
| 9 |
+
|
| 10 |
wd=$(dirname "$0")
|
| 11 |
cd "$wd/../" || exit
|
| 12 |
|
|
@@ -11,36 +11,33 @@
|
|
| 11 |
|
| 12 |
NS_ASSUME_NONNULL_BEGIN
|
| 13 |
|
| 14 |
-
|
| 15 |
/// Model Prediction Input Type
|
| 16 |
-
API_AVAILABLE(macos(
|
| 17 |
@interface whisper_decoder_implInput : NSObject<MLFeatureProvider>
|
| 18 |
|
| 19 |
-
/// token_data as 1 by 1 matrix of
|
| 20 |
@property (readwrite, nonatomic, strong) MLMultiArray * token_data;
|
| 21 |
|
| 22 |
-
/// audio_data as 1 ×
|
| 23 |
@property (readwrite, nonatomic, strong) MLMultiArray * audio_data;
|
| 24 |
- (instancetype)init NS_UNAVAILABLE;
|
| 25 |
- (instancetype)initWithToken_data:(MLMultiArray *)token_data audio_data:(MLMultiArray *)audio_data NS_DESIGNATED_INITIALIZER;
|
| 26 |
|
| 27 |
@end
|
| 28 |
|
| 29 |
-
|
| 30 |
/// Model Prediction Output Type
|
| 31 |
-
API_AVAILABLE(macos(
|
| 32 |
@interface whisper_decoder_implOutput : NSObject<MLFeatureProvider>
|
| 33 |
|
| 34 |
-
///
|
| 35 |
-
@property (readwrite, nonatomic, strong) MLMultiArray *
|
| 36 |
- (instancetype)init NS_UNAVAILABLE;
|
| 37 |
-
- (instancetype)
|
| 38 |
|
| 39 |
@end
|
| 40 |
|
| 41 |
-
|
| 42 |
/// Class for model loading and prediction
|
| 43 |
-
API_AVAILABLE(macos(
|
| 44 |
@interface whisper_decoder_impl : NSObject
|
| 45 |
@property (readonly, nonatomic, nullable) MLModel * model;
|
| 46 |
|
|
@@ -94,7 +91,7 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
|
|
| 94 |
@param configuration The model configuration
|
| 95 |
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_decoder_impl instance or NSError object.
|
| 96 |
*/
|
| 97 |
-
+ (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_decoder_impl * _Nullable model, NSError * _Nullable error))handler;
|
| 98 |
|
| 99 |
/**
|
| 100 |
Construct whisper_decoder_impl instance asynchronously with URL of .mlmodelc directory and optional configuration.
|
|
@@ -105,7 +102,7 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
|
|
| 105 |
@param configuration The model configuration
|
| 106 |
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_decoder_impl instance or NSError object.
|
| 107 |
*/
|
| 108 |
-
+ (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_decoder_impl * _Nullable model, NSError * _Nullable error))handler;
|
| 109 |
|
| 110 |
/**
|
| 111 |
Make a prediction using the standard interface
|
|
@@ -124,10 +121,25 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
|
|
| 124 |
*/
|
| 125 |
- (nullable whisper_decoder_implOutput *)predictionFromFeatures:(whisper_decoder_implInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error;
|
| 126 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
/**
|
| 128 |
Make a prediction using the convenience interface
|
| 129 |
-
@param token_data
|
| 130 |
-
@param audio_data
|
| 131 |
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
|
| 132 |
@return the prediction as whisper_decoder_implOutput
|
| 133 |
*/
|
|
|
|
| 11 |
|
| 12 |
NS_ASSUME_NONNULL_BEGIN
|
| 13 |
|
|
|
|
| 14 |
/// Model Prediction Input Type
|
| 15 |
+
API_AVAILABLE(macos(10.15), ios(13.0), watchos(6.0), tvos(13.0)) __attribute__((visibility("hidden")))
|
| 16 |
@interface whisper_decoder_implInput : NSObject<MLFeatureProvider>
|
| 17 |
|
| 18 |
+
/// token_data as 1 by 1 matrix of floats
|
| 19 |
@property (readwrite, nonatomic, strong) MLMultiArray * token_data;
|
| 20 |
|
| 21 |
+
/// audio_data as 1 × 1500 × 384 3-dimensional array of floats
|
| 22 |
@property (readwrite, nonatomic, strong) MLMultiArray * audio_data;
|
| 23 |
- (instancetype)init NS_UNAVAILABLE;
|
| 24 |
- (instancetype)initWithToken_data:(MLMultiArray *)token_data audio_data:(MLMultiArray *)audio_data NS_DESIGNATED_INITIALIZER;
|
| 25 |
|
| 26 |
@end
|
| 27 |
|
|
|
|
| 28 |
/// Model Prediction Output Type
|
| 29 |
+
API_AVAILABLE(macos(10.15), ios(13.0), watchos(6.0), tvos(13.0)) __attribute__((visibility("hidden")))
|
| 30 |
@interface whisper_decoder_implOutput : NSObject<MLFeatureProvider>
|
| 31 |
|
| 32 |
+
/// cast_76 as multidimensional array of floats
|
| 33 |
+
@property (readwrite, nonatomic, strong) MLMultiArray * cast_76;
|
| 34 |
- (instancetype)init NS_UNAVAILABLE;
|
| 35 |
+
- (instancetype)initWithCast_76:(MLMultiArray *)cast_76 NS_DESIGNATED_INITIALIZER;
|
| 36 |
|
| 37 |
@end
|
| 38 |
|
|
|
|
| 39 |
/// Class for model loading and prediction
|
| 40 |
+
API_AVAILABLE(macos(10.15), ios(13.0), watchos(6.0), tvos(13.0)) __attribute__((visibility("hidden")))
|
| 41 |
@interface whisper_decoder_impl : NSObject
|
| 42 |
@property (readonly, nonatomic, nullable) MLModel * model;
|
| 43 |
|
|
|
|
| 91 |
@param configuration The model configuration
|
| 92 |
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_decoder_impl instance or NSError object.
|
| 93 |
*/
|
| 94 |
+
+ (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_decoder_impl * _Nullable model, NSError * _Nullable error))handler API_AVAILABLE(macos(11.0), ios(14.0), watchos(7.0), tvos(14.0)) __attribute__((visibility("hidden")));
|
| 95 |
|
| 96 |
/**
|
| 97 |
Construct whisper_decoder_impl instance asynchronously with URL of .mlmodelc directory and optional configuration.
|
|
|
|
| 102 |
@param configuration The model configuration
|
| 103 |
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_decoder_impl instance or NSError object.
|
| 104 |
*/
|
| 105 |
+
+ (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_decoder_impl * _Nullable model, NSError * _Nullable error))handler API_AVAILABLE(macos(11.0), ios(14.0), watchos(7.0), tvos(14.0)) __attribute__((visibility("hidden")));
|
| 106 |
|
| 107 |
/**
|
| 108 |
Make a prediction using the standard interface
|
|
|
|
| 121 |
*/
|
| 122 |
- (nullable whisper_decoder_implOutput *)predictionFromFeatures:(whisper_decoder_implInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error;
|
| 123 |
|
| 124 |
+
/**
|
| 125 |
+
Make an asynchronous prediction using the standard interface
|
| 126 |
+
@param input an instance of whisper_decoder_implInput to predict from
|
| 127 |
+
@param completionHandler a block that will be called upon completion of the prediction. error will be nil if no error occurred.
|
| 128 |
+
*/
|
| 129 |
+
- (void)predictionFromFeatures:(whisper_decoder_implInput *)input completionHandler:(void (^)(whisper_decoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler API_AVAILABLE(macos(14.0), ios(17.0), watchos(10.0), tvos(17.0)) __attribute__((visibility("hidden")));
|
| 130 |
+
|
| 131 |
+
/**
|
| 132 |
+
Make an asynchronous prediction using the standard interface
|
| 133 |
+
@param input an instance of whisper_decoder_implInput to predict from
|
| 134 |
+
@param options prediction options
|
| 135 |
+
@param completionHandler a block that will be called upon completion of the prediction. error will be nil if no error occurred.
|
| 136 |
+
*/
|
| 137 |
+
- (void)predictionFromFeatures:(whisper_decoder_implInput *)input options:(MLPredictionOptions *)options completionHandler:(void (^)(whisper_decoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler API_AVAILABLE(macos(14.0), ios(17.0), watchos(10.0), tvos(17.0)) __attribute__((visibility("hidden")));
|
| 138 |
+
|
| 139 |
/**
|
| 140 |
Make a prediction using the convenience interface
|
| 141 |
+
@param token_data 1 by 1 matrix of floats
|
| 142 |
+
@param audio_data 1 × 1500 × 384 3-dimensional array of floats
|
| 143 |
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
|
| 144 |
@return the prediction as whisper_decoder_implOutput
|
| 145 |
*/
|
|
@@ -39,21 +39,21 @@
|
|
| 39 |
|
| 40 |
@implementation whisper_decoder_implOutput
|
| 41 |
|
| 42 |
-
- (instancetype)
|
| 43 |
self = [super init];
|
| 44 |
if (self) {
|
| 45 |
-
|
| 46 |
}
|
| 47 |
return self;
|
| 48 |
}
|
| 49 |
|
| 50 |
- (NSSet<NSString *> *)featureNames {
|
| 51 |
-
return [NSSet setWithArray:@[@"
|
| 52 |
}
|
| 53 |
|
| 54 |
- (nullable MLFeatureValue *)featureValueForName:(NSString *)featureName {
|
| 55 |
-
if ([featureName isEqualToString:@"
|
| 56 |
-
return [MLFeatureValue featureValueWithMultiArray:self.
|
| 57 |
}
|
| 58 |
return nil;
|
| 59 |
}
|
|
@@ -80,10 +80,13 @@
|
|
| 80 |
Such application may want to use `-[MLModel initWithContentsOfURL:configuration:error:]` and `+URLOfModelInThisBundle` to create a MLModel object to pass-in.
|
| 81 |
*/
|
| 82 |
- (instancetype)initWithMLModel:(MLModel *)model {
|
|
|
|
|
|
|
|
|
|
| 83 |
self = [super init];
|
| 84 |
-
if (
|
| 85 |
-
|
| 86 |
-
|
| 87 |
return self;
|
| 88 |
}
|
| 89 |
|
|
@@ -177,7 +180,29 @@
|
|
| 177 |
- (nullable whisper_decoder_implOutput *)predictionFromFeatures:(whisper_decoder_implInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error {
|
| 178 |
id<MLFeatureProvider> outFeatures = [self.model predictionFromFeatures:input options:options error:error];
|
| 179 |
if (!outFeatures) { return nil; }
|
| 180 |
-
return [[whisper_decoder_implOutput alloc]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
}
|
| 182 |
|
| 183 |
- (nullable whisper_decoder_implOutput *)predictionFromToken_data:(MLMultiArray *)token_data audio_data:(MLMultiArray *)audio_data error:(NSError * _Nullable __autoreleasing * _Nullable)error {
|
|
@@ -192,7 +217,7 @@
|
|
| 192 |
NSMutableArray<whisper_decoder_implOutput*> *results = [NSMutableArray arrayWithCapacity:(NSUInteger)outBatch.count];
|
| 193 |
for (NSInteger i = 0; i < outBatch.count; i++) {
|
| 194 |
id<MLFeatureProvider> resultProvider = [outBatch featuresAtIndex:i];
|
| 195 |
-
whisper_decoder_implOutput * result = [[whisper_decoder_implOutput alloc]
|
| 196 |
[results addObject:result];
|
| 197 |
}
|
| 198 |
return results;
|
|
|
|
| 39 |
|
| 40 |
@implementation whisper_decoder_implOutput
|
| 41 |
|
| 42 |
+
- (instancetype)initWithCast_76:(MLMultiArray *)cast_76 {
|
| 43 |
self = [super init];
|
| 44 |
if (self) {
|
| 45 |
+
_cast_76 = cast_76;
|
| 46 |
}
|
| 47 |
return self;
|
| 48 |
}
|
| 49 |
|
| 50 |
- (NSSet<NSString *> *)featureNames {
|
| 51 |
+
return [NSSet setWithArray:@[@"cast_76"]];
|
| 52 |
}
|
| 53 |
|
| 54 |
- (nullable MLFeatureValue *)featureValueForName:(NSString *)featureName {
|
| 55 |
+
if ([featureName isEqualToString:@"cast_76"]) {
|
| 56 |
+
return [MLFeatureValue featureValueWithMultiArray:self.cast_76];
|
| 57 |
}
|
| 58 |
return nil;
|
| 59 |
}
|
|
|
|
| 80 |
Such application may want to use `-[MLModel initWithContentsOfURL:configuration:error:]` and `+URLOfModelInThisBundle` to create a MLModel object to pass-in.
|
| 81 |
*/
|
| 82 |
- (instancetype)initWithMLModel:(MLModel *)model {
|
| 83 |
+
if (model == nil) {
|
| 84 |
+
return nil;
|
| 85 |
+
}
|
| 86 |
self = [super init];
|
| 87 |
+
if (self != nil) {
|
| 88 |
+
_model = model;
|
| 89 |
+
}
|
| 90 |
return self;
|
| 91 |
}
|
| 92 |
|
|
|
|
| 180 |
- (nullable whisper_decoder_implOutput *)predictionFromFeatures:(whisper_decoder_implInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error {
|
| 181 |
id<MLFeatureProvider> outFeatures = [self.model predictionFromFeatures:input options:options error:error];
|
| 182 |
if (!outFeatures) { return nil; }
|
| 183 |
+
return [[whisper_decoder_implOutput alloc] initWithCast_76:(MLMultiArray *)[outFeatures featureValueForName:@"cast_76"].multiArrayValue];
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
- (void)predictionFromFeatures:(whisper_decoder_implInput *)input completionHandler:(void (^)(whisper_decoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler {
|
| 187 |
+
[self.model predictionFromFeatures:input completionHandler:^(id<MLFeatureProvider> prediction, NSError *predictionError) {
|
| 188 |
+
if (prediction != nil) {
|
| 189 |
+
whisper_decoder_implOutput *output = [[whisper_decoder_implOutput alloc] initWithCast_76:(MLMultiArray *)[prediction featureValueForName:@"cast_76"].multiArrayValue];
|
| 190 |
+
completionHandler(output, predictionError);
|
| 191 |
+
} else {
|
| 192 |
+
completionHandler(nil, predictionError);
|
| 193 |
+
}
|
| 194 |
+
}];
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
- (void)predictionFromFeatures:(whisper_decoder_implInput *)input options:(MLPredictionOptions *)options completionHandler:(void (^)(whisper_decoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler {
|
| 198 |
+
[self.model predictionFromFeatures:input options:options completionHandler:^(id<MLFeatureProvider> prediction, NSError *predictionError) {
|
| 199 |
+
if (prediction != nil) {
|
| 200 |
+
whisper_decoder_implOutput *output = [[whisper_decoder_implOutput alloc] initWithCast_76:(MLMultiArray *)[prediction featureValueForName:@"cast_76"].multiArrayValue];
|
| 201 |
+
completionHandler(output, predictionError);
|
| 202 |
+
} else {
|
| 203 |
+
completionHandler(nil, predictionError);
|
| 204 |
+
}
|
| 205 |
+
}];
|
| 206 |
}
|
| 207 |
|
| 208 |
- (nullable whisper_decoder_implOutput *)predictionFromToken_data:(MLMultiArray *)token_data audio_data:(MLMultiArray *)audio_data error:(NSError * _Nullable __autoreleasing * _Nullable)error {
|
|
|
|
| 217 |
NSMutableArray<whisper_decoder_implOutput*> *results = [NSMutableArray arrayWithCapacity:(NSUInteger)outBatch.count];
|
| 218 |
for (NSInteger i = 0; i < outBatch.count; i++) {
|
| 219 |
id<MLFeatureProvider> resultProvider = [outBatch featuresAtIndex:i];
|
| 220 |
+
whisper_decoder_implOutput * result = [[whisper_decoder_implOutput alloc] initWithCast_76:(MLMultiArray *)[resultProvider featureValueForName:@"cast_76"].multiArrayValue];
|
| 221 |
[results addObject:result];
|
| 222 |
}
|
| 223 |
return results;
|
|
@@ -11,9 +11,8 @@
|
|
| 11 |
|
| 12 |
NS_ASSUME_NONNULL_BEGIN
|
| 13 |
|
| 14 |
-
|
| 15 |
/// Model Prediction Input Type
|
| 16 |
-
API_AVAILABLE(macos(
|
| 17 |
@interface whisper_encoder_implInput : NSObject<MLFeatureProvider>
|
| 18 |
|
| 19 |
/// logmel_data as 1 × 80 × 3000 3-dimensional array of floats
|
|
@@ -23,9 +22,8 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
|
|
| 23 |
|
| 24 |
@end
|
| 25 |
|
| 26 |
-
|
| 27 |
/// Model Prediction Output Type
|
| 28 |
-
API_AVAILABLE(macos(
|
| 29 |
@interface whisper_encoder_implOutput : NSObject<MLFeatureProvider>
|
| 30 |
|
| 31 |
/// output as multidimensional array of floats
|
|
@@ -35,9 +33,8 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
|
|
| 35 |
|
| 36 |
@end
|
| 37 |
|
| 38 |
-
|
| 39 |
/// Class for model loading and prediction
|
| 40 |
-
API_AVAILABLE(macos(
|
| 41 |
@interface whisper_encoder_impl : NSObject
|
| 42 |
@property (readonly, nonatomic, nullable) MLModel * model;
|
| 43 |
|
|
@@ -91,7 +88,7 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
|
|
| 91 |
@param configuration The model configuration
|
| 92 |
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_encoder_impl instance or NSError object.
|
| 93 |
*/
|
| 94 |
-
+ (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_encoder_impl * _Nullable model, NSError * _Nullable error))handler;
|
| 95 |
|
| 96 |
/**
|
| 97 |
Construct whisper_encoder_impl instance asynchronously with URL of .mlmodelc directory and optional configuration.
|
|
@@ -102,7 +99,7 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
|
|
| 102 |
@param configuration The model configuration
|
| 103 |
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_encoder_impl instance or NSError object.
|
| 104 |
*/
|
| 105 |
-
+ (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_encoder_impl * _Nullable model, NSError * _Nullable error))handler;
|
| 106 |
|
| 107 |
/**
|
| 108 |
Make a prediction using the standard interface
|
|
@@ -121,9 +118,24 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
|
|
| 121 |
*/
|
| 122 |
- (nullable whisper_encoder_implOutput *)predictionFromFeatures:(whisper_encoder_implInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error;
|
| 123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
/**
|
| 125 |
Make a prediction using the convenience interface
|
| 126 |
-
@param logmel_data
|
| 127 |
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
|
| 128 |
@return the prediction as whisper_encoder_implOutput
|
| 129 |
*/
|
|
|
|
| 11 |
|
| 12 |
NS_ASSUME_NONNULL_BEGIN
|
| 13 |
|
|
|
|
| 14 |
/// Model Prediction Input Type
|
| 15 |
+
API_AVAILABLE(macos(10.15), ios(13.0), watchos(6.0), tvos(13.0)) __attribute__((visibility("hidden")))
|
| 16 |
@interface whisper_encoder_implInput : NSObject<MLFeatureProvider>
|
| 17 |
|
| 18 |
/// logmel_data as 1 × 80 × 3000 3-dimensional array of floats
|
|
|
|
| 22 |
|
| 23 |
@end
|
| 24 |
|
|
|
|
| 25 |
/// Model Prediction Output Type
|
| 26 |
+
API_AVAILABLE(macos(10.15), ios(13.0), watchos(6.0), tvos(13.0)) __attribute__((visibility("hidden")))
|
| 27 |
@interface whisper_encoder_implOutput : NSObject<MLFeatureProvider>
|
| 28 |
|
| 29 |
/// output as multidimensional array of floats
|
|
|
|
| 33 |
|
| 34 |
@end
|
| 35 |
|
|
|
|
| 36 |
/// Class for model loading and prediction
|
| 37 |
+
API_AVAILABLE(macos(10.15), ios(13.0), watchos(6.0), tvos(13.0)) __attribute__((visibility("hidden")))
|
| 38 |
@interface whisper_encoder_impl : NSObject
|
| 39 |
@property (readonly, nonatomic, nullable) MLModel * model;
|
| 40 |
|
|
|
|
| 88 |
@param configuration The model configuration
|
| 89 |
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_encoder_impl instance or NSError object.
|
| 90 |
*/
|
| 91 |
+
+ (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_encoder_impl * _Nullable model, NSError * _Nullable error))handler API_AVAILABLE(macos(11.0), ios(14.0), watchos(7.0), tvos(14.0)) __attribute__((visibility("hidden")));
|
| 92 |
|
| 93 |
/**
|
| 94 |
Construct whisper_encoder_impl instance asynchronously with URL of .mlmodelc directory and optional configuration.
|
|
|
|
| 99 |
@param configuration The model configuration
|
| 100 |
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_encoder_impl instance or NSError object.
|
| 101 |
*/
|
| 102 |
+
+ (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_encoder_impl * _Nullable model, NSError * _Nullable error))handler API_AVAILABLE(macos(11.0), ios(14.0), watchos(7.0), tvos(14.0)) __attribute__((visibility("hidden")));
|
| 103 |
|
| 104 |
/**
|
| 105 |
Make a prediction using the standard interface
|
|
|
|
| 118 |
*/
|
| 119 |
- (nullable whisper_encoder_implOutput *)predictionFromFeatures:(whisper_encoder_implInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error;
|
| 120 |
|
| 121 |
+
/**
|
| 122 |
+
Make an asynchronous prediction using the standard interface
|
| 123 |
+
@param input an instance of whisper_encoder_implInput to predict from
|
| 124 |
+
@param completionHandler a block that will be called upon completion of the prediction. error will be nil if no error occurred.
|
| 125 |
+
*/
|
| 126 |
+
- (void)predictionFromFeatures:(whisper_encoder_implInput *)input completionHandler:(void (^)(whisper_encoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler API_AVAILABLE(macos(14.0), ios(17.0), watchos(10.0), tvos(17.0)) __attribute__((visibility("hidden")));
|
| 127 |
+
|
| 128 |
+
/**
|
| 129 |
+
Make an asynchronous prediction using the standard interface
|
| 130 |
+
@param input an instance of whisper_encoder_implInput to predict from
|
| 131 |
+
@param options prediction options
|
| 132 |
+
@param completionHandler a block that will be called upon completion of the prediction. error will be nil if no error occurred.
|
| 133 |
+
*/
|
| 134 |
+
- (void)predictionFromFeatures:(whisper_encoder_implInput *)input options:(MLPredictionOptions *)options completionHandler:(void (^)(whisper_encoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler API_AVAILABLE(macos(14.0), ios(17.0), watchos(10.0), tvos(17.0)) __attribute__((visibility("hidden")));
|
| 135 |
+
|
| 136 |
/**
|
| 137 |
Make a prediction using the convenience interface
|
| 138 |
+
@param logmel_data 1 × 80 × 3000 3-dimensional array of floats
|
| 139 |
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
|
| 140 |
@return the prediction as whisper_encoder_implOutput
|
| 141 |
*/
|
|
@@ -76,10 +76,13 @@
|
|
| 76 |
Such application may want to use `-[MLModel initWithContentsOfURL:configuration:error:]` and `+URLOfModelInThisBundle` to create a MLModel object to pass-in.
|
| 77 |
*/
|
| 78 |
- (instancetype)initWithMLModel:(MLModel *)model {
|
|
|
|
|
|
|
|
|
|
| 79 |
self = [super init];
|
| 80 |
-
if (
|
| 81 |
-
|
| 82 |
-
|
| 83 |
return self;
|
| 84 |
}
|
| 85 |
|
|
@@ -176,6 +179,28 @@
|
|
| 176 |
return [[whisper_encoder_implOutput alloc] initWithOutput:(MLMultiArray *)[outFeatures featureValueForName:@"output"].multiArrayValue];
|
| 177 |
}
|
| 178 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
- (nullable whisper_encoder_implOutput *)predictionFromLogmel_data:(MLMultiArray *)logmel_data error:(NSError * _Nullable __autoreleasing * _Nullable)error {
|
| 180 |
whisper_encoder_implInput *input_ = [[whisper_encoder_implInput alloc] initWithLogmel_data:logmel_data];
|
| 181 |
return [self predictionFromFeatures:input_ error:error];
|
|
|
|
| 76 |
Such application may want to use `-[MLModel initWithContentsOfURL:configuration:error:]` and `+URLOfModelInThisBundle` to create a MLModel object to pass-in.
|
| 77 |
*/
|
| 78 |
- (instancetype)initWithMLModel:(MLModel *)model {
|
| 79 |
+
if (model == nil) {
|
| 80 |
+
return nil;
|
| 81 |
+
}
|
| 82 |
self = [super init];
|
| 83 |
+
if (self != nil) {
|
| 84 |
+
_model = model;
|
| 85 |
+
}
|
| 86 |
return self;
|
| 87 |
}
|
| 88 |
|
|
|
|
| 179 |
return [[whisper_encoder_implOutput alloc] initWithOutput:(MLMultiArray *)[outFeatures featureValueForName:@"output"].multiArrayValue];
|
| 180 |
}
|
| 181 |
|
| 182 |
+
- (void)predictionFromFeatures:(whisper_encoder_implInput *)input completionHandler:(void (^)(whisper_encoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler {
|
| 183 |
+
[self.model predictionFromFeatures:input completionHandler:^(id<MLFeatureProvider> prediction, NSError *predictionError) {
|
| 184 |
+
if (prediction != nil) {
|
| 185 |
+
whisper_encoder_implOutput *output = [[whisper_encoder_implOutput alloc] initWithOutput:(MLMultiArray *)[prediction featureValueForName:@"output"].multiArrayValue];
|
| 186 |
+
completionHandler(output, predictionError);
|
| 187 |
+
} else {
|
| 188 |
+
completionHandler(nil, predictionError);
|
| 189 |
+
}
|
| 190 |
+
}];
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
- (void)predictionFromFeatures:(whisper_encoder_implInput *)input options:(MLPredictionOptions *)options completionHandler:(void (^)(whisper_encoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler {
|
| 194 |
+
[self.model predictionFromFeatures:input options:options completionHandler:^(id<MLFeatureProvider> prediction, NSError *predictionError) {
|
| 195 |
+
if (prediction != nil) {
|
| 196 |
+
whisper_encoder_implOutput *output = [[whisper_encoder_implOutput alloc] initWithOutput:(MLMultiArray *)[prediction featureValueForName:@"output"].multiArrayValue];
|
| 197 |
+
completionHandler(output, predictionError);
|
| 198 |
+
} else {
|
| 199 |
+
completionHandler(nil, predictionError);
|
| 200 |
+
}
|
| 201 |
+
}];
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
- (nullable whisper_encoder_implOutput *)predictionFromLogmel_data:(MLMultiArray *)logmel_data error:(NSError * _Nullable __autoreleasing * _Nullable)error {
|
| 205 |
whisper_encoder_implInput *input_ = [[whisper_encoder_implInput alloc] initWithLogmel_data:logmel_data];
|
| 206 |
return [self predictionFromFeatures:input_ error:error];
|