diff --git a/packages/moxxmpp/lib/src/util/wait.dart b/packages/moxxmpp/lib/src/util/wait.dart new file mode 100644 index 0000000..1b92a1a --- /dev/null +++ b/packages/moxxmpp/lib/src/util/wait.dart @@ -0,0 +1,67 @@ +import 'dart:async'; +import 'package:meta/meta.dart'; +import 'package:synchronized/synchronized.dart'; + +/// This class allows for multiple asynchronous code places to wait on the +/// same computation of type [V], indentified by a key of type [K]. +class WaitForTracker { + /// The mapping of key -> Completer for the pending tasks. + final Map>> _tracker = {}; + + /// The lock for accessing _tracker. + final Lock _lock = Lock(); + + /// Wait for a task with key [key]. If there was no such task already + /// present, returns null. If one or more tasks were already present, returns + /// a future that will resolve to the result of the first task. + Future?> waitFor(K key) async { + final result = await _lock.synchronized(() { + if (_tracker.containsKey(key)) { + // The task already exists. Just append outselves + final completer = Completer(); + _tracker[key]!.add(completer); + return completer; + } + + // The task does not exist yet + _tracker[key] = List>.empty(growable: true); + return null; + }); + + return result?.future; + } + + /// Resolve a task with key [key] to [value]. + Future resolve(K key, V value) async { + await _lock.synchronized(() { + if (!_tracker.containsKey(key)) return; + + for (final completer in _tracker[key]!) { + completer.complete(value); + } + + _tracker.remove(key); + }); + } + + Future resolveAll(V value) async { + await _lock.synchronized(() { + for (final key in _tracker.keys) { + for (final completer in _tracker[key]!) { + completer.complete(value); + } + } + }); + } + + /// Remove all tasks from the tracker. + Future clear() async { + await _lock.synchronized(_tracker.clear); + } + + @visibleForTesting + bool hasTasksRunning() => _tracker.isNotEmpty; + + @visibleForTesting + List> getRunningTasks(K key) => _tracker[key]!; +} diff --git a/packages/moxxmpp/lib/src/xeps/xep_0030/cache.dart b/packages/moxxmpp/lib/src/xeps/xep_0030/cache.dart new file mode 100644 index 0000000..f1df4d6 --- /dev/null +++ b/packages/moxxmpp/lib/src/xeps/xep_0030/cache.dart @@ -0,0 +1,23 @@ +import 'package:meta/meta.dart'; + +@internal +@immutable +class DiscoCacheKey { + const DiscoCacheKey(this.jid, this.node); + + /// The JID we're requesting disco data from. + final String jid; + + /// Optionally the node we are requesting from. + final String? node; + + @override + bool operator ==(Object other) { + return other is DiscoCacheKey && + jid == other.jid && + node == other.node; + } + + @override + int get hashCode => jid.hashCode ^ node.hashCode; +} diff --git a/packages/moxxmpp/lib/src/xeps/xep_0030/xep_0030.dart b/packages/moxxmpp/lib/src/xeps/xep_0030/xep_0030.dart index 696c50c..e8405ac 100644 --- a/packages/moxxmpp/lib/src/xeps/xep_0030/xep_0030.dart +++ b/packages/moxxmpp/lib/src/xeps/xep_0030/xep_0030.dart @@ -1,5 +1,6 @@ import 'dart:async'; import 'package:meta/meta.dart'; +import 'package:moxxmpp/src/connection.dart'; import 'package:moxxmpp/src/events.dart'; import 'package:moxxmpp/src/jid.dart'; import 'package:moxxmpp/src/managers/base.dart'; @@ -10,6 +11,8 @@ import 'package:moxxmpp/src/namespaces.dart'; import 'package:moxxmpp/src/stanza.dart'; import 'package:moxxmpp/src/stringxml.dart'; import 'package:moxxmpp/src/types/result.dart'; +import 'package:moxxmpp/src/util/wait.dart'; +import 'package:moxxmpp/src/xeps/xep_0030/cache.dart'; import 'package:moxxmpp/src/xeps/xep_0030/errors.dart'; import 'package:moxxmpp/src/xeps/xep_0030/helpers.dart'; import 'package:moxxmpp/src/xeps/xep_0030/types.dart'; @@ -22,21 +25,6 @@ typedef DiscoInfoRequestCallback = Future Function(); /// Callback that is called when a disco#items requests is received on a given node. typedef DiscoItemsRequestCallback = Future> Function(); -@immutable -class DiscoCacheKey { - const DiscoCacheKey(this.jid, this.node); - final String jid; - final String? node; - - @override - bool operator ==(Object other) { - return other is DiscoCacheKey && jid == other.jid && node == other.node; - } - - @override - int get hashCode => jid.hashCode ^ node.hashCode; -} - /// This manager implements XEP-0030 by providing a way of performing disco#info and /// disco#items requests and answering those requests. /// A caching mechanism is also provided. @@ -62,8 +50,11 @@ class DiscoManager extends XmppManagerBase { /// Map full JID to Disco Info final Map _discoInfoCache = {}; - /// Mapping the full JID to a list of running requests - final Map>>> _runningInfoQueries = {}; + /// The tracker for tracking disco#info queries that are in flight. + final WaitForTracker> _discoInfoTracker = WaitForTracker(); + + /// The tracker for tracking disco#info queries that are in flight. + final WaitForTracker>> _discoItemsTracker = WaitForTracker(); /// Cache lock final Lock _cacheLock = Lock(); @@ -79,12 +70,9 @@ class DiscoManager extends XmppManagerBase { /// The list of disco features that are registered. List get features => _features; - - @visibleForTesting - bool hasInfoQueriesRunning() => _runningInfoQueries.isNotEmpty; @visibleForTesting - List>> getRunningInfoQueries(DiscoCacheKey key) => _runningInfoQueries[key]!; + WaitForTracker> get infoTracker => _discoInfoTracker; @override List getIncomingStanzaHandlers() => [ @@ -118,7 +106,21 @@ class DiscoManager extends XmppManagerBase { Future onXmppEvent(XmppEvent event) async { if (event is PresenceReceivedEvent) { await _onPresence(event.jid, event.presence); - } else if (event is StreamResumeFailedEvent) { + } else if (event is ConnectionStateChangedEvent) { + // TODO(Unknown): This handling is stupid. We should have an event that is + // triggered when we cannot guarantee that everything is as + // it was before. + if (event.state != XmppConnectionState.connected) return; + if (event.resumed) return; + + // Cancel all waiting requests + await _discoInfoTracker.resolveAll( + Result(UnknownDiscoError()), + ); + await _discoItemsTracker.resolveAll( + Result>(UnknownDiscoError()), + ); + await _cacheLock.synchronized(() async { // Clear the cache _discoInfoCache.clear(); @@ -259,46 +261,37 @@ class DiscoManager extends XmppManagerBase { } Future _exitDiscoInfoCriticalSection(DiscoCacheKey key, Result result) async { - return _cacheLock.synchronized(() async { - // Complete all futures - for (final completer in _runningInfoQueries[key]!) { - completer.complete(result); - } - + await _cacheLock.synchronized(() async { // Add to cache if it is a result if (result.isType()) { _discoInfoCache[key] = result.get(); } - - // Remove from the request cache - _runningInfoQueries.remove(key); }); + + await _discoInfoTracker.resolve(key, result); } /// Sends a disco info query to the (full) jid [entity], optionally with node=[node]. Future> discoInfoQuery(String entity, { String? node, bool shouldEncrypt = true }) async { final cacheKey = DiscoCacheKey(entity, node); DiscoInfo? info; - Completer>? completer; - await _cacheLock.synchronized(() async { + final ffuture = await _cacheLock.synchronized>?>?>(() async { // Check if we already know what the JID supports if (_discoInfoCache.containsKey(cacheKey)) { info = _discoInfoCache[cacheKey]; + return null; } else { - // Is a request running? - if (_runningInfoQueries.containsKey(cacheKey)) { - completer = Completer(); - _runningInfoQueries[cacheKey]!.add(completer!); - } else { - _runningInfoQueries[cacheKey] = List.from(>[]); - } + return _discoInfoTracker.waitFor(cacheKey); } }); if (info != null) { return Result(info); - } else if (completer != null) { - return completer!.future; + } else { + final future = await ffuture; + if (future != null) { + return future; + } } final stanza = await getAttributes().sendStanza( @@ -331,6 +324,12 @@ class DiscoManager extends XmppManagerBase { /// Sends a disco items query to the (full) jid [entity], optionally with node=[node]. Future>> discoItemsQuery(String entity, { String? node, bool shouldEncrypt = true }) async { + final key = DiscoCacheKey(entity, node); + final future = await _discoItemsTracker.waitFor(key); + if (future != null) { + return future; + } + final stanza = await getAttributes() .sendStanza( buildDiscoItemsQueryStanza(entity, node: node), @@ -338,12 +337,18 @@ class DiscoManager extends XmppManagerBase { ) as Stanza; final query = stanza.firstTag('query'); - if (query == null) return Result(InvalidResponseDiscoError()); + if (query == null) { + final result = Result>(InvalidResponseDiscoError()); + await _discoItemsTracker.resolve(key, result); + return result; + } if (stanza.type == 'error') { //final error = stanza.firstTag('error'); //print("Disco Items error: " + error.toXml()); - return Result(ErrorResponseDiscoError()); + final result = Result>(ErrorResponseDiscoError()); + await _discoItemsTracker.resolve(key, result); + return result; } final items = query.findTags('item').map((node) => DiscoItem( @@ -352,7 +357,9 @@ class DiscoManager extends XmppManagerBase { name: node.attributes['name'] as String?, ),).toList(); - return Result(items); + final result = Result>(items); + await _discoItemsTracker.resolve(key, result); + return result; } /// Queries information about a jid based on its node and capability hash. diff --git a/packages/moxxmpp/test/wait_test.dart b/packages/moxxmpp/test/wait_test.dart new file mode 100644 index 0000000..d7d5d84 --- /dev/null +++ b/packages/moxxmpp/test/wait_test.dart @@ -0,0 +1,40 @@ +import 'package:test/test.dart'; +import 'package:moxxmpp/src/util/wait.dart'; + +void main() { + test('Test adding and resolving', () async { + // ID -> Milliseconds since epoch + final tracker = WaitForTracker(); + + int r2 = 0; + int r3 = 0; + + // Queue some jobs + final r1 = await tracker.waitFor(0); + expect(r1, null); + + tracker + .waitFor(0) + .then((result) async { + expect(result != null, true); + r2 = await result!; + }); + tracker + .waitFor(0) + .then((result) async { + expect(result != null, true); + r3 = await result!; + }); + + final c = await tracker.waitFor(1); + expect(c, null); + + // Resolve jobs + await tracker.resolve(0, 42); + await tracker.resolve(1, 25); + await tracker.resolve(2, -1); + + expect(r2, 42); + expect(r3, 42); + }); +} diff --git a/packages/moxxmpp/test/xeps/xep_0030_test.dart b/packages/moxxmpp/test/xeps/xep_0030_test.dart index bc08635..c34cead 100644 --- a/packages/moxxmpp/test/xeps/xep_0030_test.dart +++ b/packages/moxxmpp/test/xeps/xep_0030_test.dart @@ -1,9 +1,12 @@ import 'package:moxxmpp/moxxmpp.dart'; import 'package:test/test.dart'; +import '../helpers/logging.dart'; import '../helpers/xmpp.dart'; void main() { + initLogger(); + test('Test having multiple disco requests for the same JID', () async { final fakeSocket = StubTCPSocket( play: [ @@ -102,7 +105,7 @@ void main() { await Future.delayed(const Duration(seconds: 1)); expect( - disco.getRunningInfoQueries(DiscoCacheKey(jid.toString(), null)).length, + disco.infoTracker.getRunningTasks(DiscoCacheKey(jid.toString(), null)).length, 1, ); fakeSocket.injectRawXml(""); @@ -111,6 +114,6 @@ void main() { expect(fakeSocket.getState(), 6); expect(await result1, await result2); - expect(disco.hasInfoQueriesRunning(), false); + expect(disco.infoTracker.hasTasksRunning(), false); }); }