diff options
Diffstat (limited to 'libs/libaxolotl/src/signal_protocol.c')
-rw-r--r-- | libs/libaxolotl/src/signal_protocol.c | 55 |
1 files changed, 51 insertions, 4 deletions
diff --git a/libs/libaxolotl/src/signal_protocol.c b/libs/libaxolotl/src/signal_protocol.c index 11b0c51645..d9ea5b5921 100644 --- a/libs/libaxolotl/src/signal_protocol.c +++ b/libs/libaxolotl/src/signal_protocol.c @@ -20,6 +20,8 @@ int type_ref_count = 0; int type_unref_count = 0; #endif +#define MIN(a,b) (((a)<(b))?(a):(b)) + struct signal_protocol_store_context { signal_context *global_context; signal_protocol_session_store session_store; @@ -107,6 +109,12 @@ signal_buffer *signal_buffer_copy(const signal_buffer *buffer) return signal_buffer_create(buffer->data, buffer->len); } +signal_buffer *signal_buffer_n_copy(const signal_buffer *buffer, size_t n) +{ + size_t len = MIN(buffer->len, n); + return signal_buffer_create(buffer->data, len); +} + signal_buffer *signal_buffer_append(signal_buffer *buffer, const uint8_t *data, size_t len) { signal_buffer *tmp_buffer; @@ -132,7 +140,12 @@ uint8_t *signal_buffer_data(signal_buffer *buffer) return buffer->data; } -size_t signal_buffer_len(signal_buffer *buffer) +const uint8_t *signal_buffer_const_data(const signal_buffer *buffer) +{ + return buffer->data; +} + +size_t signal_buffer_len(const signal_buffer *buffer) { return buffer->len; } @@ -229,7 +242,7 @@ signal_buffer_list *signal_buffer_list_copy(const signal_buffer_list *list) for(i = 0; i < list_size; i++) { signal_buffer **buffer = (signal_buffer**)utarray_eltptr(list->values, i); buffer_copy = signal_buffer_copy(*buffer); - utarray_push_back(list->values, &buffer_copy); + utarray_push_back(result_list->values, &buffer_copy); buffer_copy = 0; } @@ -700,13 +713,14 @@ int signal_protocol_session_load_session(signal_protocol_store_context *context, { int result = 0; signal_buffer *buffer = 0; + signal_buffer *user_buffer = 0; session_record *result_record = 0; assert(context); assert(context->session_store.load_session_func); result = context->session_store.load_session_func( - &buffer, address, + &buffer, &user_buffer, address, context->session_store.user_data); if(result < 0) { goto complete; @@ -736,8 +750,14 @@ complete: signal_buffer_free(buffer); } if(result >= 0) { + if(user_buffer) { + session_record_set_user_record(result_record, user_buffer); + } *record = result_record; } + else { + signal_buffer_free(user_buffer); + } return result; } @@ -755,6 +775,9 @@ int signal_protocol_session_store_session(signal_protocol_store_context *context { int result = 0; signal_buffer *buffer = 0; + signal_buffer *user_buffer = 0; + uint8_t *user_buffer_data = 0; + size_t user_buffer_len = 0; assert(context); assert(context->session_store.store_session_func); @@ -765,9 +788,16 @@ int signal_protocol_session_store_session(signal_protocol_store_context *context goto complete; } + user_buffer = session_record_get_user_record(record); + if(user_buffer) { + user_buffer_data = signal_buffer_data(user_buffer); + user_buffer_len = signal_buffer_len(user_buffer); + } + result = context->session_store.store_session_func( address, signal_buffer_data(buffer), signal_buffer_len(buffer), + user_buffer_data, user_buffer_len, context->session_store.user_data); complete: @@ -1114,6 +1144,9 @@ int signal_protocol_sender_key_store_key(signal_protocol_store_context *context, { int result = 0; signal_buffer *buffer = 0; + signal_buffer *user_buffer = 0; + uint8_t *user_buffer_data = 0; + size_t user_buffer_len = 0; assert(context); assert(context->sender_key_store.store_sender_key); @@ -1124,9 +1157,16 @@ int signal_protocol_sender_key_store_key(signal_protocol_store_context *context, goto complete; } + user_buffer = sender_key_record_get_user_record(record); + if(user_buffer) { + user_buffer_data = signal_buffer_data(user_buffer); + user_buffer_len = signal_buffer_len(user_buffer); + } + result = context->sender_key_store.store_sender_key( sender_key_name, signal_buffer_data(buffer), signal_buffer_len(buffer), + user_buffer_data, user_buffer_len, context->sender_key_store.user_data); complete: @@ -1141,13 +1181,14 @@ int signal_protocol_sender_key_load_key(signal_protocol_store_context *context, { int result = 0; signal_buffer *buffer = 0; + signal_buffer *user_buffer = 0; sender_key_record *result_record = 0; assert(context); assert(context->sender_key_store.load_sender_key); result = context->sender_key_store.load_sender_key( - &buffer, sender_key_name, + &buffer, &user_buffer, sender_key_name, context->sender_key_store.user_data); if(result < 0) { goto complete; @@ -1177,7 +1218,13 @@ complete: signal_buffer_free(buffer); } if(result >= 0) { + if(user_buffer) { + sender_key_record_set_user_record(result_record, user_buffer); + } *record = result_record; } + else { + signal_buffer_free(user_buffer); + } return result; } |