#include "session_cipher.h" #include <assert.h> #include <string.h> #include "session_builder.h" #include "session_builder_internal.h" #include "session_record.h" #include "session_state.h" #include "ratchet.h" #include "protocol.h" #include "signal_protocol_internal.h" struct session_cipher { signal_protocol_store_context *store; const signal_protocol_address *remote_address; session_builder *builder; signal_context *global_context; int (*decrypt_callback)(session_cipher *cipher, signal_buffer *plaintext, void *decrypt_context); int inside_callback; void *user_data; }; static int session_cipher_decrypt_from_record_and_signal_message(session_cipher *cipher, session_record *record, signal_message *ciphertext, signal_buffer **plaintext); static int session_cipher_decrypt_from_state_and_signal_message(session_cipher *cipher, session_state *state, signal_message *ciphertext, signal_buffer **plaintext); static int session_cipher_get_or_create_chain_key(session_cipher *cipher, ratchet_chain_key **chain_key, session_state *state, ec_public_key *their_ephemeral); static int session_cipher_get_or_create_message_keys(ratchet_message_keys *message_keys, session_state *state, ec_public_key *their_ephemeral, ratchet_chain_key *chain_key, uint32_t counter, signal_context *global_context); static int session_cipher_get_ciphertext(session_cipher *cipher, signal_buffer **ciphertext, uint32_t version, ratchet_message_keys *message_keys, const uint8_t *plaintext, size_t plaintext_len); static int session_cipher_get_plaintext(session_cipher *cipher, signal_buffer **plaintext, uint32_t version, ratchet_message_keys *message_keys, const uint8_t *ciphertext, size_t ciphertext_len); static int session_cipher_decrypt_callback(session_cipher *cipher, signal_buffer *plaintext, void *decrypt_context); int session_cipher_create(session_cipher **cipher, signal_protocol_store_context *store, const signal_protocol_address *remote_address, signal_context *global_context) { int result = 0; session_builder *builder = 0; session_cipher *result_cipher; assert(store); assert(global_context); result = session_builder_create(&builder, store, remote_address, global_context); if(result < 0) { return result; } result_cipher = malloc(sizeof(session_cipher)); if(!result_cipher) { return SG_ERR_NOMEM; } memset(result_cipher, 0, sizeof(session_cipher)); result_cipher->store = store; result_cipher->remote_address = remote_address; result_cipher->builder = builder; result_cipher->global_context = global_context; *cipher = result_cipher; return 0; } void session_cipher_set_user_data(session_cipher *cipher, void *user_data) { assert(cipher); cipher->user_data = user_data; } void *session_cipher_get_user_data(session_cipher *cipher) { assert(cipher); return cipher->user_data; } void session_cipher_set_decryption_callback(session_cipher *cipher, int (*callback)(session_cipher *cipher, signal_buffer *plaintext, void *decrypt_context)) { assert(cipher); cipher->decrypt_callback = callback; } int session_cipher_encrypt(session_cipher *cipher, const uint8_t *padded_message, size_t padded_message_len, ciphertext_message **encrypted_message) { int result = 0; session_record *record = 0; session_state *state = 0; ratchet_chain_key *chain_key = 0; ratchet_chain_key *next_chain_key = 0; ratchet_message_keys message_keys; ec_public_key *sender_ephemeral = 0; uint32_t previous_counter = 0; uint32_t session_version = 0; signal_buffer *ciphertext = 0; uint32_t chain_key_index = 0; ec_public_key *local_identity_key = 0; ec_public_key *remote_identity_key = 0; signal_message *message = 0; pre_key_signal_message *pre_key_message = 0; uint8_t *ciphertext_data = 0; size_t ciphertext_len = 0; assert(cipher); signal_lock(cipher->global_context); if(cipher->inside_callback == 1) { result = SG_ERR_INVAL; goto complete; } result = signal_protocol_session_load_session(cipher->store, &record, cipher->remote_address); if(result < 0) { goto complete; } state = session_record_get_state(record); if(!state) { result = SG_ERR_UNKNOWN; goto complete; } chain_key = session_state_get_sender_chain_key(state); if(!chain_key) { result = SG_ERR_UNKNOWN; goto complete; } result = ratchet_chain_key_get_message_keys(chain_key, &message_keys); if(result < 0) { goto complete; } sender_ephemeral = session_state_get_sender_ratchet_key(state); if(!sender_ephemeral) { result = SG_ERR_UNKNOWN; goto complete; } previous_counter = session_state_get_previous_counter(state); session_version = session_state_get_session_version(state); result = session_cipher_get_ciphertext(cipher, &ciphertext, session_version, &message_keys, padded_message, padded_message_len); if(result < 0) { goto complete; } ciphertext_data = signal_buffer_data(ciphertext); ciphertext_len = signal_buffer_len(ciphertext); chain_key_index = ratchet_chain_key_get_index(chain_key); local_identity_key = session_state_get_local_identity_key(state); if(!local_identity_key) { result = SG_ERR_UNKNOWN; goto complete; } remote_identity_key = session_state_get_remote_identity_key(state); if(!remote_identity_key) { result = SG_ERR_UNKNOWN; goto complete; } result = signal_message_create(&message, session_version, message_keys.mac_key, sizeof(message_keys.mac_key), sender_ephemeral, chain_key_index, previous_counter, ciphertext_data, ciphertext_len, local_identity_key, remote_identity_key, cipher->global_context); if(result < 0) { goto complete; } if(session_state_has_unacknowledged_pre_key_message(state) == 1) { uint32_t local_registration_id = session_state_get_local_registration_id(state); int has_pre_key_id = 0; uint32_t pre_key_id = 0; uint32_t signed_pre_key_id; ec_public_key *base_key; if(session_state_unacknowledged_pre_key_message_has_pre_key_id(state)) { has_pre_key_id = 1; pre_key_id = session_state_unacknowledged_pre_key_message_get_pre_key_id(state); } signed_pre_key_id = session_state_unacknowledged_pre_key_message_get_signed_pre_key_id(state); base_key = session_state_unacknowledged_pre_key_message_get_base_key(state); if(!base_key) { result = SG_ERR_UNKNOWN; goto complete; } result = pre_key_signal_message_create(&pre_key_message, session_version, local_registration_id, (has_pre_key_id ? &pre_key_id : 0), signed_pre_key_id, base_key, local_identity_key, message, cipher->global_context); if(result < 0) { goto complete; } SIGNAL_UNREF(message); message = 0; } result = ratchet_chain_key_create_next(chain_key, &next_chain_key); if(result < 0) { goto complete; } result = session_state_set_sender_chain_key(state, next_chain_key); if(result < 0) { goto complete; } result = signal_protocol_session_store_session(cipher->store, cipher->remote_address, record); complete: if(result >= 0) { if(pre_key_message) { *encrypted_message = (ciphertext_message *)pre_key_message; } else { *encrypted_message = (ciphertext_message *)message; } } else { SIGNAL_UNREF(pre_key_message); SIGNAL_UNREF(message); } signal_buffer_free(ciphertext); SIGNAL_UNREF(next_chain_key); SIGNAL_UNREF(record); signal_explicit_bzero(&message_keys, sizeof(ratchet_message_keys)); signal_unlock(cipher->global_context); return result; } int session_cipher_decrypt_pre_key_signal_message(session_cipher *cipher, pre_key_signal_message *ciphertext, void *decrypt_context, signal_buffer **plaintext) { int result = 0; signal_buffer *result_buf = 0; session_record *record = 0; int has_unsigned_pre_key_id = 0; uint32_t unsigned_pre_key_id = 0; assert(cipher); signal_lock(cipher->global_context); if(cipher->inside_callback == 1) { result = SG_ERR_INVAL; goto complete; } result = signal_protocol_session_load_session(cipher->store, &record, cipher->remote_address); if(result < 0) { goto complete; } result = session_builder_process_pre_key_signal_message(cipher->builder, record, ciphertext, &unsigned_pre_key_id); if(result < 0) { goto complete; } has_unsigned_pre_key_id = result; result = session_cipher_decrypt_from_record_and_signal_message(cipher, record, pre_key_signal_message_get_signal_message(ciphertext), &result_buf); if(result < 0) { goto complete; } result = session_cipher_decrypt_callback(cipher, result_buf, decrypt_context); if(result < 0) { goto complete; } result = signal_protocol_session_store_session(cipher->store, cipher->remote_address, record); if(result < 0) { goto complete; } if(has_unsigned_pre_key_id) { result = signal_protocol_pre_key_remove_key(cipher->store, unsigned_pre_key_id); if(result < 0) { goto complete; } } complete: SIGNAL_UNREF(record); if(result >= 0) { *plaintext = result_buf; } else { signal_buffer_free(result_buf); } signal_unlock(cipher->global_context); return result; } int session_cipher_decrypt_signal_message(session_cipher *cipher, signal_message *ciphertext, void *decrypt_context, signal_buffer **plaintext) { int result = 0; signal_buffer *result_buf = 0; session_record *record = 0; assert(cipher); signal_lock(cipher->global_context); if(cipher->inside_callback == 1) { result = SG_ERR_INVAL; goto complete; } result = signal_protocol_session_contains_session(cipher->store, cipher->remote_address); if(result == 0) { signal_log(cipher->global_context, SG_LOG_WARNING, "No session for: %s:%d", cipher->remote_address->name, cipher->remote_address->device_id); result = SG_ERR_NO_SESSION; goto complete; } else if(result < 0) { goto complete; } result = signal_protocol_session_load_session(cipher->store, &record, cipher->remote_address); if(result < 0) { goto complete; } result = session_cipher_decrypt_from_record_and_signal_message( cipher, record, ciphertext, &result_buf); if(result < 0) { goto complete; } result = session_cipher_decrypt_callback(cipher, result_buf, decrypt_context); if(result < 0) { goto complete; } result = signal_protocol_session_store_session(cipher->store, cipher->remote_address, record); complete: SIGNAL_UNREF(record); if(result >= 0) { *plaintext = result_buf; } else { signal_buffer_free(result_buf); } signal_unlock(cipher->global_context); return result; } static int session_cipher_decrypt_from_record_and_signal_message(session_cipher *cipher, session_record *record, signal_message *ciphertext, signal_buffer **plaintext) { int result = 0; signal_buffer *result_buf = 0; session_state *state = 0; session_state *state_copy = 0; session_record_state_node *previous_states_node = 0; assert(cipher); signal_lock(cipher->global_context); state = session_record_get_state(record); if(state) { result = session_state_copy(&state_copy, state, cipher->global_context); if(result < 0) { goto complete; } //TODO Collect and log invalid message errors if totally unsuccessful result = session_cipher_decrypt_from_state_and_signal_message(cipher, state_copy, ciphertext, &result_buf); if(result < 0 && result != SG_ERR_INVALID_MESSAGE) { goto complete; } if(result >= SG_SUCCESS) { session_record_set_state(record, state_copy); goto complete; } SIGNAL_UNREF(state_copy); } previous_states_node = session_record_get_previous_states_head(record); while(previous_states_node) { state = session_record_get_previous_states_element(previous_states_node); result = session_state_copy(&state_copy, state, cipher->global_context); if(result < 0) { goto complete; } result = session_cipher_decrypt_from_state_and_signal_message(cipher, state_copy, ciphertext, &result_buf); if(result < 0 && result != SG_ERR_INVALID_MESSAGE) { goto complete; } if(result >= SG_SUCCESS) { session_record_get_previous_states_remove(record, previous_states_node); result = session_record_promote_state(record, state_copy); goto complete; } SIGNAL_UNREF(state_copy); previous_states_node = session_record_get_previous_states_next(previous_states_node); } signal_log(cipher->global_context, SG_LOG_WARNING, "No valid sessions"); result = SG_ERR_INVALID_MESSAGE; complete: SIGNAL_UNREF(state_copy); if(result >= 0) { *plaintext = result_buf; } else { signal_buffer_free(result_buf); } signal_unlock(cipher->global_context); return result; } static int session_cipher_decrypt_from_state_and_signal_message(session_cipher *cipher, session_state *state, signal_message *ciphertext, signal_buffer **plaintext) { int result = 0; signal_buffer *result_buf = 0; ec_public_key *their_ephemeral = 0; uint32_t counter = 0; ratchet_chain_key *chain_key = 0; ratchet_message_keys message_keys; uint8_t message_version = 0; uint32_t session_version = 0; ec_public_key *remote_identity_key = 0; ec_public_key *local_identity_key = 0; signal_buffer *ciphertext_body = 0; if(!session_state_has_sender_chain(state)) { signal_log(cipher->global_context, SG_LOG_WARNING, "Uninitialized session!"); result = SG_ERR_INVALID_MESSAGE; goto complete; } message_version = signal_message_get_message_version(ciphertext); session_version = session_state_get_session_version(state); if(message_version != session_version) { signal_log(cipher->global_context, SG_LOG_WARNING, "Message version %d, but session version %d", message_version, session_version); result = SG_ERR_INVALID_MESSAGE; goto complete; } their_ephemeral = signal_message_get_sender_ratchet_key(ciphertext); if(!their_ephemeral) { result = SG_ERR_UNKNOWN; goto complete; } counter = signal_message_get_counter(ciphertext); result = session_cipher_get_or_create_chain_key(cipher, &chain_key, state, their_ephemeral); if(result < 0) { goto complete; } result = session_cipher_get_or_create_message_keys(&message_keys, state, their_ephemeral, chain_key, counter, cipher->global_context); if(result < 0) { goto complete; } remote_identity_key = session_state_get_remote_identity_key(state); if(!remote_identity_key) { result = SG_ERR_UNKNOWN; goto complete; } local_identity_key = session_state_get_local_identity_key(state); if(!local_identity_key) { result = SG_ERR_UNKNOWN; goto complete; } result = signal_message_verify_mac(ciphertext, message_version, remote_identity_key, local_identity_key, message_keys.mac_key, sizeof(message_keys.mac_key), cipher->global_context); if(result != 1) { if(result == 0) { signal_log(cipher->global_context, SG_LOG_WARNING, "Message mac not verified"); result = SG_ERR_INVALID_MESSAGE; } else if(result < 0) { signal_log(cipher->global_context, SG_LOG_WARNING, "Error attempting to verify message mac"); } goto complete; } ciphertext_body = signal_message_get_body(ciphertext); if(!ciphertext_body) { signal_log(cipher->global_context, SG_LOG_WARNING, "Message body does not exist"); result = SG_ERR_INVALID_MESSAGE; goto complete; } result = session_cipher_get_plaintext(cipher, &result_buf, message_version, &message_keys, signal_buffer_data(ciphertext_body), signal_buffer_len(ciphertext_body)); if(result < 0) { goto complete; } session_state_clear_unacknowledged_pre_key_message(state); complete: SIGNAL_UNREF(chain_key); if(result >= 0) { *plaintext = result_buf; } else { signal_buffer_free(result_buf); } signal_explicit_bzero(&message_keys, sizeof(ratchet_message_keys)); return result; } static int session_cipher_get_or_create_chain_key(session_cipher *cipher, ratchet_chain_key **chain_key, session_state *state, ec_public_key *their_ephemeral) { int result = 0; ratchet_chain_key *result_key = 0; ratchet_root_key *receiver_root_key = 0; ratchet_chain_key *receiver_chain_key = 0; ratchet_root_key *sender_root_key = 0; ratchet_chain_key *sender_chain_key = 0; ec_key_pair *our_new_ephemeral = 0; ratchet_root_key *root_key = 0; ec_key_pair *our_ephemeral = 0; ratchet_chain_key *previous_sender_chain_key = 0; uint32_t index = 0; result_key = session_state_get_receiver_chain_key(state, their_ephemeral); if(result_key) { SIGNAL_REF(result_key); goto complete; } root_key = session_state_get_root_key(state); if(!root_key) { result = SG_ERR_UNKNOWN; goto complete; } our_ephemeral = session_state_get_sender_ratchet_key_pair(state); if(!our_ephemeral) { result = SG_ERR_UNKNOWN; goto complete; } result = ratchet_root_key_create_chain(root_key, &receiver_root_key, &receiver_chain_key, their_ephemeral, ec_key_pair_get_private(our_ephemeral)); if(result < 0) { goto complete; } result = curve_generate_key_pair(cipher->global_context, &our_new_ephemeral); if(result < 0) { goto complete; } result = ratchet_root_key_create_chain(receiver_root_key, &sender_root_key, &sender_chain_key, their_ephemeral, ec_key_pair_get_private(our_new_ephemeral)); if(result < 0) { goto complete; } session_state_set_root_key(state, sender_root_key); result = session_state_add_receiver_chain(state, their_ephemeral, receiver_chain_key); if(result < 0) { goto complete; } previous_sender_chain_key = session_state_get_sender_chain_key(state); if(!previous_sender_chain_key) { result = SG_ERR_UNKNOWN; goto complete; } index = ratchet_chain_key_get_index(previous_sender_chain_key); if(index > 0) { --index; } session_state_set_previous_counter(state, index); session_state_set_sender_chain(state, our_new_ephemeral, sender_chain_key); result_key = receiver_chain_key; SIGNAL_REF(result_key); complete: SIGNAL_UNREF(receiver_root_key); SIGNAL_UNREF(receiver_chain_key); SIGNAL_UNREF(sender_root_key); SIGNAL_UNREF(sender_chain_key); SIGNAL_UNREF(our_new_ephemeral); if(result >= 0) { *chain_key = result_key; } else { SIGNAL_UNREF(result_key); } return result; } static int session_cipher_get_or_create_message_keys(ratchet_message_keys *message_keys, session_state *state, ec_public_key *their_ephemeral, ratchet_chain_key *chain_key, uint32_t counter, signal_context *global_context) { int result = 0; ratchet_chain_key *cur_chain_key = 0; ratchet_chain_key *next_chain_key = 0; ratchet_message_keys message_keys_result; if(ratchet_chain_key_get_index(chain_key) > counter) { result = session_state_remove_message_keys(state, &message_keys_result, their_ephemeral, counter); if(result == 1) { result = 0; goto complete; } signal_log(global_context, SG_LOG_WARNING, "Received message with old counter: %d, %d", ratchet_chain_key_get_index(chain_key), counter); result = SG_ERR_DUPLICATE_MESSAGE; goto complete; } if(counter - ratchet_chain_key_get_index(chain_key) > 2000) { signal_log(global_context, SG_LOG_WARNING, "Over 2000 messages into the future!"); result = SG_ERR_INVALID_MESSAGE; goto complete; } cur_chain_key = chain_key; SIGNAL_REF(cur_chain_key); while(ratchet_chain_key_get_index(cur_chain_key) < counter) { result = ratchet_chain_key_get_message_keys(cur_chain_key, &message_keys_result); if(result < 0) { goto complete; } result = session_state_set_message_keys(state, their_ephemeral, &message_keys_result); if(result < 0) { goto complete; } result = ratchet_chain_key_create_next(cur_chain_key, &next_chain_key); if(result < 0) { goto complete; } SIGNAL_UNREF(cur_chain_key); cur_chain_key = next_chain_key; next_chain_key = 0; } result = ratchet_chain_key_create_next(cur_chain_key, &next_chain_key); if(result < 0) { goto complete; } result = session_state_set_receiver_chain_key(state, their_ephemeral, next_chain_key); if(result < 0) { goto complete; } result = ratchet_chain_key_get_message_keys(cur_chain_key, &message_keys_result); if(result < 0) { goto complete; } complete: if(result >= 0) { memcpy(message_keys, &message_keys_result, sizeof(ratchet_message_keys)); } SIGNAL_UNREF(cur_chain_key); SIGNAL_UNREF(next_chain_key); signal_explicit_bzero(&message_keys_result, sizeof(ratchet_message_keys)); return result; } int session_cipher_get_remote_registration_id(session_cipher *cipher, uint32_t *remote_id) { int result = 0; uint32_t id_result = 0; session_record *record = 0; session_state *state = 0; assert(cipher); signal_lock(cipher->global_context); result = signal_protocol_session_load_session(cipher->store, &record, cipher->remote_address); if(result < 0) { goto complete; } state = session_record_get_state(record); if(!state) { result = SG_ERR_UNKNOWN; goto complete; } id_result = session_state_get_remote_registration_id(state); complete: if(result >= 0) { *remote_id = id_result; } signal_unlock(cipher->global_context); return result; } int session_cipher_get_session_version(session_cipher *cipher, uint32_t *version) { int result = 0; uint32_t version_result = 0; session_record *record = 0; session_state *state = 0; assert(cipher); signal_lock(cipher->global_context); result = signal_protocol_session_contains_session(cipher->store, cipher->remote_address); if(result != 1) { if(result == 0) { signal_log(cipher->global_context, SG_LOG_WARNING, "No session for: %s:%d", cipher->remote_address->name, cipher->remote_address->device_id); result = SG_ERR_NO_SESSION; } goto complete; } result = signal_protocol_session_load_session(cipher->store, &record, cipher->remote_address); if(result < 0) { goto complete; } state = session_record_get_state(record); if(!state) { result = SG_ERR_UNKNOWN; goto complete; } version_result = session_state_get_session_version(state); complete: if(result >= 0) { *version = version_result; } signal_unlock(cipher->global_context); return result; } static int session_cipher_get_ciphertext(session_cipher *cipher, signal_buffer **ciphertext, uint32_t version, ratchet_message_keys *message_keys, const uint8_t *plaintext, size_t plaintext_len) { int result = 0; signal_buffer *output = 0; if(version >= 3) { result = signal_encrypt(cipher->global_context, &output, SG_CIPHER_AES_CBC_PKCS5, message_keys->cipher_key, sizeof(message_keys->cipher_key), message_keys->iv, sizeof(message_keys->iv), plaintext, plaintext_len); } else { uint8_t iv[16]; memset(iv, 0, sizeof(iv)); iv[3] = (uint8_t)(message_keys->counter); iv[2] = (uint8_t)(message_keys->counter >> 8); iv[1] = (uint8_t)(message_keys->counter >> 16); iv[0] = (uint8_t)(message_keys->counter >> 24); result = signal_encrypt(cipher->global_context, &output, SG_CIPHER_AES_CTR_NOPADDING, message_keys->cipher_key, sizeof(message_keys->cipher_key), iv, sizeof(iv), plaintext, plaintext_len); } if(result >= 0) { *ciphertext = output; } return result; } static int session_cipher_get_plaintext(session_cipher *cipher, signal_buffer **plaintext, uint32_t version, ratchet_message_keys *message_keys, const uint8_t *ciphertext, size_t ciphertext_len) { int result = 0; signal_buffer *output = 0; if(version >= 3) { result = signal_decrypt(cipher->global_context, &output, SG_CIPHER_AES_CBC_PKCS5, message_keys->cipher_key, sizeof(message_keys->cipher_key), message_keys->iv, sizeof(message_keys->iv), ciphertext, ciphertext_len); } else { uint8_t iv[16]; memset(iv, 0, sizeof(iv)); iv[3] = (uint8_t)(message_keys->counter); iv[2] = (uint8_t)(message_keys->counter >> 8); iv[1] = (uint8_t)(message_keys->counter >> 16); iv[0] = (uint8_t)(message_keys->counter >> 24); result = signal_decrypt(cipher->global_context, &output, SG_CIPHER_AES_CTR_NOPADDING, message_keys->cipher_key, sizeof(message_keys->cipher_key), iv, sizeof(iv), ciphertext, ciphertext_len); } if(result >= 0) { *plaintext = output; } return result; } static int session_cipher_decrypt_callback(session_cipher *cipher, signal_buffer *plaintext, void *decrypt_context) { int result = 0; if(cipher->decrypt_callback) { cipher->inside_callback = 1; result = cipher->decrypt_callback(cipher, plaintext, decrypt_context); cipher->inside_callback = 0; } return result; } void session_cipher_free(session_cipher *cipher) { if(cipher) { session_builder_free(cipher->builder); free(cipher); } }