File size: 3,300 Bytes
5e456c0
 
 
 
 
 
 
 
 
 
 
 
 
728fcbe
5e456c0
 
 
 
 
 
 
 
728fcbe
 
 
 
 
4a56623
 
5e456c0
728fcbe
5e456c0
 
 
 
 
 
728fcbe
5e456c0
 
 
 
 
 
 
 
 
 
 
728fcbe
5e456c0
728fcbe
5e456c0
728fcbe
5e456c0
728fcbe
 
 
 
 
5e456c0
728fcbe
 
5e456c0
 
 
 
728fcbe
89e8314
6eea7b7
5e456c0
 
 
89e8314
 
5e456c0
728fcbe
 
 
 
 
 
 
 
 
 
 
 
 
5e456c0
 
 
728fcbe
5e456c0
 
 
728fcbe
 
 
 
 
5e456c0
728fcbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e456c0
728fcbe
 
 
5e456c0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
package main

import (
	"fmt"
	"io"
	"os"
	"time"

	// Package imports
	whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
	wav "github.com/go-audio/wav"
)

func Process(model whisper.Model, path string, flags *Flags) error {
	var data []float32

	// Create processing context
	context, err := model.NewContext()
	if err != nil {
		return err
	}

	// Set the parameters
	if err := flags.SetParams(context); err != nil {
		return err
	}

	fmt.Printf("\n%s\n", context.SystemInfo())

	// Open the file
	fmt.Fprintf(flags.Output(), "Loading %q\n", path)
	fh, err := os.Open(path)
	if err != nil {
		return err
	}
	defer fh.Close()

	// Decode the WAV file - load the full buffer
	dec := wav.NewDecoder(fh)
	if buf, err := dec.FullPCMBuffer(); err != nil {
		return err
	} else if dec.SampleRate != whisper.SampleRate {
		return fmt.Errorf("unsupported sample rate: %d", dec.SampleRate)
	} else if dec.NumChans != 1 {
		return fmt.Errorf("unsupported number of channels: %d", dec.NumChans)
	} else {
		data = buf.AsFloat32Buffer().Data
	}

	// Segment callback when -tokens is specified
	var cb whisper.SegmentCallback
	if flags.IsTokens() {
		cb = func(segment whisper.Segment) {
			fmt.Fprintf(flags.Output(), "%02d [%6s->%6s] ", segment.Num, segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond))
			for _, token := range segment.Tokens {
				if flags.IsColorize() && context.IsText(token) {
					fmt.Fprint(flags.Output(), Colorize(token.Text, int(token.P*24.0)), " ")
				} else {
					fmt.Fprint(flags.Output(), token.Text, " ")
				}
			}
			fmt.Fprintln(flags.Output(), "")
			fmt.Fprintln(flags.Output(), "")
		}
	}

	// Process the data
	fmt.Fprintf(flags.Output(), "  ...processing %q\n", path)
	context.ResetTimings()
	if err := context.Process(data, nil, cb, nil); err != nil {
		return err
	}

	context.PrintTimings()

	// Print out the results
	switch {
	case flags.GetOut() == "srt":
		return OutputSRT(os.Stdout, context)
	case flags.GetOut() == "none":
		return nil
	default:
		return Output(os.Stdout, context, flags.IsColorize())
	}
}

// Output text as SRT file
func OutputSRT(w io.Writer, context whisper.Context) error {
	n := 1
	for {
		segment, err := context.NextSegment()
		if err == io.EOF {
			return nil
		} else if err != nil {
			return err
		}
		fmt.Fprintln(w, n)
		fmt.Fprintln(w, srtTimestamp(segment.Start), " --> ", srtTimestamp(segment.End))
		fmt.Fprintln(w, segment.Text)
		fmt.Fprintln(w, "")
		n++
	}
}

// Output text to terminal
func Output(w io.Writer, context whisper.Context, colorize bool) error {
	for {
		segment, err := context.NextSegment()
		if err == io.EOF {
			return nil
		} else if err != nil {
			return err
		}
		fmt.Fprintf(w, "[%6s->%6s]", segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond))
		if colorize {
			for _, token := range segment.Tokens {
				if !context.IsText(token) {
					continue
				}
				fmt.Fprint(w, " ", Colorize(token.Text, int(token.P*24.0)))
			}
			fmt.Fprint(w, "\n")
		} else {
			fmt.Fprintln(w, " ", segment.Text)
		}
	}
}

// Return srtTimestamp
func srtTimestamp(t time.Duration) string {
	return fmt.Sprintf("%02d:%02d:%02d,%03d", t/time.Hour, (t%time.Hour)/time.Minute, (t%time.Minute)/time.Second, (t%time.Second)/time.Millisecond)
}