diff --git a/lib/src/omemo/sessionmanager.dart b/lib/src/omemo/sessionmanager.dart index d6ca264..15ae4dc 100644 --- a/lib/src/omemo/sessionmanager.dart +++ b/lib/src/omemo/sessionmanager.dart @@ -20,6 +20,7 @@ import 'package:synchronized/synchronized.dart'; /// The info used for when encrypting the AES key for the actual payload. const omemoPayloadInfoString = 'OMEMO Payload'; +@immutable class EncryptionResult { const EncryptionResult(this.ciphertext, this.encryptedKeys); @@ -32,6 +33,7 @@ class EncryptionResult { final List encryptedKeys; } +@immutable class EncryptedKey { const EncryptedKey(this.rid, this.value, this.kex); @@ -40,6 +42,22 @@ class EncryptedKey { final bool kex; } +@immutable +class RatchetMapKey { + + const RatchetMapKey(this.jid, this.deviceId); + final String jid; + final int deviceId; + + @override + bool operator ==(Object other) { + return other is RatchetMapKey && jid == other.jid && deviceId == other.deviceId; + } + + @override + int get hashCode => jid.hashCode ^ deviceId.hashCode; +} + class OmemoSessionManager { OmemoSessionManager(this._device) @@ -61,8 +79,7 @@ class OmemoSessionManager { final Lock _lock; /// Mapping of the Device Id to its OMEMO session - // TODO(PapaTutuWawa): Make this map use a tuple (Jid, Id) as a key - final Map _ratchetMap; + final Map _ratchetMap; /// Mapping of a bare Jid to its Device Ids final Map> _deviceMap; @@ -99,8 +116,9 @@ class OmemoSessionManager { } // Add the ratchet session - if (!_ratchetMap.containsKey(deviceId)) { - _ratchetMap[deviceId] = ratchet; + final key = RatchetMapKey(jid, deviceId); + if (!_ratchetMap.containsKey(key)) { + _ratchetMap[key] = ratchet; } else { // TODO(PapaTutuWawa): What do we do now? throw Exception(); @@ -180,7 +198,8 @@ class OmemoSessionManager { await _lock.synchronized(() async { // We assume that the user already checked if the session exists for (final deviceId in _deviceMap[jid]!) { - final ratchet = _ratchetMap[deviceId]!; + final ratchetKey = RatchetMapKey(jid, deviceId); + final ratchet = _ratchetMap[ratchetKey]!; final ciphertext = (await ratchet.ratchetEncrypt(concatKey)).ciphertext; if (kex.isNotEmpty && kex.containsKey(deviceId)) { @@ -256,8 +275,9 @@ class OmemoSessionManager { } final message = OmemoMessage.fromBuffer(authMessage.message!); - - final ratchet = _ratchetMap[senderDeviceId]!; + + final ratchetKey = RatchetMapKey(senderJid, senderDeviceId); + final ratchet = _ratchetMap[ratchetKey]!; List keyAndHmac; if (rawKey.kex) { keyAndHmac = await ratchet.ratchetDecrypt(message, authMessage.writeToBuffer()); @@ -278,5 +298,5 @@ class OmemoSessionManager { } @visibleForTesting - OmemoDoubleRatchet getRatchet(int deviceId) => _ratchetMap[deviceId]!; + OmemoDoubleRatchet getRatchet(String jid, int deviceId) => _ratchetMap[RatchetMapKey(jid, deviceId)]!; } diff --git a/test/serialisation_test.dart b/test/serialisation_test.dart index 986aafc..503b1b3 100644 --- a/test/serialisation_test.dart +++ b/test/serialisation_test.dart @@ -41,7 +41,7 @@ void main() { (await aliceSession.getDevice()).id, aliceMessage.encryptedKeys, ); - final aliceOld = aliceSession.getRatchet((await bobSession.getDevice()).id); + final aliceOld = aliceSession.getRatchet(bobJid, (await bobSession.getDevice()).id); final aliceSerialised = await aliceOld.toJson(); final aliceNew = OmemoDoubleRatchet.fromJson(aliceSerialised);