diff --git a/lib/src/double_ratchet/double_ratchet.dart b/lib/src/double_ratchet/double_ratchet.dart index b5978bf..0f5e0d2 100644 --- a/lib/src/double_ratchet/double_ratchet.dart +++ b/lib/src/double_ratchet/double_ratchet.dart @@ -69,6 +69,7 @@ class OmemoDoubleRatchet { this.mkSkipped, // MKSKIPPED this.acknowledged, this.kexTimestamp, + this.kex, ); factory OmemoDoubleRatchet.fromJson(Map data) { @@ -83,10 +84,11 @@ class OmemoDoubleRatchet { 'ns': 0, 'nr': 0, 'pn': 0, - 'ik_pub': 'base/64/encoded', + 'ik_pub': null | 'base/64/encoded', 'session_ad': 'base/64/encoded', 'acknowledged': true | false, 'kex_timestamp': int, + 'kex': 'base/64/encoded', 'mkskipped': [ { 'key': 'base/64/encoded', @@ -132,6 +134,7 @@ class OmemoDoubleRatchet { mkSkipped, data['acknowledged']! as bool, data['kex_timestamp']! as int, + data['kex'] as String?, ); } @@ -167,6 +170,9 @@ class OmemoDoubleRatchet { /// Precision is milliseconds since epoch. int kexTimestamp; + /// The key exchange that was used for initiating the session. + final String? kex; + /// Indicates whether we received an empty OMEMO message after building a session with /// the device. bool acknowledged; @@ -194,6 +200,7 @@ class OmemoDoubleRatchet { {}, false, timestamp, + '', ); } @@ -216,6 +223,7 @@ class OmemoDoubleRatchet { {}, false, kexTimestamp, + null, ); } @@ -243,6 +251,7 @@ class OmemoDoubleRatchet { 'mkskipped': mkSkippedSerialised, 'acknowledged': acknowledged, 'kex_timestamp': kexTimestamp, + 'kex': kex, }; } @@ -359,6 +368,30 @@ class OmemoDoubleRatchet { Map>.from(mkSkipped), acknowledged, kexTimestamp, + kex, + ); + } + + OmemoDoubleRatchet cloneWithKex(String kex) { + return OmemoDoubleRatchet( + dhs, + dhr, + rk, + cks != null ? + List.from(cks!) : + null, + ckr != null ? + List.from(ckr!) : + null, + ns, + nr, + pn, + ik, + sessionAd, + Map>.from(mkSkipped), + acknowledged, + kexTimestamp, + kex, ); } diff --git a/lib/src/omemo/sessionmanager.dart b/lib/src/omemo/sessionmanager.dart index 84bad34..5ddaf29 100644 --- a/lib/src/omemo/sessionmanager.dart +++ b/lib/src/omemo/sessionmanager.dart @@ -170,7 +170,7 @@ class OmemoSessionManager { /// Build a new session with the user at [jid] with the device [deviceId] using data /// from the key exchange [kex]. In case [kex] contains an unknown Signed Prekey /// identifier an UnknownSignedPrekeyException will be thrown. - Future _addSessionFromKeyExchange(String jid, int deviceId, OmemoKeyExchange kex) async { + Future _addSessionFromKeyExchange(String jid, int deviceId, OmemoKeyExchange kex) async { // Pick the correct SPK final device = await getDevice(); final spk = await _lock.synchronized(() async { @@ -204,8 +204,7 @@ class OmemoSessionManager { getTimestamp(), ); - await _trustManager.onNewSession(jid, deviceId); - await _addSession(jid, deviceId, ratchet); + return ratchet; } /// Like [encryptToJids] but only for one Jid [jid]. @@ -264,24 +263,53 @@ class OmemoSessionManager { } final ratchetKey = RatchetMapKey(jid, deviceId); - final ratchet = _ratchetMap[ratchetKey]!; + var ratchet = _ratchetMap[ratchetKey]!; final ciphertext = (await ratchet.ratchetEncrypt(keyPayload)).ciphertext; - - // Commit the ratchet - _eventStreamController.add(RatchetModifiedEvent(jid, deviceId, ratchet)); - + if (kex.isNotEmpty && kex.containsKey(deviceId)) { + // The ratchet did not exist final k = kex[deviceId]! ..message = OmemoAuthenticatedMessage.fromBuffer(ciphertext); + final buffer = base64.encode(k.writeToBuffer()); encryptedKeys.add( EncryptedKey( jid, deviceId, - base64.encode(k.writeToBuffer()), + buffer, true, ), ); + + ratchet = ratchet.cloneWithKex(buffer); + _ratchetMap[ratchetKey] = ratchet; + } else if (!ratchet.acknowledged) { + // The ratchet exists but is not acked + if (ratchet.kex != null) { + final oldKex = OmemoKeyExchange.fromBuffer(base64.decode(ratchet.kex!)) + ..message = OmemoAuthenticatedMessage.fromBuffer(ciphertext); + + encryptedKeys.add( + EncryptedKey( + jid, + deviceId, + base64.encode(oldKex.writeToBuffer()), + true, + ), + ); + } else { + // The ratchet is not acked but we don't have the old key exchange + _log.warning('Ratchet for $jid:$deviceId is not acked but the kex attribute is null'); + encryptedKeys.add( + EncryptedKey( + jid, + deviceId, + base64.encode(ciphertext), + false, + ), + ); + } } else { + // The ratchet exists and is acked encryptedKeys.add( EncryptedKey( jid, @@ -291,6 +319,9 @@ class OmemoSessionManager { ), ); } + + // Commit the ratchet + _eventStreamController.add(RatchetModifiedEvent(jid, deviceId, ratchet)); } } }); @@ -319,6 +350,25 @@ class OmemoSessionManager { ); }); } + + Future _decryptAndVerifyHmac(List? ciphertext, List keyAndHmac) async { + // Empty OMEMO messages should just have the key decrypted and/or session set up. + if (ciphertext == null) { + return null; + } + + final key = keyAndHmac.sublist(0, 32); + final hmac = keyAndHmac.sublist(32, 48); + final derivedKeys = await deriveEncryptionKeys(key, omemoPayloadInfoString); + final computedHmac = await truncatedHmac(ciphertext, derivedKeys.authenticationKey); + if (!listsEqual(hmac, computedHmac)) { + throw InvalidMessageHMACException(); + } + + return utf8.decode( + await aes256CbcDecrypt(ciphertext, derivedKeys.encryptionKey, derivedKeys.iv), + ); + } /// Attempt to decrypt [ciphertext]. [keys] refers to the elements inside the /// element with a "jid" attribute matching our own. [senderJid] refers to the @@ -342,6 +392,7 @@ class OmemoSessionManager { final ratchetKey = RatchetMapKey(senderJid, senderDeviceId); final decodedRawKey = base64.decode(rawKey.value); + List? keyAndHmac; OmemoAuthenticatedMessage authMessage; OmemoDoubleRatchet? oldRatchet; OmemoMessage? message; @@ -359,14 +410,34 @@ class OmemoSessionManager { if (oldRatchet.kexTimestamp > timestamp) { throw InvalidKeyExchangeException(); } + + // Try to decrypt it + try { + final decrypted = await oldRatchet.ratchetDecrypt(message, authMessage.writeToBuffer()); + + // Commit the ratchet + _eventStreamController.add( + RatchetModifiedEvent( + senderJid, + senderDeviceId, + oldRatchet, + ), + ); + + final plaintext = await _decryptAndVerifyHmac( + ciphertext, + decrypted, + ); + await _addSession(senderJid, senderDeviceId, oldRatchet); + return plaintext; + } catch (_) { + _log.finest('Failed to use old ratchet with KEX for existing ratchet'); + } } - - // TODO(PapaTutuWawa): Only do this when we should - await _addSessionFromKeyExchange( - senderJid, - senderDeviceId, - kex, - ); + + final r = await _addSessionFromKeyExchange(senderJid, senderDeviceId, kex); + await _trustManager.onNewSession(senderJid, senderDeviceId); + await _addSession(senderJid, senderDeviceId, r); // Replace the OPK // TODO(PapaTutuWawa): Replace the OPK when we know that the KEX worked @@ -389,7 +460,6 @@ class OmemoSessionManager { throw NoDecryptionKeyException(); } - List? keyAndHmac; // We can guarantee that the ratchet exists at this point in time final ratchet = (await _getRatchet(ratchetKey))!; oldRatchet ??= ratchet.clone(); @@ -408,24 +478,12 @@ class OmemoSessionManager { // Commit the ratchet _eventStreamController.add(RatchetModifiedEvent(senderJid, senderDeviceId, ratchet)); - // Empty OMEMO messages should just have the key decrypted and/or session set up. - if (ciphertext == null) { - return null; - } - - final key = keyAndHmac.sublist(0, 32); - final hmac = keyAndHmac.sublist(32, 48); - final derivedKeys = await deriveEncryptionKeys(key, omemoPayloadInfoString); - - final computedHmac = await truncatedHmac(ciphertext, derivedKeys.authenticationKey); - if (!listsEqual(hmac, computedHmac)) { - // TODO(PapaTutuWawa): I am unsure if we should restore the ratchet here + try { + return _decryptAndVerifyHmac(ciphertext, keyAndHmac); + } catch (_) { await _restoreRatchet(ratchetKey, oldRatchet); - throw InvalidMessageHMACException(); + rethrow; } - - final plaintext = await aes256CbcDecrypt(ciphertext, derivedKeys.encryptionKey, derivedKeys.iv); - return utf8.decode(plaintext); } /// Returns the list of hex-encoded fingerprints we have for sessions with [jid]. diff --git a/test/omemo_test.dart b/test/omemo_test.dart index c4f98ca..bc2b400 100644 --- a/test/omemo_test.dart +++ b/test/omemo_test.dart @@ -80,7 +80,7 @@ void main() { ], ); expect(aliceMessage.encryptedKeys.length, 1); - + // Alice sends the message to Bob // ... @@ -106,6 +106,10 @@ void main() { false, ); + // Ratchets are acked + await aliceSession.ratchetAcknowledged(bobJid, await bobSession.getDeviceId()); + await bobSession.ratchetAcknowledged(aliceJid, await aliceSession.getDeviceId()); + // Bob responds to Alice const bobResponseText = 'Oh, hello Alice!'; final bobResponseMessage = await bobSession.encryptToJid( @@ -176,6 +180,10 @@ void main() { ); expect(messagePlaintext, bobMessage); + // Ratchets are acked + await aliceSession.ratchetAcknowledged(bobJid, await bobSession.getDeviceId()); + await bobSession.ratchetAcknowledged(aliceJid, await aliceSession.getDeviceId()); + // Bob responds to Alice const bobResponseText = 'Oh, hello Alice!'; final bobResponseMessage = await bobSession.encryptToJid( @@ -448,6 +456,10 @@ void main() { 0, ); + // Ratchets are acked + await aliceSession.ratchetAcknowledged(bobJid, await bobSession.getDeviceId()); + await bobSession.ratchetAcknowledged(aliceJid, await aliceSession.getDeviceId()); + for (var i = 0; i < 100; i++) { final messageText = 'Test Message #$i'; // Bob responds to Alice @@ -665,6 +677,58 @@ void main() { expect(await bobRatchet1.equals(bobRatchet2), false); }); + test('Test resending key exchanges', () async { + const aliceJid = 'alice@server.example'; + const bobJid = 'bob@other.server.example'; + // Alice and Bob generate their sessions + final aliceSession = await OmemoSessionManager.generateNewIdentity( + aliceJid, + AlwaysTrustingTrustManager(), + opkAmount: 1, + ); + final bobSession = await OmemoSessionManager.generateNewIdentity( + bobJid, + AlwaysTrustingTrustManager(), + opkAmount: 2, + ); + + // Alice sends Bob a message + final msg1 = await aliceSession.encryptToJid( + bobJid, + 'Hallo Welt', + newSessions: [ + await bobSession.getDeviceBundle(), + ], + ); + // The first message should be a kex message + expect(msg1.encryptedKeys.first.kex, true); + + await bobSession.decryptMessage( + msg1.ciphertext, + aliceJid, + await aliceSession.getDeviceId(), + msg1.encryptedKeys, + 0, + ); + + // Alice is impatient and immediately sends another message before the original one + // can be acknowledged by Bob + final msg2 = await aliceSession.encryptToJid( + bobJid, + "Why don't you answer?", + ); + expect(msg2.encryptedKeys.first.kex, true); + + await bobSession.decryptMessage( + msg2.ciphertext, + aliceJid, + await aliceSession.getDeviceId(), + msg2.encryptedKeys, + getTimestamp(), + ); + + }); + test('Test receiving old messages including a KEX', () async { const aliceJid = 'alice@server.example'; const bobJid = 'bob@other.server.example'; @@ -703,6 +767,10 @@ void main() { t1, ); + // Ratchets are acked + await aliceSession.ratchetAcknowledged(bobJid, await bobSession.getDeviceId()); + await bobSession.ratchetAcknowledged(aliceJid, await aliceSession.getDeviceId()); + // Bob responds final msg2 = await bobSession.encryptToJid( aliceJid, @@ -785,4 +853,94 @@ void main() { expect(result, 'Are you okay?'); }); + + test("Test ignoring a new KEX when we haven't acket it yet", () async { + const aliceJid = 'alice@server.example'; + const bobJid = 'bob@other.server.example'; + // Alice and Bob generate their sessions + final aliceSession = await OmemoSessionManager.generateNewIdentity( + aliceJid, + AlwaysTrustingTrustManager(), + opkAmount: 1, + ); + final bobSession = await OmemoSessionManager.generateNewIdentity( + bobJid, + AlwaysTrustingTrustManager(), + opkAmount: 1, + ); + + // Alice sends Bob a message + final msg1 = await aliceSession.encryptToJid( + bobJid, + 'Hallo Welt', + newSessions: [ + await bobSession.getDeviceBundle(), + ], + ); + expect(msg1.encryptedKeys.first.kex, true); + + await bobSession.decryptMessage( + msg1.ciphertext, + aliceJid, + await aliceSession.getDeviceId(), + msg1.encryptedKeys, + getTimestamp(), + ); + + // Alice sends another message before the ack can reach us + final msg2 = await aliceSession.encryptToJid( + bobJid, + 'ANSWER ME!', + ); + expect(msg2.encryptedKeys.first.kex, true); + + await bobSession.decryptMessage( + msg2.ciphertext, + aliceJid, + await aliceSession.getDeviceId(), + msg2.encryptedKeys, + getTimestamp(), + ); + + // Now the acks reach us + await aliceSession.ratchetAcknowledged(bobJid, await bobSession.getDeviceId()); + await bobSession.ratchetAcknowledged(aliceJid, await aliceSession.getDeviceId()); + + // Alice sends another message + final msg3 = await aliceSession.encryptToJid( + bobJid, + "You read the message, didn't you?", + ); + expect(msg3.encryptedKeys.first.kex, false); + + await bobSession.decryptMessage( + msg3.ciphertext, + aliceJid, + await aliceSession.getDeviceId(), + msg3.encryptedKeys, + getTimestamp(), + ); + + for (var i = 0; i < 100; i++) { + final messageText = 'Test Message #$i'; + // Bob responds to Alice + final bobResponseMessage = await bobSession.encryptToJid( + aliceJid, + messageText, + ); + + // Bob sends the message to Alice + // ... + + // Alice decrypts it + final aliceReceivedMessage = await aliceSession.decryptMessage( + bobResponseMessage.ciphertext, + bobJid, + await bobSession.getDeviceId(), + bobResponseMessage.encryptedKeys, + 0, + ); + expect(messageText, aliceReceivedMessage); + } + }); }