feat: Re-introduce locking the ratchet map/device list

This makes the locking much more intelligent, allowing us
to encrypt to/decrypt from groups while still being able to
bypass the lock for unaffiliated JIDs.
This commit is contained in:
PapaTutuWawa 2023-06-15 21:02:53 +02:00
parent 6e734ec0c3
commit da11e60f79
3 changed files with 203 additions and 76 deletions

View File

@ -20,6 +20,7 @@ import 'package:omemo_dart/src/omemo/encryption_result.dart';
import 'package:omemo_dart/src/omemo/errors.dart';
import 'package:omemo_dart/src/omemo/events.dart';
import 'package:omemo_dart/src/omemo/fingerprint.dart';
import 'package:omemo_dart/src/omemo/queue.dart';
import 'package:omemo_dart/src/omemo/ratchet_map_key.dart';
import 'package:omemo_dart/src/omemo/stanza.dart';
import 'package:omemo_dart/src/protobuf/schema.pb.dart';
@ -71,21 +72,20 @@ class OmemoManager {
final Future<void> Function(String jid) subscribeToDeviceListNodeImpl;
/// Map bare JID to its known devices
Map<String, List<int>> _deviceList = {};
final Map<String, List<int>> _deviceList = {};
/// Map bare JIDs to whether we already requested the device list once
final Map<String, bool> _deviceListRequested = {};
/// Map bare a ratchet key to its ratchet. Note that this is also locked by
/// _ratchetCriticalSectionLock.
Map<RatchetMapKey, OmemoDoubleRatchet> _ratchetMap = {};
final Map<RatchetMapKey, OmemoDoubleRatchet> _ratchetMap = {};
/// Map bare JID to whether we already tried to subscribe to the device list node.
final Map<String, bool> _subscriptionMap = {};
/// For preventing a race condition in encryption/decryption
final Map<String, Queue<Completer<void>>> _ratchetCriticalSectionQueue = {};
final Lock _ratchetCriticalSectionLock = Lock();
final RatchetAccessQueue _ratchetQueue = RatchetAccessQueue();
/// The OmemoManager's trust management
final TrustManager _trustManager;
@ -101,39 +101,6 @@ class OmemoManager {
StreamController<OmemoEvent>.broadcast();
Stream<OmemoEvent> get eventStream => _eventStreamController.stream;
/// Enter the critical section for performing cryptographic operations on the ratchets
Future<void> _enterRatchetCriticalSection(String jid) async {
return;
final completer = await _ratchetCriticalSectionLock.synchronized(() {
if (_ratchetCriticalSectionQueue.containsKey(jid)) {
final c = Completer<void>();
_ratchetCriticalSectionQueue[jid]!.addLast(c);
return c;
}
_ratchetCriticalSectionQueue[jid] = Queue();
return null;
});
if (completer != null) {
await completer.future;
}
}
/// Leave the critical section for the ratchets.
Future<void> _leaveRatchetCriticalSection(String jid) async {
return;
await _ratchetCriticalSectionLock.synchronized(() {
if (_ratchetCriticalSectionQueue.containsKey(jid)) {
if (_ratchetCriticalSectionQueue[jid]!.isEmpty) {
_ratchetCriticalSectionQueue.remove(jid);
} else {
_ratchetCriticalSectionQueue[jid]!.removeFirst().complete();
}
}
});
}
Future<Result<OmemoError, String?>> _decryptAndVerifyHmac(
List<int>? ciphertext,
List<int> keyAndHmac,
@ -242,7 +209,7 @@ class OmemoManager {
} else {
// Ratchet is not acknowledged
_log.finest('Sending acknowledgement heartbeat to ${key.jid}');
await ratchetAcknowledged(key.jid, key.deviceId);
await _ratchetAcknowledged(key.jid, key.deviceId);
await sendEmptyOmemoMessageImpl(
await _onOutgoingStanzaImpl(
OmemoOutgoingStanza(
@ -257,13 +224,10 @@ class OmemoManager {
///
Future<DecryptionResult> onIncomingStanza(OmemoIncomingStanza stanza) async {
// NOTE: We do this so that we cannot forget to acquire and free the critical
// section.
await _enterRatchetCriticalSection(stanza.bareSenderJid);
final result = await _onIncomingStanzaImpl(stanza);
await _leaveRatchetCriticalSection(stanza.bareSenderJid);
return result;
return _ratchetQueue.synchronized(
[stanza.bareSenderJid],
() => _onIncomingStanzaImpl(stanza),
);
}
Future<DecryptionResult> _onIncomingStanzaImpl(OmemoIncomingStanza stanza) async {
@ -419,7 +383,7 @@ class OmemoManager {
if (!_deviceList[stanza.bareSenderJid]!.contains(stanza.senderDeviceId)) {
_deviceList[stanza.bareSenderJid]!.add(stanza.senderDeviceId);
}
await sendOmemoHeartbeat(stanza.bareSenderJid);
await _sendOmemoHeartbeat(stanza.bareSenderJid);
return DecryptionResult(
null,
@ -489,13 +453,10 @@ class OmemoManager {
}
Future<EncryptionResult> onOutgoingStanza(OmemoOutgoingStanza stanza) async {
// TODO: Be more smart about the locking
// TODO: Do we even need to lock?
await _enterRatchetCriticalSection(stanza.recipientJids.first);
final result = await _onOutgoingStanzaImpl(stanza);
await _leaveRatchetCriticalSection(stanza.recipientJids.first);
return result;
return _ratchetQueue.synchronized(
stanza.recipientJids,
() => _onOutgoingStanzaImpl(stanza),
);
}
Future<EncryptionResult> _onOutgoingStanzaImpl(OmemoOutgoingStanza stanza) async {
@ -697,9 +658,17 @@ class OmemoManager {
);
}
// Sends an empty OMEMO message (heartbeat) to [jid].
/// Sends an empty OMEMO message (heartbeat) to [jid].
Future<void> sendOmemoHeartbeat(String jid) async {
final result = await onOutgoingStanza(
await _ratchetQueue.synchronized(
[jid],
() => _sendOmemoHeartbeat(jid),
);
}
/// Like [sendOmemoHeartbeat], but does not acquire the lock for [jid].
Future<void> _sendOmemoHeartbeat(String jid) async {
final result = await _onOutgoingStanzaImpl(
OmemoOutgoingStanza(
[jid],
null,
@ -710,8 +679,9 @@ class OmemoManager {
/// Removes all ratchets associated with [jid].
Future<void> removeAllRatchets(String jid) async {
await _enterRatchetCriticalSection(jid);
await _ratchetQueue.synchronized(
[jid],
() async {
for (final device in _deviceList[jid] ?? <int>[]) {
// Remove the ratchet and commit
_ratchetMap.remove(RatchetMapKey(jid, device));
@ -722,8 +692,8 @@ class OmemoManager {
_deviceList.remove(jid);
_deviceListRequested.remove(jid);
_eventStreamController.add(DeviceListModifiedEvent(jid, []));
await _leaveRatchetCriticalSection(jid);
},
);
}
/// To be called when a update to the device list of [jid] is returned.
@ -748,8 +718,14 @@ class OmemoManager {
// Mark the ratchet [jid]:[device] as acknowledged.
Future<void> ratchetAcknowledged(String jid, int device) async {
await _enterRatchetCriticalSection(jid);
await _ratchetQueue.synchronized(
[jid],
() => _ratchetAcknowledged(jid, device),
);
}
/// Like [ratchetAcknowledged], but does not acquire the lock for [jid].
Future<void> _ratchetAcknowledged(String jid, int device) async {
final ratchetKey = RatchetMapKey(jid, device);
if (!_ratchetMap.containsKey(ratchetKey)) {
_log.warning('Cannot mark $jid:$device as acknowledged as the ratchet does not exist');
@ -760,8 +736,6 @@ class OmemoManager {
RatchetModifiedEvent(jid, device, ratchet, false, false),
);
}
await _leaveRatchetCriticalSection(jid);
}
/// If ratchets with [jid] exists, returns a list of fingerprints for each
@ -769,13 +743,13 @@ class OmemoManager {
///
/// If not ratchets exists, returns null.
Future<List<DeviceFingerprint>?> getFingerprintsForJid(String jid) async {
await _getFingerprintsForJidImpl(jid);
final result = await _getFingerprintsForJidImpl(jid);
await _leaveRatchetCriticalSection(jid);
return result;
return _ratchetQueue.synchronized(
[jid],
() => _getFingerprintsForJidImpl(jid),
);
}
/// Same as [getFingerprintsForJid], but without acquiring the lock for [jid].
Future<List<DeviceFingerprint>?> _getFingerprintsForJidImpl(String jid) async {
// Check if we know of the JID.
if (!_deviceList.containsKey(jid)) {

97
lib/src/omemo/queue.dart Normal file
View File

@ -0,0 +1,97 @@
import 'dart:async';
import 'dart:collection';
import 'package:meta/meta.dart';
import 'package:synchronized/synchronized.dart';
extension UtilAllMethodsList<T> on List<T> {
void removeAll(List<T> values) {
for (final value in values) {
remove(value);
}
}
bool containsAll(List<T> values) {
for (final value in values) {
if (!contains(value)) {
return false;
}
}
return true;
}
}
class _RatchetAccessQueueEntry {
_RatchetAccessQueueEntry(
this.jids,
this.completer,
);
final List<String> jids;
final Completer<void> completer;
}
class RatchetAccessQueue {
final Queue<_RatchetAccessQueueEntry> _queue = Queue();
@visibleForTesting
final List<String> runningOperations = List<String>.empty(growable: true);
final Lock lock = Lock();
bool canBypass(List<String> jids) {
for (final jid in jids) {
if (runningOperations.contains(jid)) {
return false;
}
}
return true;
}
Future<void> enterCriticalSection(List<String> jids) async {
final completer = await lock.synchronized<Completer<void>?>(() {
if (canBypass(jids)) {
runningOperations.addAll(jids);
return null;
}
final completer = Completer<void>();
_queue.add(
_RatchetAccessQueueEntry(
jids,
completer,
),
);
return completer;
});
await completer?.future;
}
Future<void> leaveCriticalSection(List<String> jids) async {
await lock.synchronized(() {
runningOperations.removeAll(jids);
while (_queue.isNotEmpty) {
if (canBypass(_queue.first.jids)) {
final head = _queue.removeFirst();
runningOperations.addAll(head.jids);
head.completer.complete();
} else {
break;
}
}
});
}
Future<T> synchronized<T>(List<String> jids, Future<T> Function() function) async {
await enterCriticalSection(jids);
final result = await function();
await leaveCriticalSection(jids);
return result;
}
}

56
test/queue_test.dart Normal file
View File

@ -0,0 +1,56 @@
import 'dart:async';
import 'package:omemo_dart/src/omemo/queue.dart';
import 'package:test/test.dart';
Future<void> testMethod(RatchetAccessQueue queue, List<String> data, int duration) async {
await queue.enterCriticalSection(data);
await Future<void>.delayed(Duration(seconds: duration));
await queue.leaveCriticalSection(data);
}
void main() {
test('Test blocking due to conflicts', () async {
final queue = RatchetAccessQueue();
unawaited(testMethod(queue, ['a', 'b', 'c'], 5));
unawaited(testMethod(queue, ['a'], 4));
await Future<void>.delayed(const Duration(seconds: 1));
expect(
queue.runningOperations.containsAll(['a', 'b', 'c']),
isTrue,
);
expect(queue.runningOperations.length, 3);
await Future<void>.delayed(const Duration(seconds: 4));
expect(
queue.runningOperations.containsAll(['a']),
isTrue,
);
expect(queue.runningOperations.length, 1);
await Future<void>.delayed(const Duration(seconds: 4));
expect(queue.runningOperations.length, 0);
});
test('Test not blocking due to no conflicts', () async {
final queue = RatchetAccessQueue();
unawaited(testMethod(queue, ['a', 'b'], 5));
unawaited(testMethod(queue, ['c'], 5));
unawaited(testMethod(queue, ['d'], 5));
await Future<void>.delayed(const Duration(seconds: 1));
expect(queue.runningOperations.length, 4);
expect(
queue.runningOperations.containsAll([
'a', 'b', 'c', 'd',
]),
isTrue,
);
});
}