diff --git a/lib/src/omemo/sessionmanager.dart b/lib/src/omemo/sessionmanager.dart index 3f4c7e3..2eca1c4 100644 --- a/lib/src/omemo/sessionmanager.dart +++ b/lib/src/omemo/sessionmanager.dart @@ -26,14 +26,15 @@ class EncryptionResult { /// Mapping of the device Id to the key for decrypting ciphertext, encrypted /// for the ratchet with said device Id - final Map> encryptedKeys; + final List encryptedKeys; } class EncryptedKey { - const EncryptedKey(this.rid, this.value); + const EncryptedKey(this.rid, this.value, this.kex); final int rid; final String value; + final bool kex; } class OmemoSessionManager { @@ -127,8 +128,8 @@ class OmemoSessionManager { /// Encrypt the key [plaintext] for all known bundles of [jid]. Returns a map that /// maps the Bundle Id to the ciphertext of [plaintext]. - Future encryptToJid(String jid, String plaintext) async { - final encryptedKeys = >{}; + Future encryptToJid(String jid, String plaintext, { OmemoBundle? newSession }) async { + final encryptedKeys = List.empty(growable: true); // Generate the key and encrypt the plaintext final key = generateRandomBytes(32); @@ -141,11 +142,35 @@ class OmemoSessionManager { final hmac = await truncatedHmac(ciphertext, keys.authenticationKey); final concatKey = concat([key, hmac]); + OmemoKeyExchange? kex; + if (newSession != null) { + kex = await addSessionFromBundle(jid, newSession.id, newSession); + } + await _lock.synchronized(() async { // We assume that the user already checked if the session exists for (final deviceId in _deviceMap[jid]!) { final ratchet = _ratchetMap[deviceId]!; - encryptedKeys[deviceId] = (await ratchet.ratchetEncrypt(concatKey)).ciphertext; + final ciphertext = (await ratchet.ratchetEncrypt(concatKey)).ciphertext; + + if (kex != null && deviceId == newSession?.id) { + kex.message = OmemoAuthenticatedMessage.fromBuffer(ciphertext); + encryptedKeys.add( + EncryptedKey( + deviceId, + base64.encode(kex.writeToBuffer()), + true, + ), + ); + } else { + encryptedKeys.add( + EncryptedKey( + deviceId, + base64.encode(ciphertext), + false, + ), + ); + } } }); @@ -166,6 +191,22 @@ class OmemoSessionManager { throw NotEncryptedForDeviceException(); } + final decodedRawKey = base64.decode(rawKey.value); + OmemoAuthenticatedMessage authMessage; + if (rawKey.kex) { + // TODO(PapaTutuWawa): Only do this when we should + final kex = OmemoKeyExchange.fromBuffer(decodedRawKey); + await addSessionFromKeyExchange( + senderJid, + senderDeviceId, + kex, + ); + + authMessage = kex.message!; + } else { + authMessage = OmemoAuthenticatedMessage.fromBuffer(decodedRawKey); + } + final devices = _deviceMap[senderJid]; if (devices == null) { throw NoDecryptionKeyException(); @@ -173,13 +214,16 @@ class OmemoSessionManager { if (!devices.contains(senderDeviceId)) { throw NoDecryptionKeyException(); } - - final decodedRawKey = base64.decode(rawKey.value); - final authMessage = OmemoAuthenticatedMessage.fromBuffer(decodedRawKey); + final message = OmemoMessage.fromBuffer(authMessage.message!); final ratchet = _ratchetMap[senderDeviceId]!; - final keyAndHmac = await ratchet.ratchetDecrypt(message, decodedRawKey); + List keyAndHmac; + if (rawKey.kex) { + keyAndHmac = await ratchet.ratchetDecrypt(message, authMessage.writeToBuffer()); + } else { + keyAndHmac = await ratchet.ratchetDecrypt(message, decodedRawKey); + } final key = keyAndHmac.sublist(0, 32); final hmac = keyAndHmac.sublist(32, 48); final derivedKeys = await deriveEncryptionKeys(key, omemoPayloadInfoString); diff --git a/test/omemo_test.dart b/test/omemo_test.dart index 14013f1..537e62f 100644 --- a/test/omemo_test.dart +++ b/test/omemo_test.dart @@ -1,4 +1,3 @@ -import 'dart:convert'; import 'package:omemo_dart/omemo_dart.dart'; import 'package:test/test.dart'; @@ -11,21 +10,13 @@ void main() { final aliceSession = await OmemoSessionManager.generateNewIdentity(opkAmount: 1); final bobSession = await OmemoSessionManager.generateNewIdentity(opkAmount: 1); - // Perform the X3DH - final kex = await aliceSession.addSessionFromBundle( - bobJid, - bobSession.device.id, - await bobSession.device.toBundle(), - ); - await bobSession.addSessionFromKeyExchange( - aliceJid, - aliceSession.device.id, - kex, - ); - // Alice encrypts a message for Bob const messagePlaintext = 'Hello Bob!'; - final aliceMessage = await aliceSession.encryptToJid(bobJid, messagePlaintext); + final aliceMessage = await aliceSession.encryptToJid( + bobJid, + messagePlaintext, + newSession: await bobSession.device.toBundle(), + ); expect(aliceMessage.encryptedKeys.length, 1); // Alice sends the message to Bob @@ -36,12 +27,7 @@ void main() { aliceMessage.ciphertext, aliceJid, aliceSession.device.id, - [ - EncryptedKey( - bobSession.device.id, - base64.encode(aliceMessage.encryptedKeys[bobSession.device.id]!), - ), - ], + aliceMessage.encryptedKeys, ); expect(messagePlaintext, bobMessage);