feat: Factor out "multiple-waiting" into its own thing

This commit is contained in:
PapaTutuWawa 2023-01-27 21:14:35 +01:00
parent 7e588f01b0
commit 47337540f5
5 changed files with 187 additions and 47 deletions

View File

@ -0,0 +1,67 @@
import 'dart:async';
import 'package:meta/meta.dart';
import 'package:synchronized/synchronized.dart';
/// This class allows for multiple asynchronous code places to wait on the
/// same computation of type [V], indentified by a key of type [K].
class WaitForTracker<K, V> {
/// The mapping of key -> Completer for the pending tasks.
final Map<K, List<Completer<V>>> _tracker = {};
/// The lock for accessing _tracker.
final Lock _lock = Lock();
/// Wait for a task with key [key]. If there was no such task already
/// present, returns null. If one or more tasks were already present, returns
/// a future that will resolve to the result of the first task.
Future<Future<V>?> waitFor(K key) async {
final result = await _lock.synchronized(() {
if (_tracker.containsKey(key)) {
// The task already exists. Just append outselves
final completer = Completer<V>();
_tracker[key]!.add(completer);
return completer;
}
// The task does not exist yet
_tracker[key] = List<Completer<V>>.empty(growable: true);
return null;
});
return result?.future;
}
/// Resolve a task with key [key] to [value].
Future<void> resolve(K key, V value) async {
await _lock.synchronized(() {
if (!_tracker.containsKey(key)) return;
for (final completer in _tracker[key]!) {
completer.complete(value);
}
_tracker.remove(key);
});
}
Future<void> resolveAll(V value) async {
await _lock.synchronized(() {
for (final key in _tracker.keys) {
for (final completer in _tracker[key]!) {
completer.complete(value);
}
}
});
}
/// Remove all tasks from the tracker.
Future<void> clear() async {
await _lock.synchronized(_tracker.clear);
}
@visibleForTesting
bool hasTasksRunning() => _tracker.isNotEmpty;
@visibleForTesting
List<Completer<V>> getRunningTasks(K key) => _tracker[key]!;
}

View File

@ -0,0 +1,23 @@
import 'package:meta/meta.dart';
@internal
@immutable
class DiscoCacheKey {
const DiscoCacheKey(this.jid, this.node);
/// The JID we're requesting disco data from.
final String jid;
/// Optionally the node we are requesting from.
final String? node;
@override
bool operator ==(Object other) {
return other is DiscoCacheKey &&
jid == other.jid &&
node == other.node;
}
@override
int get hashCode => jid.hashCode ^ node.hashCode;
}

View File

