diff --git a/lib/src/omemo/sessionmanager.dart b/lib/src/omemo/sessionmanager.dart index 5d7bbc7..8953ace 100644 --- a/lib/src/omemo/sessionmanager.dart +++ b/lib/src/omemo/sessionmanager.dart @@ -297,6 +297,16 @@ class OmemoSessionManager { ); } + /// In case a decryption error occurs, the Double Ratchet spec says to just restore + /// the ratchet to its old state. As such, this function restores the ratchet at + /// [mapKey] with [oldRatchet]. + Future _restoreRatchet(RatchetMapKey mapKey, OmemoDoubleRatchet oldRatchet) async { + await _lock.synchronized(() { + print('RESTORING RATCHETS'); + _ratchetMap[mapKey] = oldRatchet; + }); + } + /// Attempt to decrypt [ciphertext]. [keys] refers to the elements inside the /// element with a "jid" attribute matching our own. [senderJid] refers to the /// bare Jid of the sender. [senderDeviceId] refers to the "sid" attribute of the @@ -314,9 +324,15 @@ class OmemoSessionManager { throw NotEncryptedForDeviceException(); } + final ratchetKey = RatchetMapKey(senderJid, senderDeviceId); final decodedRawKey = base64.decode(rawKey.value); OmemoAuthenticatedMessage authMessage; + OmemoDoubleRatchet? oldRatchet; if (rawKey.kex) { + // If the ratchet already existed, we store it. If it didn't, oldRatchet will stay + // null. + oldRatchet = await _getRatchet(ratchetKey); + // TODO(PapaTutuWawa): Only do this when we should final kex = OmemoKeyExchange.fromBuffer(decodedRawKey); await _addSessionFromKeyExchange( @@ -347,31 +363,38 @@ class OmemoSessionManager { } final message = OmemoMessage.fromBuffer(authMessage.message!); - final ratchetKey = RatchetMapKey(senderJid, senderDeviceId); List? keyAndHmac; - await _lock.synchronized(() async { - final ratchet = _ratchetMap[ratchetKey]!; + // We can guarantee that the ratchet exists at this point in time + final ratchet = (await _getRatchet(ratchetKey))!; + oldRatchet ??= ratchet ; + + try { if (rawKey.kex) { keyAndHmac = await ratchet.ratchetDecrypt(message, authMessage.writeToBuffer()); } else { keyAndHmac = await ratchet.ratchetDecrypt(message, decodedRawKey); } + } on InvalidMessageHMACException { + await _restoreRatchet(ratchetKey, oldRatchet); + rethrow; + } - // Commit the ratchet - _eventStreamController.add(RatchetModifiedEvent(senderJid, senderDeviceId, ratchet)); - }); + // Commit the ratchet + _eventStreamController.add(RatchetModifiedEvent(senderJid, senderDeviceId, ratchet)); // Empty OMEMO messages should just have the key decrypted and/or session set up. if (ciphertext == null) { return null; } - final key = keyAndHmac!.sublist(0, 32); - final hmac = keyAndHmac!.sublist(32, 48); + final key = keyAndHmac.sublist(0, 32); + final hmac = keyAndHmac.sublist(32, 48); final derivedKeys = await deriveEncryptionKeys(key, omemoPayloadInfoString); final computedHmac = await truncatedHmac(ciphertext, derivedKeys.authenticationKey); if (!listsEqual(hmac, computedHmac)) { + // TODO(PapaTutuWawa): I am unsure if we should restore the ratchet here + await _restoreRatchet(ratchetKey, oldRatchet); throw InvalidMessageHMACException(); } @@ -514,6 +537,12 @@ class OmemoSessionManager { _eventStreamController.add(DeviceModifiedEvent(_device)); }); } + + Future _getRatchet(RatchetMapKey key) async { + return _lock.synchronized(() async { + return _ratchetMap[key]; + }); + } @visibleForTesting OmemoDoubleRatchet getRatchet(String jid, int deviceId) => _ratchetMap[RatchetMapKey(jid, deviceId)]!; diff --git a/test/omemo_test.dart b/test/omemo_test.dart index c755b4b..457473a 100644 --- a/test/omemo_test.dart +++ b/test/omemo_test.dart @@ -614,4 +614,144 @@ void main() { expect(await aliceRatchet1.equals(aliceRatchet2), false); expect(await bobRatchet1.equals(bobRatchet2), false); }); + + test('Test receiving an old message that contains a KEX', () async { + const aliceJid = 'alice@server.example'; + const bobJid = 'bob@other.server.example'; + // Alice and Bob generate their sessions + final aliceSession = await OmemoSessionManager.generateNewIdentity( + aliceJid, + AlwaysTrustingTrustManager(), + opkAmount: 1, + ); + final bobSession = await OmemoSessionManager.generateNewIdentity( + bobJid, + AlwaysTrustingTrustManager(), + opkAmount: 2, + ); + + // Alice sends Bob a message + final msg1 = await aliceSession.encryptToJid( + bobJid, + 'Hallo Welt', + newSessions: [ + await bobSession.getDeviceBundle(), + ], + ); + await bobSession.decryptMessage( + msg1.ciphertext, + aliceJid, + await aliceSession.getDeviceId(), + msg1.encryptedKeys, + ); + + // Bob responds + final msg2 = await bobSession.encryptToJid( + aliceJid, + 'Hello!', + ); + await aliceSession.decryptMessage( + msg2.ciphertext, + bobJid, + await bobSession.getDeviceId(), + msg2.encryptedKeys, + ); + + // Due to some issue with the transport protocol, the first message Bob received is + // received again + try { + await bobSession.decryptMessage( + msg1.ciphertext, + aliceJid, + await aliceSession.getDeviceId(), + msg1.encryptedKeys, + ); + expect(true, false); + } on InvalidMessageHMACException { + // NOOP + } + + final msg3 = await aliceSession.encryptToJid( + bobJid, + 'Are you okay?', + ); + final result = await bobSession.decryptMessage( + msg3.ciphertext, + aliceJid, + await aliceSession.getDeviceId(), + msg3.encryptedKeys, + ); + + expect(result, 'Are you okay?'); + }); + + test('Test receiving an old message that does not contain a KEX', () async { + const aliceJid = 'alice@server.example'; + const bobJid = 'bob@other.server.example'; + // Alice and Bob generate their sessions + final aliceSession = await OmemoSessionManager.generateNewIdentity( + aliceJid, + AlwaysTrustingTrustManager(), + opkAmount: 1, + ); + final bobSession = await OmemoSessionManager.generateNewIdentity( + bobJid, + AlwaysTrustingTrustManager(), + opkAmount: 2, + ); + + // Alice sends Bob a message + final msg1 = await aliceSession.encryptToJid( + bobJid, + 'Hallo Welt', + newSessions: [ + await bobSession.getDeviceBundle(), + ], + ); + await bobSession.decryptMessage( + msg1.ciphertext, + aliceJid, + await aliceSession.getDeviceId(), + msg1.encryptedKeys, + ); + + // Bob responds + final msg2 = await bobSession.encryptToJid( + aliceJid, + 'Hello!', + ); + await aliceSession.decryptMessage( + msg2.ciphertext, + bobJid, + await bobSession.getDeviceId(), + msg2.encryptedKeys, + ); + + // Due to some issue with the transport protocol, the first message Alice received is + // received again. + try { + await aliceSession.decryptMessage( + msg2.ciphertext, + bobJid, + await bobSession.getDeviceId(), + msg2.encryptedKeys, + ); + expect(true, false); + } catch (_) { + // NOOP + } + + final msg3 = await aliceSession.encryptToJid( + bobJid, + 'Are you okay?', + ); + final result = await bobSession.decryptMessage( + msg3.ciphertext, + aliceJid, + await aliceSession.getDeviceId(), + msg3.encryptedKeys, + ); + + expect(result, 'Are you okay?'); + }); }