feat: Re-introduce locking the ratchet map/device list
This makes the locking much more intelligent, allowing us to encrypt to/decrypt from groups while still being able to bypass the lock for unaffiliated JIDs.
This commit is contained in:
parent
6e734ec0c3
commit
da11e60f79
@ -20,6 +20,7 @@ import 'package:omemo_dart/src/omemo/encryption_result.dart';
|
||||
import 'package:omemo_dart/src/omemo/errors.dart';
|
||||
import 'package:omemo_dart/src/omemo/events.dart';
|
||||
import 'package:omemo_dart/src/omemo/fingerprint.dart';
|
||||
import 'package:omemo_dart/src/omemo/queue.dart';
|
||||
import 'package:omemo_dart/src/omemo/ratchet_map_key.dart';
|
||||
import 'package:omemo_dart/src/omemo/stanza.dart';
|
||||
import 'package:omemo_dart/src/protobuf/schema.pb.dart';
|
||||
@ -71,21 +72,20 @@ class OmemoManager {
|
||||
final Future<void> Function(String jid) subscribeToDeviceListNodeImpl;
|
||||
|
||||
/// Map bare JID to its known devices
|
||||
Map<String, List<int>> _deviceList = {};
|
||||
final Map<String, List<int>> _deviceList = {};
|
||||
|
||||
/// Map bare JIDs to whether we already requested the device list once
|
||||
final Map<String, bool> _deviceListRequested = {};
|
||||
|
||||
/// Map bare a ratchet key to its ratchet. Note that this is also locked by
|
||||
/// _ratchetCriticalSectionLock.
|
||||
Map<RatchetMapKey, OmemoDoubleRatchet> _ratchetMap = {};
|
||||
final Map<RatchetMapKey, OmemoDoubleRatchet> _ratchetMap = {};
|
||||
|
||||
/// Map bare JID to whether we already tried to subscribe to the device list node.
|
||||
final Map<String, bool> _subscriptionMap = {};
|
||||
|
||||
/// For preventing a race condition in encryption/decryption
|
||||
final Map<String, Queue<Completer<void>>> _ratchetCriticalSectionQueue = {};
|
||||
final Lock _ratchetCriticalSectionLock = Lock();
|
||||
final RatchetAccessQueue _ratchetQueue = RatchetAccessQueue();
|
||||
|
||||
/// The OmemoManager's trust management
|
||||
final TrustManager _trustManager;
|
||||
@ -101,39 +101,6 @@ class OmemoManager {
|
||||
StreamController<OmemoEvent>.broadcast();
|
||||
Stream<OmemoEvent> get eventStream => _eventStreamController.stream;
|
||||
|
||||
/// Enter the critical section for performing cryptographic operations on the ratchets
|
||||
Future<void> _enterRatchetCriticalSection(String jid) async {
|
||||
return;
|
||||
final completer = await _ratchetCriticalSectionLock.synchronized(() {
|
||||
if (_ratchetCriticalSectionQueue.containsKey(jid)) {
|
||||
final c = Completer<void>();
|
||||
_ratchetCriticalSectionQueue[jid]!.addLast(c);
|
||||
return c;
|
||||
}
|
||||
|
||||
_ratchetCriticalSectionQueue[jid] = Queue();
|
||||
return null;
|
||||
});
|
||||
|
||||
if (completer != null) {
|
||||
await completer.future;
|
||||
}
|
||||
}
|
||||
|
||||
/// Leave the critical section for the ratchets.
|
||||
Future<void> _leaveRatchetCriticalSection(String jid) async {
|
||||
return;
|
||||
await _ratchetCriticalSectionLock.synchronized(() {
|
||||
if (_ratchetCriticalSectionQueue.containsKey(jid)) {
|
||||
if (_ratchetCriticalSectionQueue[jid]!.isEmpty) {
|
||||
_ratchetCriticalSectionQueue.remove(jid);
|
||||
} else {
|
||||
_ratchetCriticalSectionQueue[jid]!.removeFirst().complete();
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Future<Result<OmemoError, String?>> _decryptAndVerifyHmac(
|
||||
List<int>? ciphertext,
|
||||
List<int> keyAndHmac,
|
||||
@ -242,7 +209,7 @@ class OmemoManager {
|
||||
} else {
|
||||
// Ratchet is not acknowledged
|
||||
_log.finest('Sending acknowledgement heartbeat to ${key.jid}');
|
||||
await ratchetAcknowledged(key.jid, key.deviceId);
|
||||
await _ratchetAcknowledged(key.jid, key.deviceId);
|
||||
await sendEmptyOmemoMessageImpl(
|
||||
await _onOutgoingStanzaImpl(
|
||||
OmemoOutgoingStanza(
|
||||
@ -257,13 +224,10 @@ class OmemoManager {
|
||||
|
||||
///
|
||||
Future<DecryptionResult> onIncomingStanza(OmemoIncomingStanza stanza) async {
|
||||
// NOTE: We do this so that we cannot forget to acquire and free the critical
|
||||
// section.
|
||||
await _enterRatchetCriticalSection(stanza.bareSenderJid);
|
||||
final result = await _onIncomingStanzaImpl(stanza);
|
||||
await _leaveRatchetCriticalSection(stanza.bareSenderJid);
|
||||
|
||||
return result;
|
||||
return _ratchetQueue.synchronized(
|
||||
[stanza.bareSenderJid],
|
||||
() => _onIncomingStanzaImpl(stanza),
|
||||
);
|
||||
}
|
||||
|
||||
Future<DecryptionResult> _onIncomingStanzaImpl(OmemoIncomingStanza stanza) async {
|
||||
@ -419,7 +383,7 @@ class OmemoManager {
|
||||
if (!_deviceList[stanza.bareSenderJid]!.contains(stanza.senderDeviceId)) {
|
||||
_deviceList[stanza.bareSenderJid]!.add(stanza.senderDeviceId);
|
||||
}
|
||||
await sendOmemoHeartbeat(stanza.bareSenderJid);
|
||||
await _sendOmemoHeartbeat(stanza.bareSenderJid);
|
||||
|
||||
return DecryptionResult(
|
||||
null,
|
||||
@ -489,13 +453,10 @@ class OmemoManager {
|
||||
}
|
||||
|
||||
Future<EncryptionResult> onOutgoingStanza(OmemoOutgoingStanza stanza) async {
|
||||
// TODO: Be more smart about the locking
|
||||
// TODO: Do we even need to lock?
|
||||
await _enterRatchetCriticalSection(stanza.recipientJids.first);
|
||||
final result = await _onOutgoingStanzaImpl(stanza);
|
||||
await _leaveRatchetCriticalSection(stanza.recipientJids.first);
|
||||
|
||||
return result;
|
||||
return _ratchetQueue.synchronized(
|
||||
stanza.recipientJids,
|
||||
() => _onOutgoingStanzaImpl(stanza),
|
||||
);
|
||||
}
|
||||
|
||||
Future<EncryptionResult> _onOutgoingStanzaImpl(OmemoOutgoingStanza stanza) async {
|
||||
@ -697,9 +658,17 @@ class OmemoManager {
|
||||
);
|
||||
}
|
||||
|
||||
// Sends an empty OMEMO message (heartbeat) to [jid].
|
||||
/// Sends an empty OMEMO message (heartbeat) to [jid].
|
||||
Future<void> sendOmemoHeartbeat(String jid) async {
|
||||
final result = await onOutgoingStanza(
|
||||
await _ratchetQueue.synchronized(
|
||||
[jid],
|
||||
() => _sendOmemoHeartbeat(jid),
|
||||
);
|
||||
}
|
||||
|
||||
/// Like [sendOmemoHeartbeat], but does not acquire the lock for [jid].
|
||||
Future<void> _sendOmemoHeartbeat(String jid) async {
|
||||
final result = await _onOutgoingStanzaImpl(
|
||||
OmemoOutgoingStanza(
|
||||
[jid],
|
||||
null,
|
||||
@ -710,20 +679,21 @@ class OmemoManager {
|
||||
|
||||
/// Removes all ratchets associated with [jid].
|
||||
Future<void> removeAllRatchets(String jid) async {
|
||||
await _enterRatchetCriticalSection(jid);
|
||||
await _ratchetQueue.synchronized(
|
||||
[jid],
|
||||
() async {
|
||||
for (final device in _deviceList[jid] ?? <int>[]) {
|
||||
// Remove the ratchet and commit
|
||||
_ratchetMap.remove(RatchetMapKey(jid, device));
|
||||
_eventStreamController.add(RatchetRemovedEvent(jid, device));
|
||||
}
|
||||
|
||||
for (final device in _deviceList[jid] ?? <int>[]) {
|
||||
// Remove the ratchet and commit
|
||||
_ratchetMap.remove(RatchetMapKey(jid, device));
|
||||
_eventStreamController.add(RatchetRemovedEvent(jid, device));
|
||||
}
|
||||
|
||||
// Clear the device list
|
||||
_deviceList.remove(jid);
|
||||
_deviceListRequested.remove(jid);
|
||||
_eventStreamController.add(DeviceListModifiedEvent(jid, []));
|
||||
|
||||
await _leaveRatchetCriticalSection(jid);
|
||||
// Clear the device list
|
||||
_deviceList.remove(jid);
|
||||
_deviceListRequested.remove(jid);
|
||||
_eventStreamController.add(DeviceListModifiedEvent(jid, []));
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
/// To be called when a update to the device list of [jid] is returned.
|
||||
@ -748,8 +718,14 @@ class OmemoManager {
|
||||
|
||||
// Mark the ratchet [jid]:[device] as acknowledged.
|
||||
Future<void> ratchetAcknowledged(String jid, int device) async {
|
||||
await _enterRatchetCriticalSection(jid);
|
||||
await _ratchetQueue.synchronized(
|
||||
[jid],
|
||||
() => _ratchetAcknowledged(jid, device),
|
||||
);
|
||||
}
|
||||
|
||||
/// Like [ratchetAcknowledged], but does not acquire the lock for [jid].
|
||||
Future<void> _ratchetAcknowledged(String jid, int device) async {
|
||||
final ratchetKey = RatchetMapKey(jid, device);
|
||||
if (!_ratchetMap.containsKey(ratchetKey)) {
|
||||
_log.warning('Cannot mark $jid:$device as acknowledged as the ratchet does not exist');
|
||||
@ -760,8 +736,6 @@ class OmemoManager {
|
||||
RatchetModifiedEvent(jid, device, ratchet, false, false),
|
||||
);
|
||||
}
|
||||
|
||||
await _leaveRatchetCriticalSection(jid);
|
||||
}
|
||||
|
||||
/// If ratchets with [jid] exists, returns a list of fingerprints for each
|
||||
@ -769,13 +743,13 @@ class OmemoManager {
|
||||
///
|
||||
/// If not ratchets exists, returns null.
|
||||
Future<List<DeviceFingerprint>?> getFingerprintsForJid(String jid) async {
|
||||
await _getFingerprintsForJidImpl(jid);
|
||||
final result = await _getFingerprintsForJidImpl(jid);
|
||||
await _leaveRatchetCriticalSection(jid);
|
||||
|
||||
return result;
|
||||
return _ratchetQueue.synchronized(
|
||||
[jid],
|
||||
() => _getFingerprintsForJidImpl(jid),
|
||||
);
|
||||
}
|
||||
|
||||
/// Same as [getFingerprintsForJid], but without acquiring the lock for [jid].
|
||||
Future<List<DeviceFingerprint>?> _getFingerprintsForJidImpl(String jid) async {
|
||||
// Check if we know of the JID.
|
||||
if (!_deviceList.containsKey(jid)) {
|
||||
|
97
lib/src/omemo/queue.dart
Normal file
97
lib/src/omemo/queue.dart
Normal file
@ -0,0 +1,97 @@
|
||||
import 'dart:async';
|
||||
import 'dart:collection';
|
||||
|
||||
import 'package:meta/meta.dart';
|
||||
import 'package:synchronized/synchronized.dart';
|
||||
|
||||
extension UtilAllMethodsList<T> on List<T> {
|
||||
void removeAll(List<T> values) {
|
||||
for (final value in values) {
|
||||
remove(value);
|
||||
}
|
||||
}
|
||||
|
||||
bool containsAll(List<T> values) {
|
||||
for (final value in values) {
|
||||
if (!contains(value)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
class _RatchetAccessQueueEntry {
|
||||
_RatchetAccessQueueEntry(
|
||||
this.jids,
|
||||
this.completer,
|
||||
);
|
||||
|
||||
final List<String> jids;
|
||||
final Completer<void> completer;
|
||||
}
|
||||
|
||||
class RatchetAccessQueue {
|
||||
final Queue<_RatchetAccessQueueEntry> _queue = Queue();
|
||||
|
||||
@visibleForTesting
|
||||
final List<String> runningOperations = List<String>.empty(growable: true);
|
||||
|
||||
final Lock lock = Lock();
|
||||
|
||||
bool canBypass(List<String> jids) {
|
||||
for (final jid in jids) {
|
||||
if (runningOperations.contains(jid)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
Future<void> enterCriticalSection(List<String> jids) async {
|
||||
final completer = await lock.synchronized<Completer<void>?>(() {
|
||||
if (canBypass(jids)) {
|
||||
runningOperations.addAll(jids);
|
||||
return null;
|
||||
}
|
||||
|
||||
final completer = Completer<void>();
|
||||
_queue.add(
|
||||
_RatchetAccessQueueEntry(
|
||||
jids,
|
||||
completer,
|
||||
),
|
||||
);
|
||||
|
||||
return completer;
|
||||
});
|
||||
|
||||
await completer?.future;
|
||||
}
|
||||
|
||||
Future<void> leaveCriticalSection(List<String> jids) async {
|
||||
await lock.synchronized(() {
|
||||
runningOperations.removeAll(jids);
|
||||
|
||||
while (_queue.isNotEmpty) {
|
||||
if (canBypass(_queue.first.jids)) {
|
||||
final head = _queue.removeFirst();
|
||||
runningOperations.addAll(head.jids);
|
||||
head.completer.complete();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Future<T> synchronized<T>(List<String> jids, Future<T> Function() function) async {
|
||||
await enterCriticalSection(jids);
|
||||
final result = await function();
|
||||
await leaveCriticalSection(jids);
|
||||
|
||||
return result;
|
||||
}
|
||||
}
|
56
test/queue_test.dart
Normal file
56
test/queue_test.dart
Normal file
@ -0,0 +1,56 @@
|
||||
import 'dart:async';
|
||||
|
||||
import 'package:omemo_dart/src/omemo/queue.dart';
|
||||
import 'package:test/test.dart';
|
||||
|
||||
Future<void> testMethod(RatchetAccessQueue queue, List<String> data, int duration) async {
|
||||
await queue.enterCriticalSection(data);
|
||||
|
||||
await Future<void>.delayed(Duration(seconds: duration));
|
||||
|
||||
await queue.leaveCriticalSection(data);
|
||||
}
|
||||
|
||||
void main() {
|
||||
test('Test blocking due to conflicts', () async {
|
||||
final queue = RatchetAccessQueue();
|
||||
|
||||
unawaited(testMethod(queue, ['a', 'b', 'c'], 5));
|
||||
unawaited(testMethod(queue, ['a'], 4));
|
||||
|
||||
await Future<void>.delayed(const Duration(seconds: 1));
|
||||
expect(
|
||||
queue.runningOperations.containsAll(['a', 'b', 'c']),
|
||||
isTrue,
|
||||
);
|
||||
expect(queue.runningOperations.length, 3);
|
||||
|
||||
await Future<void>.delayed(const Duration(seconds: 4));
|
||||
|
||||
expect(
|
||||
queue.runningOperations.containsAll(['a']),
|
||||
isTrue,
|
||||
);
|
||||
expect(queue.runningOperations.length, 1);
|
||||
|
||||
await Future<void>.delayed(const Duration(seconds: 4));
|
||||
expect(queue.runningOperations.length, 0);
|
||||
});
|
||||
|
||||
test('Test not blocking due to no conflicts', () async {
|
||||
final queue = RatchetAccessQueue();
|
||||
|
||||
unawaited(testMethod(queue, ['a', 'b'], 5));
|
||||
unawaited(testMethod(queue, ['c'], 5));
|
||||
unawaited(testMethod(queue, ['d'], 5));
|
||||
|
||||
await Future<void>.delayed(const Duration(seconds: 1));
|
||||
expect(queue.runningOperations.length, 4);
|
||||
expect(
|
||||
queue.runningOperations.containsAll([
|
||||
'a', 'b', 'c', 'd',
|
||||
]),
|
||||
isTrue,
|
||||
);
|
||||
});
|
||||
}
|
Loading…
Reference in New Issue
Block a user