#include "sender_key_record.h"

#include <string.h>

#include "sender_key_state.h"
#include "sender_key.h"
#include "utlist.h"
#include "LocalStorageProtocol.pb-c.h"
#include "signal_protocol_internal.h"

#define MAX_STATES 5

typedef struct sender_key_state_node {
    sender_key_state *state;
    struct sender_key_state_node *prev, *next;
} sender_key_state_node;

struct sender_key_record
{
    signal_type_base base;
    sender_key_state_node *sender_key_states_head;
    signal_buffer *user_record;
    signal_context *global_context;
};

int sender_key_record_create(sender_key_record **record,
        signal_context *global_context)
{
    sender_key_record *result = malloc(sizeof(sender_key_record));
    if(!result) {
        return SG_ERR_NOMEM;
    }
    memset(result, 0, sizeof(sender_key_record));
    SIGNAL_INIT(result, sender_key_record_destroy);

    result->global_context = global_context;

    *record = result;
    return 0;
}

int sender_key_record_serialize(signal_buffer **buffer, sender_key_record *record)
{
    int result = 0;
    size_t result_size = 0;
    unsigned int i = 0;
    Textsecure__SenderKeyRecordStructure record_structure = TEXTSECURE__SENDER_KEY_RECORD_STRUCTURE__INIT;
    sender_key_state_node *cur_node = 0;
    signal_buffer *result_buf = 0;
    uint8_t *data;
    size_t len;

    if(record->sender_key_states_head) {
        size_t count;
        DL_COUNT(record->sender_key_states_head, cur_node, count);

        if(count > SIZE_MAX / sizeof(Textsecure__SenderKeyStateStructure *)) {
            result = SG_ERR_NOMEM;
            goto complete;
        }

        record_structure.senderkeystates = malloc(sizeof(Textsecure__SenderKeyStateStructure *) * count);
        if(!record_structure.senderkeystates) {
            result = SG_ERR_NOMEM;
            goto complete;
        }

        i = 0;
        DL_FOREACH(record->sender_key_states_head, cur_node) {
            record_structure.senderkeystates[i] = malloc(sizeof(Textsecure__SenderKeyStateStructure));
            if(!record_structure.senderkeystates[i]) {
                result = SG_ERR_NOMEM;
                break;
            }
            textsecure__sender_key_state_structure__init(record_structure.senderkeystates[i]);

            result = sender_key_state_serialize_prepare(cur_node->state, record_structure.senderkeystates[i]);
            if(result < 0) {
                break;
            }
            i++;
        }
        record_structure.n_senderkeystates = i;
        if(result < 0) {
            goto complete;
        }
    }

    len = textsecure__sender_key_record_structure__get_packed_size(&record_structure);

    result_buf = signal_buffer_alloc(len);
    if(!result_buf) {
        result = SG_ERR_NOMEM;
        goto complete;
    }

    data = signal_buffer_data(result_buf);
    result_size = textsecure__sender_key_record_structure__pack(&record_structure, data);
    if(result_size != len) {
        signal_buffer_free(result_buf);
        result = SG_ERR_INVALID_PROTO_BUF;
        result_buf = 0;
        goto complete;
    }

complete:
    if(record_structure.senderkeystates) {
        for(i = 0; i < record_structure.n_senderkeystates; i++) {
            if(record_structure.senderkeystates[i]) {
                sender_key_state_serialize_prepare_free(record_structure.senderkeystates[i]);
            }
        }
        free(record_structure.senderkeystates);
    }

    if(result >= 0) {
        *buffer = result_buf;
    }
    return result;
}

int sender_key_record_deserialize(sender_key_record **record, const uint8_t *data, size_t len, signal_context *global_context)
{
    int result = 0;
    sender_key_record *result_record = 0;
    Textsecure__SenderKeyRecordStructure *record_structure = 0;

    record_structure = textsecure__sender_key_record_structure__unpack(0, len, data);
    if(!record_structure) {
        result = SG_ERR_INVALID_PROTO_BUF;
        goto complete;
    }

    result = sender_key_record_create(&result_record, global_context);
    if(result < 0) {
        goto complete;
    }

    if(record_structure->n_senderkeystates > 0) {
        unsigned int i;
        sender_key_state_node *state_node = 0;
        sender_key_state *state_element = 0;
        for(i = 0; i < record_structure->n_senderkeystates; i++) {
            result = sender_key_state_deserialize_protobuf(&state_element, record_structure->senderkeystates[i], global_context);
            if(result < 0) {
                goto complete;
            }

            state_node = malloc(sizeof(sender_key_state_node));
            if(!state_node) {
                result = SG_ERR_NOMEM;
                goto complete;
            }

            state_node->state = state_element;
            DL_APPEND(result_record->sender_key_states_head, state_node);
        }
    }

complete:
    if(record_structure) {
        textsecure__sender_key_record_structure__free_unpacked(record_structure, 0);
    }
    if(result_record) {
        if(result < 0) {
            SIGNAL_UNREF(result_record);
        }
        else {
            *record = result_record;
        }
    }
    return result;
}

