VIANC/pw_plugin/main.cc

304 lines
8.1 KiB
C++
Raw Normal View History

2025-01-16 23:35:35 +01:00
#include <random>
#include <cstdio>
#include <cmath>
#include <cstring>
#include <iostream>
2025-01-27 09:58:36 +01:00
#include <array>
#include <atomic>
2025-01-16 23:35:35 +01:00
#include "pipewire/pipewire.h"
#include "spa/param/audio/format-utils.h"
#include "spa/param/latency-utils.h"
2025-01-16 23:35:35 +01:00
2025-01-27 09:58:36 +01:00
#include "circ_buffer.h"
constexpr int samplerate = 48000;
2025-01-16 23:35:35 +01:00
struct local_data {
2025-01-27 09:58:36 +01:00
pw_thread_loop* loop;
pw_filter* filter;
2025-01-16 23:35:35 +01:00
void* mic_port; // Input data from mic
2025-01-27 09:58:36 +01:00
void* aec_port; // Echo cancelled data from mic
void* listen_port; // Input music port
void* hp_port; // Output data to headphone
2025-01-27 09:58:36 +01:00
// Initial learning buffers
circ_buffer<samplerate * 5> record_buffer;
circ_buffer<samplerate * 5> mic_buffer;
std::atomic<bool> isearly{true};
2025-01-27 09:58:36 +01:00
// Contains the hp output needed for computing FIR filter
circ_buffer<samplerate> filter_buf;
std::array<float, 128> impulse_response;
int initial_delay = 0;
std::atomic<bool> isavailable{false};
2025-01-27 09:58:36 +01:00
circ_buffer<samplerate*5> aec_outbuf;
int head_index = 0;
uint32_t call_count = 0;
2025-01-27 09:58:36 +01:00
std::atomic<bool> isdone{false};
float phase = 1.0;
2025-01-16 23:35:35 +01:00
};
void on_process(local_data* ld, spa_io_position* position) {
int32_t n_samples = (int32_t)position->clock.duration;
2025-01-16 23:35:35 +01:00
float* mic_in = (float*)pw_filter_get_dsp_buffer(ld->mic_port, n_samples);
2025-01-27 09:58:36 +01:00
float* aec_out = (float*)pw_filter_get_dsp_buffer(ld->aec_port, n_samples);
float* listen_in = (float*)pw_filter_get_dsp_buffer(ld->listen_port, n_samples);
float* hp_out = (float*)pw_filter_get_dsp_buffer(ld->hp_port, n_samples);
2025-01-16 23:35:35 +01:00
2025-01-27 09:58:36 +01:00
if(ld->isdone) {return;}
if(mic_in == NULL || listen_in == NULL || hp_out == NULL) {return;}
2025-01-16 23:35:35 +01:00
std::memcpy(hp_out, listen_in, n_samples * sizeof(float));
2025-01-27 09:58:36 +01:00
ld->filter_buf.enqueue(listen_in, n_samples); // Inefficient, ugly, and bad
if(!ld->isavailable) {
if(aec_out != NULL){
std::memcpy(aec_out, mic_in, n_samples * sizeof(float));
}
if(ld->isearly) {
ld->record_buffer.enqueue(listen_in, n_samples);
ld->mic_buffer.enqueue(mic_in, n_samples);
}
} else {
ld->record_buffer.enqueue(listen_in, n_samples);
ld->mic_buffer.enqueue(mic_in, n_samples);
for(int i = 0; i < n_samples; ++i) {
double corr = 0.0;
// Compute correction to be done at time -(n_samples-1 - i)
for(int d = 0; d < ld->impulse_response.size(); ++d) {
corr += ld->impulse_response[d] * ld->filter_buf[-d-ld->initial_delay - (n_samples-1 - i)];
}
if(aec_out != NULL) {aec_out[i] = *mic_in++ - corr;}
float aec = *mic_in++ - corr;
ld->aec_outbuf.enqueue(&aec, 1);
}
}
ld->head_index += n_samples;
if(ld->head_index >= 10*samplerate) {
ld->isearly=false;
}
if(ld->head_index >= 20*samplerate) {
ld->isdone = true;
2025-01-16 23:35:35 +01:00
}
}
2025-01-16 23:35:35 +01:00
const struct pw_filter_events filter_events {
PW_VERSION_FILTER_EVENTS,
.process = (void(*)(void*, spa_io_position*))on_process,
};
2025-01-16 23:35:35 +01:00
2025-01-27 09:58:36 +01:00
template<size_t N, typename DT=float>
DT compute_power(circ_buffer<N, DT>& buf) {
double acc = 0;
2025-01-27 09:58:36 +01:00
for(int i = 0; i < N; ++i) {
acc += buf.raw[i] * buf.raw[i];
}
2025-01-27 09:58:36 +01:00
return (DT)acc/N;
2025-01-16 23:35:35 +01:00
}
2025-01-27 09:58:36 +01:00
template<size_t N, typename DT=float>
DT avg_correlation(circ_buffer<N, DT>& bufHP, circ_buffer<N, DT>& bufMIC, int delta) {
double acc = 0;
2025-01-27 09:58:36 +01:00
for(int i = 0; -i-delta > -N; ++i) {
acc += bufHP[-i - delta] * bufMIC[-i];
}
2025-01-27 09:58:36 +01:00
return acc / (N-delta);
}
2025-01-16 23:35:35 +01:00
int main(int argc, char** argv) {
local_data local{nullptr,nullptr,};
2025-01-16 23:35:35 +01:00
pw_init(&argc, &argv);
2025-01-27 09:58:36 +01:00
local.loop = pw_thread_loop_new("thread", NULL);
if(local.loop == NULL) {
std::cerr << "Could not create loop!\n";
return 1;
}
local.filter = pw_filter_new_simple(
2025-01-27 09:58:36 +01:00
pw_thread_loop_get_loop(local.loop),
"audio-filter",
2025-01-16 23:35:35 +01:00
pw_properties_new(
PW_KEY_MEDIA_TYPE, "Audio",
PW_KEY_MEDIA_CATEGORY, "Filter",
PW_KEY_MEDIA_ROLE, "DSP",
2025-01-16 23:35:35 +01:00
NULL),
&filter_events,
&local);
2025-01-16 23:35:35 +01:00
local.mic_port = pw_filter_add_port(local.filter,
PW_DIRECTION_INPUT,
PW_FILTER_PORT_FLAG_MAP_BUFFERS,
sizeof(void*),
pw_properties_new(
PW_KEY_FORMAT_DSP, "32 bit float mono audio",
PW_KEY_PORT_NAME, "micinput",
NULL),
NULL, 0);
2025-01-27 09:58:36 +01:00
local.aec_port = pw_filter_add_port(local.filter,
PW_DIRECTION_OUTPUT,
PW_FILTER_PORT_FLAG_MAP_BUFFERS,
sizeof(void*),
pw_properties_new(
PW_KEY_FORMAT_DSP, "32 bit float mono audio",
PW_KEY_PORT_NAME, "AEC",
PW_KEY_MEDIA_CATEGORY, "Capture",
NULL),
NULL, 0);
local.hp_port = pw_filter_add_port(local.filter,
PW_DIRECTION_OUTPUT,
PW_FILTER_PORT_FLAG_MAP_BUFFERS,
sizeof(void*),
pw_properties_new(
PW_KEY_FORMAT_DSP, "32 bit float mono audio",
PW_KEY_PORT_NAME, "output",
NULL),
NULL, 0);
local.listen_port = pw_filter_add_port(local.filter,
PW_DIRECTION_INPUT,
PW_FILTER_PORT_FLAG_MAP_BUFFERS,
sizeof(void*),
pw_properties_new(
PW_KEY_FORMAT_DSP, "32 bit float mono audio",
PW_KEY_PORT_NAME, "musinput",
NULL),
NULL, 0);
const spa_pod* params[2];
2025-01-16 23:35:35 +01:00
uint8_t buffer[1024];
spa_pod_builder b = SPA_POD_BUILDER_INIT(buffer, sizeof(buffer));
auto latinfo = SPA_PROCESS_LATENCY_INFO_INIT(.ns = 10 * SPA_NSEC_PER_MSEC);
auto formatinfo = SPA_AUDIO_INFO_RAW_INIT(.format = SPA_AUDIO_FORMAT_DSP_F32,
.rate = 48000,
.channels = 1);
2025-01-16 23:35:35 +01:00
params[0] = spa_process_latency_build(&b, SPA_PARAM_ProcessLatency, &latinfo);
params[1] = spa_format_audio_raw_build(&b, SPA_PARAM_EnumFormat, &formatinfo);
2025-01-16 23:35:35 +01:00
if(pw_filter_connect(local.filter,
PW_FILTER_FLAG_RT_PROCESS,
params, 1) < 0) {
std::fprintf(stderr, "Cannot connect\n");
return -1;
}
std::printf("Waiting for connection\n");
2025-01-27 09:58:36 +01:00
pw_thread_loop_start(local.loop);
while(local.isearly) {};
float max_crosscor = -1.0f;
int max_crosscor_index = 0;
std::array<float, samplerate/5> xcors;
for(int delta = 0; delta < samplerate/5; ++delta) {
xcors[delta] = avg_correlation(local.record_buffer, local.mic_buffer, delta);
if(std::abs(xcors[delta]) > max_crosscor) {
max_crosscor = std::abs(xcors[delta]);
max_crosscor_index = delta;
}
}
2025-01-16 23:35:35 +01:00
2025-01-27 09:58:36 +01:00
// If the x-fer coeff was 1, then avg_correlation would return the power input
float input_power = compute_power(local.record_buffer);
std::printf("input power : %e (%f dB)\n", input_power, 10*std::log10(input_power));
max_crosscor_index -= 16; // Allow for non-zero rise time (can you be smarter?)
std::printf("Estimated latency : %d samples\n", max_crosscor_index);
std::printf("h = [");
for(int i = 0; i < 128; ++i) {
std::printf("%e,", xcors[max_crosscor_index+i]);
//local.impulse_response[i] = xcors[max_crosscor_index+i]/input_power;
local.impulse_response[i] = 0.0;
}
for(int i = 128; i < 4096; ++i) {
std::printf("%e,", xcors[max_crosscor_index+i]);
}
std::printf("]\n");
local.initial_delay = max_crosscor_index;
local.isavailable = true;
while(!local.isdone) {}
pw_thread_loop_stop(local.loop);
pw_filter_destroy(local.filter);
2025-01-27 09:58:36 +01:00
pw_thread_loop_destroy(local.loop);
pw_deinit();
2025-01-16 23:35:35 +01:00
2025-01-27 09:58:36 +01:00
std::printf("hMIC = [");
for(int i = 0; i < 4096; ++i) {
float x = avg_correlation(local.record_buffer, local.mic_buffer, i + local.initial_delay);
std::printf("%e,", x);
}
std::printf("]\n");
2025-01-27 09:58:36 +01:00
std::printf("hAEC = [");
for(int i = 0; i < 4096; ++i) {
float x = avg_correlation(local.record_buffer, local.aec_outbuf, i + local.initial_delay);
std::printf("%e,", x);
}
std::printf("]\n");
float interleavedData[samplerate*5][2];
for(int i = 0; i < samplerate*5; ++i) {
interleavedData[i][0] = local.aec_outbuf[-i];
interleavedData[i][1] = local.mic_buffer[-i];
}
float mic_RMS = std::sqrt(compute_power(local.mic_buffer));
float aec_RMS = std::sqrt(compute_power(local.aec_outbuf));
std::printf("RMS power mic : %f dB (with AEC : %f dB)\n",
20*std::log10(mic_RMS),
20*std::log10(aec_RMS));
std::FILE* f = std::fopen("record.wav", "wb");
if(f == nullptr) {
std::fprintf(stderr, "Could not open record file\n");
return 1;
}
std::fwrite("RIFFxxxxWAVEfmt ", 16, 1, f);
uint16_t header[] = {16, 0, // BlocSize
3, // AudioFormat : IEEE 754 float
2, // 2 Channels
(uint16_t)samplerate, 0, // Samplerate
(samplerate * sizeof(float) * 2) % 65536, (samplerate * sizeof(float) * 2)/65536, // BitsPerSec
2 * sizeof(float),
sizeof(float)*8,
};
std::fwrite(header, sizeof(header), 1, f);
std::fwrite("data", 4, 1, f);
std::fwrite(interleavedData, sizeof(interleavedData), 1, f);
uint32_t size = (uint32_t)std::ftell(f) - 8;
std::fseek(f, 4, SEEK_SET);
std::fwrite(&size, sizeof(uint32_t), 1, f);
std::fclose(f);
2025-01-16 23:35:35 +01:00
return 0;
}