@ -1,5 +1,6 @@
import 'dart:async';
import 'package:meta/meta.dart';
import 'package:moxxmpp/src/connection.dart';
import 'package:moxxmpp/src/events.dart';
import 'package:moxxmpp/src/jid.dart';
import 'package:moxxmpp/src/managers/base.dart';
@ -10,6 +11,8 @@ import 'package:moxxmpp/src/namespaces.dart';
import 'package:moxxmpp/src/stanza.dart';
import 'package:moxxmpp/src/stringxml.dart';
import 'package:moxxmpp/src/types/result.dart';
import 'package:moxxmpp/src/util/wait.dart';
import 'package:moxxmpp/src/xeps/xep_0030/cache.dart';
import 'package:moxxmpp/src/xeps/xep_0030/errors.dart';
import 'package:moxxmpp/src/xeps/xep_0030/helpers.dart';
import 'package:moxxmpp/src/xeps/xep_0030/types.dart';
@ -22,21 +25,6 @@ typedef DiscoInfoRequestCallback = Future<DiscoInfo> Function();
/// Callback that is called when a disco#items requests is received on a given node.
typedef DiscoItemsRequestCallback = Future<List<DiscoItem>> Function();
@immutable
class DiscoCacheKey {
const DiscoCacheKey(this.jid, this.node);
final String jid;
final String? node;
@override
bool operator ==(Object other) {
return other is DiscoCacheKey && jid == other.jid && node == other.node;
}
@override
int get hashCode => jid.hashCode ^ node.hashCode;
}
/// This manager implements XEP-0030 by providing a way of performing disco#info and
/// disco#items requests and answering those requests.
/// A caching mechanism is also provided.
@ -62,8 +50,11 @@ class DiscoManager extends XmppManagerBase {
/// Map full JID to Disco Info
final Map<DiscoCacheKey, DiscoInfo> _discoInfoCache = {};
/// Mapping the full JID to a list of running requests
final Map<DiscoCacheKey, List<Completer<Result<DiscoError, DiscoInfo>>>> _runningInfoQueries = {};
/// The tracker for tracking disco#info queries that are in flight.
final WaitForTracker<DiscoCacheKey, Result<DiscoError, DiscoInfo>> _discoInfoTracker = WaitForTracker();
/// The tracker for tracking disco#info queries that are in flight.
final WaitForTracker<DiscoCacheKey, Result<DiscoError, List<DiscoItem>>> _discoItemsTracker = WaitForTracker();
/// Cache lock
final Lock _cacheLock = Lock();
@ -79,12 +70,9 @@ class DiscoManager extends XmppManagerBase {
/// The list of disco features that are registered.
List<String> get features => _features;
@visibleForTesting
bool hasInfoQueriesRunning() => _runningInfoQueries.isNotEmpty;
@visibleForTesting
List<Completer<Result<DiscoError, DiscoInfo>>> getRunningInfoQueries(DiscoCacheKey key) => _runningInfoQueries[key]!;
WaitForTracker<DiscoCacheKey, Result<DiscoError, DiscoInfo>> get infoTracker => _discoInfoTracker;
@override
List<StanzaHandler> getIncomingStanzaHandlers() => [
@ -118,7 +106,21 @@ class DiscoManager extends XmppManagerBase {
Future<void> onXmppEvent(XmppEvent event) async {
if (event is PresenceReceivedEvent) {
await _onPresence(event.jid, event.presence);
} else if (event is StreamResumeFailedEvent) {
} else if (event is ConnectionStateChangedEvent) {
// TODO(Unknown): This handling is stupid. We should have an event that is
// triggered when we cannot guarantee that everything is as
// it was before.
if (event.state != XmppConnectionState.connected) return;
if (event.resumed) return;
// Cancel all waiting requests
await _discoInfoTracker.resolveAll(
Result<DiscoError, DiscoInfo>(UnknownDiscoError()),
);
await _discoItemsTracker.resolveAll(
Result<DiscoError, List<DiscoItem>>(UnknownDiscoError()),
);
await _cacheLock.synchronized(() async {
// Clear the cache
_discoInfoCache.clear();
@ -259,46 +261,37 @@ class DiscoManager extends XmppManagerBase {
}
Future<void> _exitDiscoInfoCriticalSection(DiscoCacheKey key, Result<DiscoError, DiscoInfo> result) async {
return _cacheLock.synchronized(() async {
// Complete all futures
for (final completer in _runningInfoQueries[key]!) {
completer.complete(result);
}
await _cacheLock.synchronized(() async {
// Add to cache if it is a result
if (result.isType<DiscoInfo>()) {
_discoInfoCache[key] = result.get<DiscoInfo>();
}
// Remove from the request cache
_runningInfoQueries.remove(key);
});
await _discoInfoTracker.resolve(key, result);
}
/// Sends a disco info query to the (full) jid [entity], optionally with node=[node].
Future<Result<DiscoError, DiscoInfo>> discoInfoQuery(String entity, { String? node, bool shouldEncrypt = true }) async {
final cacheKey = DiscoCacheKey(entity, node);
DiscoInfo? info;
Completer<Result<DiscoError, DiscoInfo>>? completer;
await _cacheLock.synchronized(() async {
final ffuture = await _cacheLock.synchronized<Future<Future<Result<DiscoError, DiscoInfo>>?>?>(() async {
// Check if we already know what the JID supports
if (_discoInfoCache.containsKey(cacheKey)) {
info = _discoInfoCache[cacheKey];
return null;
} else {
// Is a request running?
if (_runningInfoQueries.containsKey(cacheKey)) {
completer = Completer();
_runningInfoQueries[cacheKey]!.add(completer!);
} else {
_runningInfoQueries[cacheKey] = List.from(<Completer<DiscoInfo?>>[]);
}
return _discoInfoTracker.waitFor(cacheKey);
}
});
if (info != null) {
return Result<DiscoError, DiscoInfo>(info);
} else if (completer != null) {
return completer!.future;
} else {
final future = await ffuture;
if (future != null) {
return future;
}
}
final stanza = await getAttributes().sendStanza(
@ -331,6 +324,12 @@ class DiscoManager extends XmppManagerBase {
/// Sends a disco items query to the (full) jid [entity], optionally with node=[node].
Future<Result<DiscoError, List<DiscoItem>>> discoItemsQuery(String entity, { String? node, bool shouldEncrypt = true }) async {
final key = DiscoCacheKey(entity, node);
final future = await _discoItemsTracker.waitFor(key);
if (future != null) {
return future;
}
final stanza = await getAttributes()
.sendStanza(
buildDiscoItemsQueryStanza(entity, node: node),
@ -338,12 +337,18 @@ class DiscoManager extends XmppManagerBase {
) as Stanza;
final query = stanza.firstTag('query');
if (query == null) return Result(InvalidResponseDiscoError());
if (query == null) {
final result = Result<DiscoError, List<DiscoItem>>(InvalidResponseDiscoError());
await _discoItemsTracker.resolve(key, result);
return result;
}
if (stanza.type == 'error') {
//final error = stanza.firstTag('error');
//print("Disco Items error: " + error.toXml());
return Result(ErrorResponseDiscoError());
final result = Result<DiscoError, List<DiscoItem>>(ErrorResponseDiscoError());
await _discoItemsTracker.resolve(key, result);
return result;
}
final items = query.findTags('item').map((node) => DiscoItem(
@ -352,7 +357,9 @@ class DiscoManager extends XmppManagerBase {
name: node.attributes['name'] as String?,
),).toList();
return Result(items);
final result = Result<DiscoError, List<DiscoItem>>(items);
await _discoItemsTracker.resolve(key, result);
return result;
}
/// Queries information about a jid based on its node and capability hash.

View File

@ -0,0 +1,40 @@
import 'package:test/test.dart';
import 'package:moxxmpp/src/util/wait.dart';
void main() {
test('Test adding and resolving', () async {
// ID -> Milliseconds since epoch
final tracker = WaitForTracker<int, int>();
int r2 = 0;
int r3 = 0;
// Queue some jobs
final r1 = await tracker.waitFor(0);
expect(r1, null);
tracker
.waitFor(0)
.then((result) async {
expect(result != null, true);
r2 = await result!;
});
tracker
.waitFor(0)
.then((result) async {
expect(result != null, true);
r3 = await result!;
});
final c = await tracker.waitFor(1);
expect(c, null);
// Resolve jobs
await tracker.resolve(0, 42);
await tracker.resolve(1, 25);
await tracker.resolve(2, -1);
expect(r2, 42);
expect(r3, 42);
});
}

View File

@ -1,9 +1,12 @@
import 'package:moxxmpp/moxxmpp.dart';
import 'package:test/test.dart';
import '../helpers/logging.dart';
import '../helpers/xmpp.dart';
void main() {
initLogger();
test('Test having multiple disco requests for the same JID', () async {
final fakeSocket = StubTCPSocket(
play: [
@ -102,7 +105,7 @@ void main() {
await Future.delayed(const Duration(seconds: 1));
expect(
disco.getRunningInfoQueries(DiscoCacheKey(jid.toString(), null)).length,
disco.infoTracker.getRunningTasks(DiscoCacheKey(jid.toString(), null)).length,
1,
);
fakeSocket.injectRawXml("<iq type='result' id='${fakeSocket.lastId!}' from='romeo@montague.lit/orchard' to='polynomdivision@test.server/MU29eEZn' xmlns='jabber:client'><query xmlns='http://jabber.org/protocol/disco#info' /></iq>");
@ -111,6 +114,6 @@ void main() {
expect(fakeSocket.getState(), 6);
expect(await result1, await result2);
expect(disco.hasInfoQueriesRunning(), false);
expect(disco.infoTracker.hasTasksRunning(), false);
});
}