int sender_key_record_copy(sender_key_record **record, sender_key_record *other_record, signal_context *global_context)
{
    int result = 0;
    sender_key_record *result_record = 0;
    signal_buffer *buffer = 0;
    uint8_t *data;
    size_t len;

    assert(other_record);
    assert(global_context);

    result = sender_key_record_serialize(&buffer, other_record);
    if(result < 0) {
        goto complete;
    }

    data = signal_buffer_data(buffer);
    len = signal_buffer_len(buffer);

    result = sender_key_record_deserialize(&result_record, data, len, global_context);
    if(result < 0) {
        goto complete;
    }
    if(other_record->user_record) {
        result_record->user_record = signal_buffer_copy(other_record->user_record);
        if(!result_record->user_record) {
            result = SG_ERR_NOMEM;
            goto complete;
        }
    }

complete:
    if(buffer) {
        signal_buffer_free(buffer);
    }
    if(result >= 0) {
        *record = result_record;
    }
    else {
        SIGNAL_UNREF(result_record);
    }
    return result;
}

int sender_key_record_is_empty(sender_key_record *record)
{
    assert(record);
    if(record->sender_key_states_head) {
        return 0;
    }
    else {
        return 1;
    }
}

int sender_key_record_get_sender_key_state(sender_key_record *record, sender_key_state **state)
{
    assert(record);
    if(record->sender_key_states_head) {
        *state = record->sender_key_states_head->state;
        return 0;
    }
    else {
        signal_log(record->global_context, SG_LOG_ERROR, "No key state in record!");
        return SG_ERR_INVALID_KEY_ID;
    }
}

int sender_key_record_get_sender_key_state_by_id(sender_key_record *record, sender_key_state **state, uint32_t key_id)
{
    sender_key_state_node *cur_node;
    assert(record);

    DL_FOREACH(record->sender_key_states_head, cur_node) {
        if(sender_key_state_get_key_id(cur_node->state) == key_id) {
            *state = cur_node->state;
            return 0;
        }
    }

    signal_log(record->global_context, SG_LOG_ERROR, "No keys for: %d", key_id);
    return SG_ERR_INVALID_KEY_ID;
}

static int sender_key_record_add_sender_key_state_impl(sender_key_record *record,
        uint32_t id, uint32_t iteration, signal_buffer *chain_key,
        ec_public_key *signature_public_key, ec_private_key *signature_private_key)
{
    int result = 0;
    sender_chain_key *chain_key_element = 0;
    sender_key_state *state = 0;
    sender_key_state_node *state_node = 0;
    int count;
    assert(record);

    result = sender_chain_key_create(&chain_key_element, iteration, chain_key, record->global_context);
    if(result < 0) {
        goto complete;
    }

    result = sender_key_state_create(&state, id, chain_key_element,
            signature_public_key, signature_private_key,
            record->global_context);
    if(result < 0) {
        goto complete;
    }

    state_node = malloc(sizeof(sender_key_state_node));
    if(!state_node) {
        result = SG_ERR_NOMEM;
        goto complete;
    }

    state_node->state = state;
    DL_PREPEND(record->sender_key_states_head, state_node);

    DL_COUNT(record->sender_key_states_head, state_node, count);
    while(count > MAX_STATES) {
        state_node = record->sender_key_states_head->prev;
        DL_DELETE(record->sender_key_states_head, state_node);
        if(state_node->state) {
            SIGNAL_UNREF(state_node->state);
        }
        free(state_node);
        --count;
    }

complete:
    SIGNAL_UNREF(chain_key_element);
    if(result < 0) {
        SIGNAL_UNREF(state);
    }
    return result;
}

int sender_key_record_add_sender_key_state(sender_key_record *record,
        uint32_t id, uint32_t iteration, signal_buffer *chain_key, ec_public_key *signature_key)
{
    int result = sender_key_record_add_sender_key_state_impl(
            record, id, iteration, chain_key, signature_key, 0);
    return result;
}

int sender_key_record_set_sender_key_state(sender_key_record *record,
        uint32_t id, uint32_t iteration, signal_buffer *chain_key, ec_key_pair *signature_key_pair)
{
    int result = 0;
    sender_key_state_node *cur_node;
    sender_key_state_node *tmp_node;
    assert(record);

    DL_FOREACH_SAFE(record->sender_key_states_head, cur_node, tmp_node) {
        DL_DELETE(record->sender_key_states_head, cur_node);
        if(cur_node->state) {
            SIGNAL_UNREF(cur_node->state);
        }
        free(cur_node);
    }
    record->sender_key_states_head = 0;

    result = sender_key_record_add_sender_key_state_impl(
            record, id, iteration, chain_key,
            ec_key_pair_get_public(signature_key_pair),
            ec_key_pair_get_private(signature_key_pair));
    return result;
}

signal_buffer *sender_key_record_get_user_record(const sender_key_record *record)
{
    assert(record);
    return record->user_record;
}

void sender_key_record_set_user_record(sender_key_record *record, signal_buffer *user_record)
{
    assert(record);
    if(record->user_record) {
        signal_buffer_free(record->user_record);
    }
    record->user_record = user_record;
}

void sender_key_record_destroy(signal_type_base *type)
{
    sender_key_record *record = (sender_key_record *)type;
    sender_key_state_node *cur_node;
    sender_key_state_node *tmp_node;

    DL_FOREACH_SAFE(record->sender_key_states_head, cur_node, tmp_node) {
        DL_DELETE(record->sender_key_states_head, cur_node);
        if(cur_node->state) {
            SIGNAL_UNREF(cur_node->state);
        }
        free(cur_node);
    }
    record->sender_key_states_head = 0;

    if(record->user_record) {
        signal_buffer_free(record->user_record);
    }

    free(record);
}