#include "session_record.h" #include <stdlib.h> #include <string.h> #include <assert.h> #include "session_state.h" #include "utlist.h" #include "LocalStorageProtocol.pb-c.h" #include "signal_protocol_internal.h" #define ARCHIVED_STATES_MAX_LENGTH 40 struct session_record_state_node { session_state *state; struct session_record_state_node *prev, *next; }; struct session_record { signal_type_base base; session_state *state; session_record_state_node *previous_states_head; int is_fresh; signal_buffer *user_record; signal_context *global_context; }; static void session_record_free_previous_states(session_record *record); int session_record_create(session_record **record, session_state *state, signal_context *global_context) { session_record *result = malloc(sizeof(session_record)); if(!result) { return SG_ERR_NOMEM; } memset(result, 0, sizeof(session_record)); SIGNAL_INIT(result, session_record_destroy); if(!state) { int ret = session_state_create(&result->state, global_context); if(ret < 0) { SIGNAL_UNREF(result); return ret; } result->is_fresh = 1; } else { SIGNAL_REF(state); result->state = state; result->is_fresh = 0; } result->global_context = global_context; *record = result; return 0; } int session_record_serialize(signal_buffer **buffer, const session_record *record) { int result = 0; size_t result_size = 0; unsigned int i = 0; Textsecure__RecordStructure record_structure = TEXTSECURE__RECORD_STRUCTURE__INIT; session_record_state_node *cur_node = 0; signal_buffer *result_buf = 0; size_t len = 0; uint8_t *data = 0; if(!record) { result = SG_ERR_INVAL; goto complete; } if(record->state) { record_structure.currentsession = malloc(sizeof(Textsecure__SessionStructure)); if(!record_structure.currentsession) { result = SG_ERR_NOMEM; goto complete; } textsecure__session_structure__init(record_structure.currentsession); result = session_state_serialize_prepare(record->state, record_structure.currentsession); if(result < 0) { goto complete; } } if(record->previous_states_head) { size_t count; DL_COUNT(record->previous_states_head, cur_node, count); if(count > SIZE_MAX / sizeof(Textsecure__SessionStructure *)) { result = SG_ERR_NOMEM; goto complete; } record_structure.previoussessions = malloc(sizeof(Textsecure__SessionStructure *) * count); if(!record_structure.previoussessions) { result = SG_ERR_NOMEM; goto complete; } i = 0; DL_FOREACH(record->previous_states_head, cur_node) { record_structure.previoussessions[i] = malloc(sizeof(Textsecure__SessionStructure)); if(!record_structure.previoussessions[i]) { result = SG_ERR_NOMEM; break; } textsecure__session_structure__init(record_structure.previoussessions[i]); result = session_state_serialize_prepare(cur_node->state, record_structure.previoussessions[i]); if(result < 0) { break; } i++; } record_structure.n_previoussessions = i; if(result < 0) { goto complete; } } len = textsecure__record_structure__get_packed_size(&record_structure); result_buf = signal_buffer_alloc(len); if(!result_buf) { result = SG_ERR_NOMEM; goto complete; } data = signal_buffer_data(result_buf); result_size = textsecure__record_structure__pack(&record_structure, data); if(result_size != len) { signal_buffer_free(result_buf); result = SG_ERR_INVALID_PROTO_BUF; result_buf = 0; goto complete; } complete: if(record_structure.currentsession) { session_state_serialize_prepare_free(record_structure.currentsession); } if(record_structure.previoussessions) { for(i = 0; i < record_structure.n_previoussessions; i++) { if(record_structure.previoussessions[i]) { session_state_serialize_prepare_free(record_structure.previoussessions[i]); } } free(record_structure.previoussessions); } if(result >= 0) { *buffer = result_buf; } return result; } int session_record_deserialize(session_record **record, const uint8_t *data, size_t len, signal_context *global_context) { int result = 0; session_record *result_record = 0; session_state *current_state = 0; session_record_state_node *previous_states_head = 0; Textsecure__RecordStructure *record_structure = 0; record_structure = textsecure__record_structure__unpack(0, len, data); if(!record_structure) { result = SG_ERR_INVALID_PROTO_BUF; goto complete; } if(record_structure->currentsession) { result = session_state_deserialize_protobuf(¤t_state, record_structure->currentsession, global_context); if(result < 0) { goto complete; } } result = session_record_create(&result_record, current_state, global_context); if(result < 0) { goto complete; } SIGNAL_UNREF(current_state); current_state = 0; result_record->is_fresh = 0; if(record_structure->n_previoussessions > 0) { unsigned int i; for(i = 0; i < record_structure->n_previoussessions; i++) { Textsecure__SessionStructure *session_structure = record_structure->previoussessions[i]; session_record_state_node *node = malloc(sizeof(session_record_state_node)); if(!node) { result = SG_ERR_NOMEM; goto complete; } result = session_state_deserialize_protobuf(&node->state, session_structure, global_context); if(result < 0) { free(node); goto complete; } DL_APPEND(previous_states_head, node); } } result_record->previous_states_head = previous_states_head; previous_states_head = 0; complete: if(record_structure) { textsecure__record_structure__free_unpacked(record_structure, 0); } if(current_state) { SIGNAL_UNREF(current_state); } if(previous_states_head) { session_record_state_node *cur_node; session_record_state_node *tmp_node; DL_FOREACH_SAFE(previous_states_head, cur_node, tmp_node) { DL_DELETE(previous_states_head, cur_node); free(cur_node); } } if(result_record) { if(result < 0) { SIGNAL_UNREF(result_record); } else { *record = result_record; } } return result; } int session_record_copy(session_record **record, session_record *other_record, signal_context *global_context) { int result = 0; session_record *result_record = 0; signal_buffer *buffer = 0; size_t len = 0; uint8_t *data = 0; assert(other_record); assert(global_context); result = session_record_serialize(&buffer, other_record); if(result < 0) { goto complete; } data = signal_buffer_data(buffer); len = signal_buffer_len(buffer); result = session_record_deserialize(&result_record, data, len, global_context); if(result < 0) { goto complete; } if(other_record->user_record) { result_record->user_record = signal_buffer_copy(other_record->user_record); if(!result_record->user_record) { result = SG_ERR_NOMEM; goto complete; } } complete: if(buffer) { signal_buffer_free(buffer); } if(result >= 0) { *record = result_record; } else { SIGNAL_UNREF(result_record); } return result; } int session_record_has_session_state(session_record *record, uint32_t version, const ec_public_key *alice_base_key) { session_record_state_node *cur_node = 0; assert(record); assert(record->state); if(session_state_get_session_version(record->state) == version && ec_public_key_compare( session_state_get_alice_base_key(record->state), alice_base_key) == 0) { return 1; } DL_FOREACH(record->previous_states_head, cur_node) { if(session_state_get_session_version(cur_node->state) == version && ec_public_key_compare( session_state_get_alice_base_key(cur_node->state), alice_base_key) == 0) { return 1; } } return 0; } session_state *session_record_get_state(session_record *record) { return record->state; } void session_record_set_state(session_record *record, session_state *state) { assert(record); assert(state); if(record->state) { SIGNAL_UNREF(record->state); } SIGNAL_REF(state); record->state = state; } session_record_state_node *session_record_get_previous_states_head(const session_record *record) { assert(record); return record->previous_states_head; } session_state *session_record_get_previous_states_element(const session_record_state_node *node) { assert(node); return node->state; } session_record_state_node *session_record_get_previous_states_next(const session_record_state_node *node) { assert(node); return node->next; } session_record_state_node *session_record_get_previous_states_remove(session_record *record, session_record_state_node *node) { session_record_state_node *next_node = 0; assert(record); assert(node); next_node = node->next; DL_DELETE(record->previous_states_head, node); SIGNAL_UNREF(node->state); free(node); return next_node; } int session_record_is_fresh(session_record *record) { assert(record); return record->is_fresh; } int session_record_archive_current_state(session_record *record) { int result = 0; session_state *new_state = 0; assert(record); result = session_state_create(&new_state, record->global_context); if(result < 0) { goto complete; } result = session_record_promote_state(record, new_state); complete: SIGNAL_UNREF(new_state); return result; } int session_record_promote_state(session_record *record, session_state *promoted_state) { int count = 0; session_record_state_node *cur_node = 0; session_record_state_node *tmp_node = 0; assert(record); assert(promoted_state); // Move the previously current state to the list of previous states if(record->state) { session_record_state_node *node = malloc(sizeof(session_record_state_node)); if(!node) { return SG_ERR_NOMEM; } node->state = record->state; DL_PREPEND(record->previous_states_head, node); record->state = 0; } // Make the promoted state the current state SIGNAL_REF(promoted_state); record->state = promoted_state; // Remove any previous nodes beyond the maximum length DL_FOREACH_SAFE(record->previous_states_head, cur_node, tmp_node) { count++; if(count > ARCHIVED_STATES_MAX_LENGTH) { DL_DELETE(record->previous_states_head, cur_node); if(cur_node->state) { SIGNAL_UNREF(cur_node->state); } free(cur_node); } } return 0; } static void session_record_free_previous_states(session_record *record) { session_record_state_node *cur_node; session_record_state_node *tmp_node; DL_FOREACH_SAFE(record->previous_states_head, cur_node, tmp_node) { DL_DELETE(record->previous_states_head, cur_node); if(cur_node->state) { SIGNAL_UNREF(cur_node->state); } free(cur_node); } record->previous_states_head = 0; } signal_buffer *session_record_get_user_record(const session_record *record) { assert(record); return record->user_record; } void session_record_set_user_record(session_record *record, signal_buffer *user_record) { assert(record); if(record->user_record) { signal_buffer_free(record->user_record); } record->user_record = user_record; } void session_record_destroy(signal_type_base *type) { session_record *record = (session_record *)type; if(record->state) { SIGNAL_UNREF(record->state); } session_record_free_previous_states(record); if(record->user_record) { signal_buffer_free(record->user_record); } free(record); }