summaryrefslogtreecommitdiff
path: root/libs/libaxolotl/src/signal_protocol.c
diff options
context:
space:
mode:
Diffstat (limited to 'libs/libaxolotl/src/signal_protocol.c')
-rw-r--r--libs/libaxolotl/src/signal_protocol.c55
1 files changed, 51 insertions, 4 deletions
diff --git a/libs/libaxolotl/src/signal_protocol.c b/libs/libaxolotl/src/signal_protocol.c
index 11b0c51645..d9ea5b5921 100644
--- a/libs/libaxolotl/src/signal_protocol.c
+++ b/libs/libaxolotl/src/signal_protocol.c
@@ -20,6 +20,8 @@ int type_ref_count = 0;
int type_unref_count = 0;
#endif
+#define MIN(a,b) (((a)<(b))?(a):(b))
+
struct signal_protocol_store_context {
signal_context *global_context;
signal_protocol_session_store session_store;
@@ -107,6 +109,12 @@ signal_buffer *signal_buffer_copy(const signal_buffer *buffer)
return signal_buffer_create(buffer->data, buffer->len);
}
+signal_buffer *signal_buffer_n_copy(const signal_buffer *buffer, size_t n)
+{
+ size_t len = MIN(buffer->len, n);
+ return signal_buffer_create(buffer->data, len);
+}
+
signal_buffer *signal_buffer_append(signal_buffer *buffer, const uint8_t *data, size_t len)
{
signal_buffer *tmp_buffer;
@@ -132,7 +140,12 @@ uint8_t *signal_buffer_data(signal_buffer *buffer)
return buffer->data;
}
-size_t signal_buffer_len(signal_buffer *buffer)
+const uint8_t *signal_buffer_const_data(const signal_buffer *buffer)
+{
+ return buffer->data;
+}
+
+size_t signal_buffer_len(const signal_buffer *buffer)
{
return buffer->len;
}
@@ -229,7 +242,7 @@ signal_buffer_list *signal_buffer_list_copy(const signal_buffer_list *list)
for(i = 0; i < list_size; i++) {
signal_buffer **buffer = (signal_buffer**)utarray_eltptr(list->values, i);
buffer_copy = signal_buffer_copy(*buffer);
- utarray_push_back(list->values, &buffer_copy);
+ utarray_push_back(result_list->values, &buffer_copy);
buffer_copy = 0;
}
@@ -700,13 +713,14 @@ int signal_protocol_session_load_session(signal_protocol_store_context *context,
{
int result = 0;
signal_buffer *buffer = 0;
+ signal_buffer *user_buffer = 0;
session_record *result_record = 0;
assert(context);
assert(context->session_store.load_session_func);
result = context->session_store.load_session_func(
- &buffer, address,
+ &buffer, &user_buffer, address,
context->session_store.user_data);
if(result < 0) {
goto complete;
@@ -736,8 +750,14 @@ complete:
signal_buffer_free(buffer);
}
if(result >= 0) {
+ if(user_buffer) {
+ session_record_set_user_record(result_record, user_buffer);
+ }
*record = result_record;
}
+ else {
+ signal_buffer_free(user_buffer);
+ }
return result;
}
@@ -755,6 +775,9 @@ int signal_protocol_session_store_session(signal_protocol_store_context *context
{
int result = 0;
signal_buffer *buffer = 0;
+ signal_buffer *user_buffer = 0;
+ uint8_t *user_buffer_data = 0;
+ size_t user_buffer_len = 0;
assert(context);
assert(context->session_store.store_session_func);
@@ -765,9 +788,16 @@ int signal_protocol_session_store_session(signal_protocol_store_context *context
goto complete;
}
+ user_buffer = session_record_get_user_record(record);
+ if(user_buffer) {
+ user_buffer_data = signal_buffer_data(user_buffer);
+ user_buffer_len = signal_buffer_len(user_buffer);
+ }
+
result = context->session_store.store_session_func(
address,
signal_buffer_data(buffer), signal_buffer_len(buffer),
+ user_buffer_data, user_buffer_len,
context->session_store.user_data);
complete:
@@ -1114,6 +1144,9 @@ int signal_protocol_sender_key_store_key(signal_protocol_store_context *context,
{
int result = 0;
signal_buffer *buffer = 0;
+ signal_buffer *user_buffer = 0;
+ uint8_t *user_buffer_data = 0;
+ size_t user_buffer_len = 0;
assert(context);
assert(context->sender_key_store.store_sender_key);
@@ -1124,9 +1157,16 @@ int signal_protocol_sender_key_store_key(signal_protocol_store_context *context,
goto complete;
}
+ user_buffer = sender_key_record_get_user_record(record);
+ if(user_buffer) {
+ user_buffer_data = signal_buffer_data(user_buffer);
+ user_buffer_len = signal_buffer_len(user_buffer);
+ }
+
result = context->sender_key_store.store_sender_key(
sender_key_name,
signal_buffer_data(buffer), signal_buffer_len(buffer),
+ user_buffer_data, user_buffer_len,
context->sender_key_store.user_data);
complete:
@@ -1141,13 +1181,14 @@ int signal_protocol_sender_key_load_key(signal_protocol_store_context *context,
{
int result = 0;
signal_buffer *buffer = 0;
+ signal_buffer *user_buffer = 0;
sender_key_record *result_record = 0;
assert(context);
assert(context->sender_key_store.load_sender_key);
result = context->sender_key_store.load_sender_key(
- &buffer, sender_key_name,
+ &buffer, &user_buffer, sender_key_name,
context->sender_key_store.user_data);
if(result < 0) {
goto complete;
@@ -1177,7 +1218,13 @@ complete:
signal_buffer_free(buffer);
}
if(result >= 0) {
+ if(user_buffer) {
+ sender_key_record_set_user_record(result_record, user_buffer);
+ }
*record = result_record;
}
+ else {
+ signal_buffer_free(user_buffer);
+ }
return result;
}