feat: Notify the user of a modified device

This commit is contained in:
PapaTutuWawa 2022-08-05 16:52:02 +02:00
parent 30e3bd78cd
commit 62fdf568aa
6 changed files with 76 additions and 21 deletions

View File

@ -2,6 +2,7 @@ library omemo_dart;
export 'src/double_ratchet/double_ratchet.dart'; export 'src/double_ratchet/double_ratchet.dart';
export 'src/errors.dart'; export 'src/errors.dart';
export 'src/events.dart';
export 'src/helpers.dart'; export 'src/helpers.dart';
export 'src/keys.dart'; export 'src/keys.dart';
export 'src/omemo/bundle.dart'; export 'src/omemo/bundle.dart';

11
lib/src/events.dart Normal file
View File

@ -0,0 +1,11 @@
import 'package:omemo_dart/src/omemo/device.dart';
abstract class OmemoEvent {}
/// Triggered by the OmemoSessionManager when our own device bundle was modified
/// and thus should be republished.
class DeviceBundleModifiedEvent extends OmemoEvent {
DeviceBundleModifiedEvent(this.device);
final Device device;
}

View File

@ -12,7 +12,7 @@ List<int> concat(List<List<int>> inputs) {
/// Compares the two lists [a] and [b] and return true if [a] and [b] are index-by-index /// Compares the two lists [a] and [b] and return true if [a] and [b] are index-by-index
/// equal. Returns false, if they are not "equal"; /// equal. Returns false, if they are not "equal";
bool listsEqual(List<int> a, List<int> b) { bool listsEqual<T>(List<T> a, List<T> b) {
// TODO(Unknown): Do we need to use a constant time comparison? // TODO(Unknown): Do we need to use a constant time comparison?
if (a.length != b.length) return false; if (a.length != b.length) return false;

View File

@ -47,7 +47,7 @@ class Device {
/// This replaces the Onetime-Prekey with id [id] with a completely new one. Returns /// This replaces the Onetime-Prekey with id [id] with a completely new one. Returns
/// a new Device object that copies over everything but replaces said key. /// a new Device object that copies over everything but replaces said key.
Future<Device> replaceOnetimePrekey(int id) async { Future<Device> replaceOnetimePrekey(int id) async {
final newOpk = await OmemoKeyPair.generateNewPair(KeyPairType.x25519); opks[id] = await OmemoKeyPair.generateNewPair(KeyPairType.x25519);
return Device( return Device(
id, id,
@ -55,13 +55,7 @@ class Device {
spk, spk,
spkId, spkId,
spkSignature, spkSignature,
opks.map((keyId, opk) { opks,
if (keyId == id) {
return MapEntry(id, newOpk);
}
return MapEntry(id, opk);
}),
); );
} }

View File

@ -1,9 +1,11 @@
import 'dart:async';
import 'dart:convert'; import 'dart:convert';
import 'package:collection/collection.dart'; import 'package:collection/collection.dart';
import 'package:cryptography/cryptography.dart'; import 'package:cryptography/cryptography.dart';
import 'package:omemo_dart/src/crypto.dart'; import 'package:omemo_dart/src/crypto.dart';
import 'package:omemo_dart/src/double_ratchet/double_ratchet.dart'; import 'package:omemo_dart/src/double_ratchet/double_ratchet.dart';
import 'package:omemo_dart/src/errors.dart'; import 'package:omemo_dart/src/errors.dart';
import 'package:omemo_dart/src/events.dart';
import 'package:omemo_dart/src/helpers.dart'; import 'package:omemo_dart/src/helpers.dart';
import 'package:omemo_dart/src/keys.dart'; import 'package:omemo_dart/src/keys.dart';
import 'package:omemo_dart/src/omemo/bundle.dart'; import 'package:omemo_dart/src/omemo/bundle.dart';
@ -39,7 +41,12 @@ class EncryptedKey {
class OmemoSessionManager { class OmemoSessionManager {
OmemoSessionManager(this.device) : _ratchetMap = {}, _deviceMap = {}, _lock = Lock(); OmemoSessionManager(this._device)
: _ratchetMap = {},
_deviceMap = {},
_lock = Lock(),
_deviceLock = Lock(),
_eventStreamController = StreamController<OmemoEvent>.broadcast();
/// Generate a new cryptographic identity. /// Generate a new cryptographic identity.
static Future<OmemoSessionManager> generateNewIdentity({ int opkAmount = 100 }) async { static Future<OmemoSessionManager> generateNewIdentity({ int opkAmount = 100 }) async {
@ -58,8 +65,26 @@ class OmemoSessionManager {
/// Mapping of a bare Jid to its Device Ids /// Mapping of a bare Jid to its Device Ids
final Map<String, List<int>> _deviceMap; final Map<String, List<int>> _deviceMap;
/// Our own keys /// The event bus of the session manager
Device device; final StreamController<OmemoEvent> _eventStreamController;
/// Our own keys...
// ignore: prefer_final_fields
Device _device;
/// and its lock
final Lock _deviceLock;
/// A stream that receives events regarding the session
Stream<OmemoEvent> get eventStream => _eventStreamController.stream;
Future<Device> getDevice() async {
Device? dev;
await _deviceLock.synchronized(() async {
dev = _device;
});
return dev!;
}
/// Add a session [ratchet] with the [deviceId] to the internal tracking state. /// Add a session [ratchet] with the [deviceId] to the internal tracking state.
Future<void> addSession(String jid, int deviceId, OmemoDoubleRatchet ratchet) async { Future<void> addSession(String jid, int deviceId, OmemoDoubleRatchet ratchet) async {
@ -84,6 +109,7 @@ class OmemoSessionManager {
/// Create a ratchet session initiated by Alice to the user with Jid [jid] and the device /// Create a ratchet session initiated by Alice to the user with Jid [jid] and the device
/// [deviceId] from the bundle [bundle]. /// [deviceId] from the bundle [bundle].
Future<OmemoKeyExchange> addSessionFromBundle(String jid, int deviceId, OmemoBundle bundle) async { Future<OmemoKeyExchange> addSessionFromBundle(String jid, int deviceId, OmemoBundle bundle) async {
final device = await getDevice();
final kexResult = await x3dhFromBundle( final kexResult = await x3dhFromBundle(
bundle, bundle,
device.ik, device.ik,
@ -105,8 +131,8 @@ class OmemoSessionManager {
/// Build a new session with the user at [jid] with the device [deviceId] using data /// Build a new session with the user at [jid] with the device [deviceId] using data
/// from the key exchange [kex]. /// from the key exchange [kex].
// TODO(PapaTutuWawa): Replace the OPK
Future<void> addSessionFromKeyExchange(String jid, int deviceId, OmemoKeyExchange kex) async { Future<void> addSessionFromKeyExchange(String jid, int deviceId, OmemoKeyExchange kex) async {
final device = await getDevice();
final kexResult = await x3dhFromInitialMessage( final kexResult = await x3dhFromInitialMessage(
X3DHMessage( X3DHMessage(
OmemoPublicKey.fromBytes(kex.ik!, KeyPairType.ed25519), OmemoPublicKey.fromBytes(kex.ik!, KeyPairType.ed25519),
@ -189,6 +215,7 @@ class OmemoSessionManager {
/// <encrypted /> element. /// <encrypted /> element.
Future<String> decryptMessage(List<int> ciphertext, String senderJid, int senderDeviceId, List<EncryptedKey> keys) async { Future<String> decryptMessage(List<int> ciphertext, String senderJid, int senderDeviceId, List<EncryptedKey> keys) async {
// Try to find a session we can decrypt with. // Try to find a session we can decrypt with.
var device = await getDevice();
final rawKey = keys.firstWhereOrNull((key) => key.rid == device.id); final rawKey = keys.firstWhereOrNull((key) => key.rid == device.id);
if (rawKey == null) { if (rawKey == null) {
throw NotEncryptedForDeviceException(); throw NotEncryptedForDeviceException();
@ -206,6 +233,14 @@ class OmemoSessionManager {
); );
authMessage = kex.message!; authMessage = kex.message!;
// Replace the OPK
await _deviceLock.synchronized(() async {
device = await device.replaceOnetimePrekey(kex.pkId!);
_eventStreamController.add(
DeviceBundleModifiedEvent(device),
);
});
} else { } else {
authMessage = OmemoAuthenticatedMessage.fromBuffer(decodedRawKey); authMessage = OmemoAuthenticatedMessage.fromBuffer(decodedRawKey);
} }

View File

@ -7,8 +7,15 @@ void main() {
const bobJid = 'bob@other.server.example'; const bobJid = 'bob@other.server.example';
// Alice and Bob generate their sessions // Alice and Bob generate their sessions
var deviceModified = false;
final aliceSession = await OmemoSessionManager.generateNewIdentity(opkAmount: 1); final aliceSession = await OmemoSessionManager.generateNewIdentity(opkAmount: 1);
final bobSession = await OmemoSessionManager.generateNewIdentity(opkAmount: 1); final bobSession = await OmemoSessionManager.generateNewIdentity(opkAmount: 1);
final bobOpks = (await bobSession.getDevice()).opks.values.toList();
bobSession.eventStream.listen((event) {
if (event is DeviceBundleModifiedEvent) {
deviceModified = true;
}
});
// Alice encrypts a message for Bob // Alice encrypts a message for Bob
const messagePlaintext = 'Hello Bob!'; const messagePlaintext = 'Hello Bob!';
@ -16,7 +23,7 @@ void main() {
bobJid, bobJid,
messagePlaintext, messagePlaintext,
newSessions: [ newSessions: [
await bobSession.device.toBundle(), await (await bobSession.getDevice()).toBundle(),
], ],
); );
expect(aliceMessage.encryptedKeys.length, 1); expect(aliceMessage.encryptedKeys.length, 1);
@ -28,10 +35,17 @@ void main() {
final bobMessage = await bobSession.decryptMessage( final bobMessage = await bobSession.decryptMessage(
aliceMessage.ciphertext, aliceMessage.ciphertext,
aliceJid, aliceJid,
aliceSession.device.id, (await aliceSession.getDevice()).id,
aliceMessage.encryptedKeys, aliceMessage.encryptedKeys,
); );
expect(messagePlaintext, bobMessage); expect(messagePlaintext, bobMessage);
// The event should be triggered
expect(deviceModified, true);
// Bob should have replaced his OPK
expect(
listsEqual(bobOpks, (await bobSession.getDevice()).opks.values.toList()),
false,
);
// Bob responds to Alice // Bob responds to Alice
const bobResponseText = 'Oh, hello Alice!'; const bobResponseText = 'Oh, hello Alice!';
@ -47,7 +61,7 @@ void main() {
final aliceReceivedMessage = await aliceSession.decryptMessage( final aliceReceivedMessage = await aliceSession.decryptMessage(
bobResponseMessage.ciphertext, bobResponseMessage.ciphertext,
bobJid, bobJid,
bobSession.device.id, (await bobSession.getDevice()).id,
bobResponseMessage.encryptedKeys, bobResponseMessage.encryptedKeys,
); );
expect(bobResponseText, aliceReceivedMessage); expect(bobResponseText, aliceReceivedMessage);
@ -69,8 +83,8 @@ void main() {
bobJid, bobJid,
messagePlaintext, messagePlaintext,
newSessions: [ newSessions: [
await bobSession.device.toBundle(), await (await bobSession.getDevice()).toBundle(),
await bobSession2.device.toBundle(), await (await bobSession2.getDevice()).toBundle(),
], ],
); );
expect(aliceMessage.encryptedKeys.length, 2); expect(aliceMessage.encryptedKeys.length, 2);
@ -84,7 +98,7 @@ void main() {
final bobMessage = await bobSession.decryptMessage( final bobMessage = await bobSession.decryptMessage(
aliceMessage.ciphertext, aliceMessage.ciphertext,
aliceJid, aliceJid,
aliceSession.device.id, (await aliceSession.getDevice()).id,
aliceMessage.encryptedKeys, aliceMessage.encryptedKeys,
); );
expect(messagePlaintext, bobMessage); expect(messagePlaintext, bobMessage);
@ -103,7 +117,7 @@ void main() {
final aliceReceivedMessage = await aliceSession.decryptMessage( final aliceReceivedMessage = await aliceSession.decryptMessage(
bobResponseMessage.ciphertext, bobResponseMessage.ciphertext,
bobJid, bobJid,
bobSession.device.id, (await bobSession.getDevice()).id,
bobResponseMessage.encryptedKeys, bobResponseMessage.encryptedKeys,
); );
expect(bobResponseText, aliceReceivedMessage); expect(bobResponseText, aliceReceivedMessage);