whisper_flutter_new/macos/Classes/whisper_flutter_new.cpp

284 lines
7.9 KiB
C++
Raw Permalink Normal View History

#include "whisper_flutter_new.h"
#include "whisper.cpp/whisper.h"
#define DR_WAV_IMPLEMENTATION
#include "whisper.cpp/examples/dr_wav.h"
#include <cstdio>
#include <string>
#include <thread>
#include <vector>
#include <cmath>
#include <iostream>
#include <stdio.h>
#include "json/json.hpp"
using json = nlohmann::json;
char *jsonToChar(json jsonData) noexcept
{
std::string result = jsonData.dump();
char *ch = new char[result.size() + 1];
strcpy(ch, result.c_str());
return ch;
}
struct whisper_params
{
int32_t seed = -1; // RNG seed, not used currently
int32_t n_threads = std::min(4, (int32_t)std::thread::hardware_concurrency());
int32_t n_processors = 1;
int32_t offset_t_ms = 0;
int32_t offset_n = 0;
int32_t duration_ms = 0;
int32_t max_context = -1;
int32_t max_len = 0;
int32_t best_of = 5;
int32_t beam_size = -1;
float word_thold = 0.01f;
float entropy_thold = 2.40f;
float logprob_thold = -1.00f;
bool verbose = false;
bool print_special_tokens = false;
bool speed_up = false;
bool translate = false;
bool diarize = false;
bool no_fallback = false;
bool output_txt = false;
bool output_vtt = false;
bool output_srt = false;
bool output_wts = false;
bool output_csv = false;
bool print_special = false;
bool print_colors = false;
bool print_progress = false;
bool no_timestamps = false;
bool split_on_word = false;
std::string language = "auto";
std::string prompt;
std::string model = "models/ggml-tiny.bin";
std::string audio = "samples/jfk.wav";
std::vector<std::string> fname_inp = {};
std::vector<std::string> fname_outp = {};
};
struct whisper_print_user_data
{
const whisper_params *params;
const std::vector<std::vector<float>> *pcmf32s;
};
json transcribe(json jsonBody) noexcept
{
whisper_params params;
params.n_threads = jsonBody["threads"];
params.verbose = jsonBody["is_verbose"];
params.translate = jsonBody["is_translate"];
params.language = jsonBody["language"];
params.print_special_tokens = jsonBody["is_special_tokens"];
params.no_timestamps = jsonBody["is_no_timestamps"];
params.model = jsonBody["model"];
params.audio = jsonBody["audio"];
params.split_on_word = jsonBody["split_on_word"];
json jsonResult;
jsonResult["@type"] = "transcribe";
if (params.language != "" && params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1)
{
jsonResult["@type"] = "error";
jsonResult["message"] = "error: unknown language = " + params.language;
return jsonResult;
}
if (params.seed < 0)
{
params.seed = time(NULL);
}
// whisper init
struct whisper_context *ctx = whisper_init_from_file(params.model.c_str());
std::string text_result = "";
const auto fname_inp = params.audio;
// WAV input
std::vector<float> pcmf32;
{
drwav wav;
if (!drwav_init_file(&wav, fname_inp.c_str(), NULL))
{
jsonResult["@type"] = "error";
jsonResult["message"] = " failed to open WAV file ";
return jsonResult;
}
if (wav.channels != 1 && wav.channels != 2)
{
jsonResult["@type"] = "error";
jsonResult["message"] = "must be mono or stereo";
return jsonResult;
}
if (wav.sampleRate != WHISPER_SAMPLE_RATE)
{
jsonResult["@type"] = "error";
jsonResult["message"] = "WAV file must be 16 kHz";
return jsonResult;
}
if (wav.bitsPerSample != 16)
{
jsonResult["@type"] = "error";
jsonResult["message"] = "WAV file must be 16 bit";
return jsonResult;
}
int n = wav.totalPCMFrameCount;
std::vector<int16_t> pcm16;
pcm16.resize(n * wav.channels);
drwav_read_pcm_frames_s16(&wav, n, pcm16.data());
drwav_uninit(&wav);
// convert to mono, float
pcmf32.resize(n);
if (wav.channels == 1)
{
for (int i = 0; i < n; i++)
{
pcmf32[i] = float(pcm16[i]) / 32768.0f;
}
}
else
{
for (int i = 0; i < n; i++)
{
pcmf32[i] = float(pcm16[2 * i] + pcm16[2 * i + 1]) / 65536.0f;
}
}
}
{
if (params.language == "" && params.language == "auto")
{
params.language = "auto";
params.translate = false;
}
}
// run the inference
{
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
wparams.print_realtime = false;
wparams.print_progress = false;
wparams.print_timestamps = !params.no_timestamps;
// wparams.print_special_tokens = params.print_special_tokens;
wparams.translate = params.translate;
wparams.language = params.language.c_str();
wparams.n_threads = params.n_threads;
wparams.split_on_word = params.split_on_word;
if (params.split_on_word) {
wparams.max_len = 1;
wparams.token_timestamps = true;
}
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0)
{
jsonResult["@type"] = "error";
jsonResult["message"] = "failed to process audio";
return jsonResult;
}
// print result;
if (!wparams.print_realtime)
{
const int n_segments = whisper_full_n_segments(ctx);
std::vector<json> segmentsJson = {};
for (int i = 0; i < n_segments; ++i)
{
const char *text = whisper_full_get_segment_text(ctx, i);
std::string str(text);
text_result += str;
if (params.no_timestamps)
{
// printf("%s", text);
// fflush(stdout);
} else {
json jsonSegment;
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
// printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
jsonSegment["from_ts"] = t0;
jsonSegment["to_ts"] = t1;
jsonSegment["text"] = text;
segmentsJson.push_back(jsonSegment);
}
}
if (!params.no_timestamps) {
jsonResult["segments"] = segmentsJson;
}
}
}
jsonResult["text"] = text_result;
whisper_free(ctx);
return jsonResult;
}
extern "C"
{
FUNCTION_ATTRIBUTE
char *request(char *body)
{
try
{
json jsonBody = json::parse(body);
json jsonResult;
if (jsonBody["@type"] == "getTextFromWavFile")
{
try
{
return jsonToChar(transcribe(jsonBody));
}
catch (const std::exception &e)
{
jsonResult["@type"] = "error";
jsonResult["message"] = e.what();
return jsonToChar(jsonResult);
}
}
if (jsonBody["@type"] == "getVersion")
{
jsonResult["@type"] = "version";
jsonResult["message"] = "lib version: v1.0.1";
return jsonToChar(jsonResult);
}
jsonResult["@type"] = "error";
jsonResult["message"] = "method not found";
return jsonToChar(jsonResult);
}
catch (const std::exception &e)
{
json jsonResult;
jsonResult["@type"] = "error";
jsonResult["message"] = e.what();
return jsonToChar(jsonResult);
}
}
}