diff --git a/lib/src/omemo/omemo.dart b/lib/src/omemo/omemo.dart index 70fa49c..3dd4b7e 100644 --- a/lib/src/omemo/omemo.dart +++ b/lib/src/omemo/omemo.dart @@ -20,6 +20,7 @@ import 'package:omemo_dart/src/omemo/encryption_result.dart'; import 'package:omemo_dart/src/omemo/errors.dart'; import 'package:omemo_dart/src/omemo/events.dart'; import 'package:omemo_dart/src/omemo/fingerprint.dart'; +import 'package:omemo_dart/src/omemo/queue.dart'; import 'package:omemo_dart/src/omemo/ratchet_map_key.dart'; import 'package:omemo_dart/src/omemo/stanza.dart'; import 'package:omemo_dart/src/protobuf/schema.pb.dart'; @@ -71,21 +72,20 @@ class OmemoManager { final Future Function(String jid) subscribeToDeviceListNodeImpl; /// Map bare JID to its known devices - Map> _deviceList = {}; + final Map> _deviceList = {}; /// Map bare JIDs to whether we already requested the device list once final Map _deviceListRequested = {}; /// Map bare a ratchet key to its ratchet. Note that this is also locked by /// _ratchetCriticalSectionLock. - Map _ratchetMap = {}; + final Map _ratchetMap = {}; /// Map bare JID to whether we already tried to subscribe to the device list node. final Map _subscriptionMap = {}; /// For preventing a race condition in encryption/decryption - final Map>> _ratchetCriticalSectionQueue = {}; - final Lock _ratchetCriticalSectionLock = Lock(); + final RatchetAccessQueue _ratchetQueue = RatchetAccessQueue(); /// The OmemoManager's trust management final TrustManager _trustManager; @@ -101,39 +101,6 @@ class OmemoManager { StreamController.broadcast(); Stream get eventStream => _eventStreamController.stream; - /// Enter the critical section for performing cryptographic operations on the ratchets - Future _enterRatchetCriticalSection(String jid) async { - return; - final completer = await _ratchetCriticalSectionLock.synchronized(() { - if (_ratchetCriticalSectionQueue.containsKey(jid)) { - final c = Completer(); - _ratchetCriticalSectionQueue[jid]!.addLast(c); - return c; - } - - _ratchetCriticalSectionQueue[jid] = Queue(); - return null; - }); - - if (completer != null) { - await completer.future; - } - } - - /// Leave the critical section for the ratchets. - Future _leaveRatchetCriticalSection(String jid) async { - return; - await _ratchetCriticalSectionLock.synchronized(() { - if (_ratchetCriticalSectionQueue.containsKey(jid)) { - if (_ratchetCriticalSectionQueue[jid]!.isEmpty) { - _ratchetCriticalSectionQueue.remove(jid); - } else { - _ratchetCriticalSectionQueue[jid]!.removeFirst().complete(); - } - } - }); - } - Future> _decryptAndVerifyHmac( List? ciphertext, List keyAndHmac, @@ -242,7 +209,7 @@ class OmemoManager { } else { // Ratchet is not acknowledged _log.finest('Sending acknowledgement heartbeat to ${key.jid}'); - await ratchetAcknowledged(key.jid, key.deviceId); + await _ratchetAcknowledged(key.jid, key.deviceId); await sendEmptyOmemoMessageImpl( await _onOutgoingStanzaImpl( OmemoOutgoingStanza( @@ -257,13 +224,10 @@ class OmemoManager { /// Future onIncomingStanza(OmemoIncomingStanza stanza) async { - // NOTE: We do this so that we cannot forget to acquire and free the critical - // section. - await _enterRatchetCriticalSection(stanza.bareSenderJid); - final result = await _onIncomingStanzaImpl(stanza); - await _leaveRatchetCriticalSection(stanza.bareSenderJid); - - return result; + return _ratchetQueue.synchronized( + [stanza.bareSenderJid], + () => _onIncomingStanzaImpl(stanza), + ); } Future _onIncomingStanzaImpl(OmemoIncomingStanza stanza) async { @@ -419,7 +383,7 @@ class OmemoManager { if (!_deviceList[stanza.bareSenderJid]!.contains(stanza.senderDeviceId)) { _deviceList[stanza.bareSenderJid]!.add(stanza.senderDeviceId); } - await sendOmemoHeartbeat(stanza.bareSenderJid); + await _sendOmemoHeartbeat(stanza.bareSenderJid); return DecryptionResult( null, @@ -489,13 +453,10 @@ class OmemoManager { } Future onOutgoingStanza(OmemoOutgoingStanza stanza) async { - // TODO: Be more smart about the locking - // TODO: Do we even need to lock? - await _enterRatchetCriticalSection(stanza.recipientJids.first); - final result = await _onOutgoingStanzaImpl(stanza); - await _leaveRatchetCriticalSection(stanza.recipientJids.first); - - return result; + return _ratchetQueue.synchronized( + stanza.recipientJids, + () => _onOutgoingStanzaImpl(stanza), + ); } Future _onOutgoingStanzaImpl(OmemoOutgoingStanza stanza) async { @@ -697,9 +658,17 @@ class OmemoManager { ); } - // Sends an empty OMEMO message (heartbeat) to [jid]. + /// Sends an empty OMEMO message (heartbeat) to [jid]. Future sendOmemoHeartbeat(String jid) async { - final result = await onOutgoingStanza( + await _ratchetQueue.synchronized( + [jid], + () => _sendOmemoHeartbeat(jid), + ); + } + + /// Like [sendOmemoHeartbeat], but does not acquire the lock for [jid]. + Future _sendOmemoHeartbeat(String jid) async { + final result = await _onOutgoingStanzaImpl( OmemoOutgoingStanza( [jid], null, @@ -710,20 +679,21 @@ class OmemoManager { /// Removes all ratchets associated with [jid]. Future removeAllRatchets(String jid) async { - await _enterRatchetCriticalSection(jid); + await _ratchetQueue.synchronized( + [jid], + () async { + for (final device in _deviceList[jid] ?? []) { + // Remove the ratchet and commit + _ratchetMap.remove(RatchetMapKey(jid, device)); + _eventStreamController.add(RatchetRemovedEvent(jid, device)); + } - for (final device in _deviceList[jid] ?? []) { - // Remove the ratchet and commit - _ratchetMap.remove(RatchetMapKey(jid, device)); - _eventStreamController.add(RatchetRemovedEvent(jid, device)); - } - - // Clear the device list - _deviceList.remove(jid); - _deviceListRequested.remove(jid); - _eventStreamController.add(DeviceListModifiedEvent(jid, [])); - - await _leaveRatchetCriticalSection(jid); + // Clear the device list + _deviceList.remove(jid); + _deviceListRequested.remove(jid); + _eventStreamController.add(DeviceListModifiedEvent(jid, [])); + }, + ); } /// To be called when a update to the device list of [jid] is returned. @@ -748,8 +718,14 @@ class OmemoManager { // Mark the ratchet [jid]:[device] as acknowledged. Future ratchetAcknowledged(String jid, int device) async { - await _enterRatchetCriticalSection(jid); + await _ratchetQueue.synchronized( + [jid], + () => _ratchetAcknowledged(jid, device), + ); + } + /// Like [ratchetAcknowledged], but does not acquire the lock for [jid]. + Future _ratchetAcknowledged(String jid, int device) async { final ratchetKey = RatchetMapKey(jid, device); if (!_ratchetMap.containsKey(ratchetKey)) { _log.warning('Cannot mark $jid:$device as acknowledged as the ratchet does not exist'); @@ -760,8 +736,6 @@ class OmemoManager { RatchetModifiedEvent(jid, device, ratchet, false, false), ); } - - await _leaveRatchetCriticalSection(jid); } /// If ratchets with [jid] exists, returns a list of fingerprints for each @@ -769,13 +743,13 @@ class OmemoManager { /// /// If not ratchets exists, returns null. Future?> getFingerprintsForJid(String jid) async { - await _getFingerprintsForJidImpl(jid); - final result = await _getFingerprintsForJidImpl(jid); - await _leaveRatchetCriticalSection(jid); - - return result; + return _ratchetQueue.synchronized( + [jid], + () => _getFingerprintsForJidImpl(jid), + ); } + /// Same as [getFingerprintsForJid], but without acquiring the lock for [jid]. Future?> _getFingerprintsForJidImpl(String jid) async { // Check if we know of the JID. if (!_deviceList.containsKey(jid)) { diff --git a/lib/src/omemo/queue.dart b/lib/src/omemo/queue.dart new file mode 100644 index 0000000..a63f02e --- /dev/null +++ b/lib/src/omemo/queue.dart @@ -0,0 +1,97 @@ +import 'dart:async'; +import 'dart:collection'; + +import 'package:meta/meta.dart'; +import 'package:synchronized/synchronized.dart'; + +extension UtilAllMethodsList on List { + void removeAll(List values) { + for (final value in values) { + remove(value); + } + } + + bool containsAll(List values) { + for (final value in values) { + if (!contains(value)) { + return false; + } + } + + return true; + } +} + +class _RatchetAccessQueueEntry { + _RatchetAccessQueueEntry( + this.jids, + this.completer, + ); + + final List jids; + final Completer completer; +} + +class RatchetAccessQueue { + final Queue<_RatchetAccessQueueEntry> _queue = Queue(); + + @visibleForTesting + final List runningOperations = List.empty(growable: true); + + final Lock lock = Lock(); + + bool canBypass(List jids) { + for (final jid in jids) { + if (runningOperations.contains(jid)) { + return false; + } + } + + return true; + } + + Future enterCriticalSection(List jids) async { + final completer = await lock.synchronized?>(() { + if (canBypass(jids)) { + runningOperations.addAll(jids); + return null; + } + + final completer = Completer(); + _queue.add( + _RatchetAccessQueueEntry( + jids, + completer, + ), + ); + + return completer; + }); + + await completer?.future; + } + + Future leaveCriticalSection(List jids) async { + await lock.synchronized(() { + runningOperations.removeAll(jids); + + while (_queue.isNotEmpty) { + if (canBypass(_queue.first.jids)) { + final head = _queue.removeFirst(); + runningOperations.addAll(head.jids); + head.completer.complete(); + } else { + break; + } + } + }); + } + + Future synchronized(List jids, Future Function() function) async { + await enterCriticalSection(jids); + final result = await function(); + await leaveCriticalSection(jids); + + return result; + } +} diff --git a/test/queue_test.dart b/test/queue_test.dart new file mode 100644 index 0000000..2cd19d3 --- /dev/null +++ b/test/queue_test.dart @@ -0,0 +1,56 @@ +import 'dart:async'; + +import 'package:omemo_dart/src/omemo/queue.dart'; +import 'package:test/test.dart'; + +Future testMethod(RatchetAccessQueue queue, List data, int duration) async { + await queue.enterCriticalSection(data); + + await Future.delayed(Duration(seconds: duration)); + + await queue.leaveCriticalSection(data); +} + +void main() { + test('Test blocking due to conflicts', () async { + final queue = RatchetAccessQueue(); + + unawaited(testMethod(queue, ['a', 'b', 'c'], 5)); + unawaited(testMethod(queue, ['a'], 4)); + + await Future.delayed(const Duration(seconds: 1)); + expect( + queue.runningOperations.containsAll(['a', 'b', 'c']), + isTrue, + ); + expect(queue.runningOperations.length, 3); + + await Future.delayed(const Duration(seconds: 4)); + + expect( + queue.runningOperations.containsAll(['a']), + isTrue, + ); + expect(queue.runningOperations.length, 1); + + await Future.delayed(const Duration(seconds: 4)); + expect(queue.runningOperations.length, 0); + }); + + test('Test not blocking due to no conflicts', () async { + final queue = RatchetAccessQueue(); + + unawaited(testMethod(queue, ['a', 'b'], 5)); + unawaited(testMethod(queue, ['c'], 5)); + unawaited(testMethod(queue, ['d'], 5)); + + await Future.delayed(const Duration(seconds: 1)); + expect(queue.runningOperations.length, 4); + expect( + queue.runningOperations.containsAll([ + 'a', 'b', 'c', 'd', + ]), + isTrue, + ); + }); +}