[Audio] Fix PCM format and use PipedAudioStream in sources (#16111)

* [Audio] Fix pcm format and use PipedAudioStream
* fix rustpotter format changes

---------

Signed-off-by: Miguel Álvarez <miguelwork92@gmail.com>
Signed-off-by: Ciprian Pascu <contact@ciprianpascu.ro>
This commit is contained in:
GiviMAD 2024-02-04 13:07:54 -08:00 committed by Ciprian Pascu
parent 999a1f9a1c
commit 3d0d115260
6 changed files with 82 additions and 168 deletions

View File

@ -14,13 +14,8 @@ package org.openhab.binding.pulseaudio.internal;
import java.io.IOException;
import java.io.InputStream;
import java.io.InterruptedIOException;
import java.io.PipedInputStream;
import java.io.PipedOutputStream;
import java.net.Socket;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.Future;
import java.util.concurrent.ScheduledExecutorService;
@ -31,6 +26,7 @@ import org.openhab.core.audio.AudioException;
import org.openhab.core.audio.AudioFormat;
import org.openhab.core.audio.AudioSource;
import org.openhab.core.audio.AudioStream;
import org.openhab.core.audio.PipedAudioStream;
import org.openhab.core.common.ThreadPoolManager;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -45,25 +41,23 @@ import org.slf4j.LoggerFactory;
public class PulseAudioAudioSource extends PulseaudioSimpleProtocolStream implements AudioSource {
private final Logger logger = LoggerFactory.getLogger(PulseAudioAudioSource.class);
private final ConcurrentLinkedQueue<PipedOutputStream> pipeOutputs = new ConcurrentLinkedQueue<>();
private final PipedAudioStream.Group streamGroup;
private final ScheduledExecutorService executor;
private final AudioFormat streamFormat;
private @Nullable Future<?> pipeWriteTask;
public PulseAudioAudioSource(PulseaudioHandler pulseaudioHandler, ScheduledExecutorService scheduler) {
super(pulseaudioHandler, scheduler);
streamFormat = pulseaudioHandler.getSourceAudioFormat();
executor = ThreadPoolManager
.getScheduledPool("OH-binding-" + pulseaudioHandler.getThing().getUID() + "-source");
streamGroup = PipedAudioStream.newGroup(streamFormat);
}
@Override
public Set<AudioFormat> getSupportedFormats() {
var supportedFormats = new HashSet<AudioFormat>();
var audioFormat = pulseaudioHandler.getSourceAudioFormat();
if (audioFormat != null) {
supportedFormats.add(audioFormat);
}
return supportedFormats;
return Set.of(streamFormat);
}
@Override
@ -76,27 +70,18 @@ public class PulseAudioAudioSource extends PulseaudioSimpleProtocolStream implem
if (clientSocketLocal == null) {
break;
}
var sourceFormat = pulseaudioHandler.getSourceAudioFormat();
if (sourceFormat == null) {
throw new AudioException("Unable to get source audio format");
}
if (!audioFormat.isCompatible(sourceFormat)) {
if (!audioFormat.isCompatible(streamFormat)) {
throw new AudioException("Incompatible audio format requested");
}
var pipeOutput = new PipedOutputStream();
var pipeInput = new PipedInputStream(pipeOutput, 1024 * 10) {
@Override
public void close() throws IOException {
unregisterPipe(pipeOutput);
super.close();
}
};
registerPipe(pipeOutput);
// get raw audio from the pulse audio socket
return new PulseAudioStream(sourceFormat, pipeInput, () -> {
// ensure pipe is writing
startPipeWrite();
var audioStream = streamGroup.getAudioStreamInGroup();
audioStream.onClose(() -> {
minusClientCount();
stopPipeWriteTask();
});
addClientCount();
startPipeWrite();
// get raw audio from the pulse audio socket
return audioStream;
} catch (IOException e) {
disconnect(); // disconnect to force clear connection in case of socket not cleanly shutdown
if (countAttempt == 2) { // we won't retry : log and quit
@ -120,14 +105,6 @@ public class PulseAudioAudioSource extends PulseaudioSimpleProtocolStream implem
throw new AudioException("Unable to create input stream");
}
private synchronized void registerPipe(PipedOutputStream pipeOutput) {
boolean isAdded = this.pipeOutputs.add(pipeOutput);
if (isAdded) {
addClientCount();
}
startPipeWrite();
}
/**
* As startPipeWrite is called for every chunk read,
* this wrapper method make the test before effectively
@ -143,35 +120,16 @@ public class PulseAudioAudioSource extends PulseaudioSimpleProtocolStream implem
if (this.pipeWriteTask == null) {
this.pipeWriteTask = executor.submit(() -> {
int lengthRead;
byte[] buffer = new byte[1024];
byte[] buffer = new byte[1200];
int readRetries = 3;
while (!pipeOutputs.isEmpty()) {
while (!streamGroup.isEmpty()) {
var stream = getSourceInputStream();
if (stream != null) {
try {
lengthRead = stream.read(buffer);
readRetries = 3;
for (var output : pipeOutputs) {
try {
output.write(buffer, 0, lengthRead);
if (pipeOutputs.contains(output)) {
output.flush();
}
} catch (InterruptedIOException e) {
if (pipeOutputs.isEmpty()) {
// task has been ended while writing
return;
}
logger.warn("InterruptedIOException while writing from pulse source to pipe: {}",
getExceptionMessage(e));
} catch (IOException e) {
logger.warn("IOException while writing from pulse source to pipe: {}",
getExceptionMessage(e));
} catch (RuntimeException e) {
logger.warn("RuntimeException while writing from pulse source to pipe: {}",
getExceptionMessage(e));
}
}
streamGroup.write(buffer, 0, lengthRead);
streamGroup.flush();
} catch (IOException e) {
logger.warn("IOException while reading from pulse source: {}", getExceptionMessage(e));
if (readRetries == 0) {
@ -192,25 +150,9 @@ public class PulseAudioAudioSource extends PulseaudioSimpleProtocolStream implem
}
}
private synchronized void unregisterPipe(PipedOutputStream pipeOutput) {
boolean isRemoved = this.pipeOutputs.remove(pipeOutput);
if (isRemoved) {
minusClientCount();
}
try {
Thread.sleep(0);
} catch (InterruptedException ignored) {
}
stopPipeWriteTask();
try {
pipeOutput.close();
} catch (IOException ignored) {
}
}
private synchronized void stopPipeWriteTask() {
var pipeWriteTask = this.pipeWriteTask;
if (pipeOutputs.isEmpty() && pipeWriteTask != null) {
if (streamGroup.isEmpty() && pipeWriteTask != null) {
pipeWriteTask.cancel(true);
this.pipeWriteTask = null;
}
@ -243,58 +185,4 @@ public class PulseAudioAudioSource extends PulseaudioSimpleProtocolStream implem
stopPipeWriteTask();
super.disconnect();
}
static class PulseAudioStream extends AudioStream {
private final Logger logger = LoggerFactory.getLogger(PulseAudioAudioSource.class);
private final AudioFormat format;
private final InputStream input;
private final Runnable activity;
private boolean closed = false;
public PulseAudioStream(AudioFormat format, InputStream input, Runnable activity) {
this.input = input;
this.format = format;
this.activity = activity;
}
@Override
public AudioFormat getFormat() {
return format;
}
@Override
public int read() throws IOException {
byte[] b = new byte[1];
int bytesRead = read(b);
if (-1 == bytesRead) {
return bytesRead;
}
Byte bb = Byte.valueOf(b[0]);
return bb.intValue();
}
@Override
public int read(byte @Nullable [] b) throws IOException {
return read(b, 0, b == null ? 0 : b.length);
}
@Override
public int read(byte @Nullable [] b, int off, int len) throws IOException {
if (b == null) {
throw new IOException("Buffer is null");
}
logger.trace("reading from pulseaudio stream");
if (closed) {
throw new IOException("Stream is closed");
}
activity.run();
return input.read(b, off, len);
}
@Override
public void close() throws IOException {
closed = true;
input.close();
}
}
}

View File

@ -469,39 +469,50 @@ public class PulseaudioHandler extends BaseThingHandler {
.orElse(simpleTcpPort);
}
public @Nullable AudioFormat getSourceAudioFormat() {
public AudioFormat getSourceAudioFormat() {
String simpleFormat = ((String) getThing().getConfiguration().get(DEVICE_PARAMETER_AUDIO_SOURCE_FORMAT));
BigDecimal simpleRate = ((BigDecimal) getThing().getConfiguration().get(DEVICE_PARAMETER_AUDIO_SOURCE_RATE));
BigDecimal simpleChannels = ((BigDecimal) getThing().getConfiguration()
.get(DEVICE_PARAMETER_AUDIO_SOURCE_CHANNELS));
AudioFormat fallback = new AudioFormat(AudioFormat.CONTAINER_NONE, AudioFormat.CODEC_PCM_SIGNED, false, 16,
16 * 16000, 16000L, 1);
if (simpleFormat == null || simpleRate == null || simpleChannels == null) {
return null;
return fallback;
}
int sampleRateAllChannels = simpleRate.intValue() * simpleChannels.intValue();
switch (simpleFormat) {
case "u8":
return new AudioFormat(AudioFormat.CONTAINER_WAVE, AudioFormat.CODEC_PCM_UNSIGNED, null, 8, 1,
simpleRate.longValue(), simpleChannels.intValue());
case "s16le":
return new AudioFormat(AudioFormat.CONTAINER_WAVE, AudioFormat.CODEC_PCM_SIGNED, false, 16, 1,
simpleRate.longValue(), simpleChannels.intValue());
case "s16be":
return new AudioFormat(AudioFormat.CONTAINER_WAVE, AudioFormat.CODEC_PCM_SIGNED, true, 16, 1,
simpleRate.longValue(), simpleChannels.intValue());
case "s24le":
return new AudioFormat(AudioFormat.CONTAINER_WAVE, AudioFormat.CODEC_PCM_SIGNED, false, 24, 1,
simpleRate.longValue(), simpleChannels.intValue());
case "s24be":
return new AudioFormat(AudioFormat.CONTAINER_WAVE, AudioFormat.CODEC_PCM_SIGNED, true, 24, 1,
simpleRate.longValue(), simpleChannels.intValue());
case "s32le":
return new AudioFormat(AudioFormat.CONTAINER_WAVE, AudioFormat.CODEC_PCM_SIGNED, false, 32, 1,
simpleRate.longValue(), simpleChannels.intValue());
case "s32be":
return new AudioFormat(AudioFormat.CONTAINER_WAVE, AudioFormat.CODEC_PCM_SIGNED, true, 32, 1,
simpleRate.longValue(), simpleChannels.intValue());
default:
case "u8" -> {
return new AudioFormat(AudioFormat.CONTAINER_NONE, AudioFormat.CODEC_PCM_UNSIGNED, null, 8,
8 * sampleRateAllChannels, simpleRate.longValue(), simpleChannels.intValue());
}
case "s16le" -> {
return new AudioFormat(AudioFormat.CONTAINER_NONE, AudioFormat.CODEC_PCM_SIGNED, false, 16,
16 * sampleRateAllChannels, simpleRate.longValue(), simpleChannels.intValue());
}
case "s16be" -> {
return new AudioFormat(AudioFormat.CONTAINER_NONE, AudioFormat.CODEC_PCM_SIGNED, true, 16,
16 * sampleRateAllChannels, simpleRate.longValue(), simpleChannels.intValue());
}
case "s24le" -> {
return new AudioFormat(AudioFormat.CONTAINER_NONE, AudioFormat.CODEC_PCM_SIGNED, false, 24,
24 * sampleRateAllChannels, simpleRate.longValue(), simpleChannels.intValue());
}
case "s24be" -> {
return new AudioFormat(AudioFormat.CONTAINER_NONE, AudioFormat.CODEC_PCM_SIGNED, true, 24,
24 * sampleRateAllChannels, simpleRate.longValue(), simpleChannels.intValue());
}
case "s32le" -> {
return new AudioFormat(AudioFormat.CONTAINER_NONE, AudioFormat.CODEC_PCM_SIGNED, false, 32,
32 * sampleRateAllChannels, simpleRate.longValue(), simpleChannels.intValue());
}
case "s32be" -> {
return new AudioFormat(AudioFormat.CONTAINER_NONE, AudioFormat.CODEC_PCM_SIGNED, true, 32,
32 * sampleRateAllChannels, simpleRate.longValue(), simpleChannels.intValue());
}
default -> {
logger.warn("unsupported format {}", simpleFormat);
return null;
return fallback;
}
}
}

View File

@ -29,6 +29,7 @@ import org.eclipse.jdt.annotation.NonNullByDefault;
import org.eclipse.jdt.annotation.Nullable;
import org.openhab.core.audio.AudioFormat;
import org.openhab.core.audio.AudioStream;
import org.openhab.core.audio.utils.AudioWaveUtils;
import org.openhab.core.auth.client.oauth2.AccessTokenResponse;
import org.openhab.core.auth.client.oauth2.OAuthClientService;
import org.openhab.core.auth.client.oauth2.OAuthException;
@ -144,12 +145,8 @@ public class GoogleSTTService implements STTService {
@Override
public Set<AudioFormat> getSupportedFormats() {
return Set.of(
new AudioFormat(AudioFormat.CONTAINER_WAVE, AudioFormat.CODEC_PCM_SIGNED, false, 16, null, 16000L),
new AudioFormat(AudioFormat.CONTAINER_OGG, "OPUS", null, null, null, 8000L),
new AudioFormat(AudioFormat.CONTAINER_OGG, "OPUS", null, null, null, 12000L),
new AudioFormat(AudioFormat.CONTAINER_OGG, "OPUS", null, null, null, 16000L),
new AudioFormat(AudioFormat.CONTAINER_OGG, "OPUS", null, null, null, 24000L),
new AudioFormat(AudioFormat.CONTAINER_OGG, "OPUS", null, null, null, 48000L));
new AudioFormat(AudioFormat.CONTAINER_NONE, AudioFormat.CODEC_PCM_SIGNED, false, 16, null, 16000L),
new AudioFormat(AudioFormat.CONTAINER_WAVE, AudioFormat.CODEC_PCM_SIGNED, false, 16, null, 16000L));
}
@Override
@ -248,8 +245,6 @@ public class GoogleSTTService implements STTService {
RecognitionConfig.AudioEncoding streamEncoding;
if (AudioFormat.WAV.isCompatible(streamFormat)) {
streamEncoding = RecognitionConfig.AudioEncoding.LINEAR16;
} else if (AudioFormat.OGG.isCompatible(streamFormat)) {
streamEncoding = RecognitionConfig.AudioEncoding.OGG_OPUS;
} else {
logger.debug("Unsupported format {}", streamFormat);
return;
@ -271,6 +266,9 @@ public class GoogleSTTService implements STTService {
final int bufferSize = 6400;
int numBytesRead;
int remaining = bufferSize;
if (AudioFormat.CONTAINER_WAVE.equals(streamFormat.getContainer())) {
AudioWaveUtils.removeFMT(audioStream);
}
byte[] audioBuffer = new byte[bufferSize];
while (!aborted.get() && !responseObserver.isDone()) {
numBytesRead = audioStream.read(audioBuffer, bufferSize - remaining, remaining);

View File

@ -109,6 +109,9 @@ public class RustpotterKSService implements KSService {
@Override
public Set<AudioFormat> getSupportedFormats() {
return Set.of(
new AudioFormat(AudioFormat.CONTAINER_NONE, AudioFormat.CODEC_PCM_SIGNED, false, 16, null, 16000L),
new AudioFormat(AudioFormat.CONTAINER_NONE, AudioFormat.CODEC_PCM_SIGNED, null, 16, null, null),
new AudioFormat(AudioFormat.CONTAINER_NONE, AudioFormat.CODEC_PCM_SIGNED, null, 32, null, null),
new AudioFormat(AudioFormat.CONTAINER_WAVE, AudioFormat.CODEC_PCM_SIGNED, false, 16, null, 16000L),
new AudioFormat(AudioFormat.CONTAINER_WAVE, AudioFormat.CODEC_PCM_SIGNED, null, 16, null, null),
new AudioFormat(AudioFormat.CONTAINER_WAVE, AudioFormat.CODEC_PCM_SIGNED, null, 32, null, null));

View File

@ -30,6 +30,7 @@ import org.eclipse.jdt.annotation.Nullable;
import org.openhab.core.OpenHAB;
import org.openhab.core.audio.AudioFormat;
import org.openhab.core.audio.AudioStream;
import org.openhab.core.audio.utils.AudioWaveUtils;
import org.openhab.core.common.ThreadPoolManager;
import org.openhab.core.config.core.ConfigurableService;
import org.openhab.core.config.core.Configuration;
@ -159,6 +160,7 @@ public class VoskSTTService implements STTService {
@Override
public Set<AudioFormat> getSupportedFormats() {
return Set.of(
new AudioFormat(AudioFormat.CONTAINER_NONE, AudioFormat.CODEC_PCM_SIGNED, false, null, null, 16000L),
new AudioFormat(AudioFormat.CONTAINER_WAVE, AudioFormat.CODEC_PCM_SIGNED, false, null, null, 16000L));
}
@ -167,10 +169,14 @@ public class VoskSTTService implements STTService {
throws STTException {
AtomicBoolean aborted = new AtomicBoolean(false);
try {
var frequency = audioStream.getFormat().getFrequency();
AudioFormat format = audioStream.getFormat();
var frequency = format.getFrequency();
if (frequency == null) {
throw new IOException("missing audio stream frequency");
}
if (AudioFormat.CONTAINER_WAVE.equals(format.getContainer())) {
AudioWaveUtils.removeFMT(audioStream);
}
backgroundRecognize(sttListener, audioStream, frequency, aborted);
} catch (IOException e) {
throw new STTException(e);

View File

@ -14,6 +14,7 @@ package org.openhab.voice.watsonstt.internal;
import static org.openhab.voice.watsonstt.internal.WatsonSTTConstants.*;
import java.io.IOException;
import java.util.List;
import java.util.Locale;
import java.util.Map;
@ -28,6 +29,7 @@ import org.eclipse.jdt.annotation.NonNullByDefault;
import org.eclipse.jdt.annotation.Nullable;
import org.openhab.core.audio.AudioFormat;
import org.openhab.core.audio.AudioStream;
import org.openhab.core.audio.utils.AudioWaveUtils;
import org.openhab.core.common.ThreadPoolManager;
import org.openhab.core.config.core.ConfigurableService;
import org.openhab.core.config.core.Configuration;
@ -122,8 +124,7 @@ public class WatsonSTTService implements STTService {
@Override
public Set<AudioFormat> getSupportedFormats() {
return Set.of(AudioFormat.WAV, AudioFormat.OGG, new AudioFormat("OGG", "OPUS", null, null, null, null),
AudioFormat.MP3);
return Set.of(AudioFormat.PCM_SIGNED, AudioFormat.WAV);
}
@Override
@ -147,6 +148,13 @@ public class WatsonSTTService implements STTService {
final AtomicReference<@Nullable WebSocket> socketRef = new AtomicReference<>();
final AtomicBoolean aborted = new AtomicBoolean(false);
executor.submit(() -> {
if (AudioFormat.CONTAINER_WAVE.equals(audioStream.getFormat().getContainer())) {
try {
AudioWaveUtils.removeFMT(audioStream);
} catch (IOException e) {
logger.warn("Error removing format header: {}", e.getMessage());
}
}
socketRef.set(stt.recognizeUsingWebSocket(wsOptions,
new TranscriptionListener(socketRef, sttListener, config, aborted)));
});