#include "sender_key_state.h"

#include <stdlib.h>
#include <string.h>
#include <assert.h>

#include "sender_key.h"
#include "utlist.h"
#include "LocalStorageProtocol.pb-c.h"
#include "signal_protocol_internal.h"

#define MAX_MESSAGE_KEYS 2000

typedef struct sender_message_key_node {
    sender_message_key *key;
    struct sender_message_key_node *prev, *next;
} sender_message_key_node;

struct sender_key_state
{
    signal_type_base base;

    uint32_t key_id;
    sender_chain_key *chain_key;
    ec_public_key *signature_public_key;
    ec_private_key *signature_private_key;
    sender_message_key_node *message_keys_head;

    signal_context *global_context;
};

int sender_key_state_create(sender_key_state **state,
        uint32_t id, sender_chain_key *chain_key,
        ec_public_key *signature_public_key, ec_private_key *signature_private_key,
        signal_context *global_context)
{
    sender_key_state *result = 0;

    if(!chain_key || !signature_public_key) {
        return SG_ERR_INVAL;
    }

    result = malloc(sizeof(sender_key_state));
    if(!result) {
        return SG_ERR_NOMEM;
    }
    memset(result, 0, sizeof(sender_key_state));
    SIGNAL_INIT(result, sender_key_state_destroy);

    result->key_id = id;

    SIGNAL_REF(chain_key);
    result->chain_key = chain_key;

    SIGNAL_REF(signature_public_key);
    result->signature_public_key = signature_public_key;

    if(signature_private_key) {
        SIGNAL_REF(signature_private_key);
        result->signature_private_key = signature_private_key;
    }

    result->global_context = global_context;

    *state = result;
    return 0;
}

int sender_key_state_serialize(signal_buffer **buffer, sender_key_state *state)
{
    int result = 0;
    size_t result_size = 0;
    uint8_t *data;
    size_t len;
    Textsecure__SenderKeyStateStructure *state_structure = 0;
    signal_buffer *result_buf = 0;

    state_structure = malloc(sizeof(Textsecure__SenderKeyStateStructure));
    if(!state_structure) {
        result = SG_ERR_NOMEM;
        goto complete;
    }
    textsecure__sender_key_state_structure__init(state_structure);

    result = sender_key_state_serialize_prepare(state, state_structure);
    if(result < 0) {
        goto complete;
    }

    len = textsecure__sender_key_state_structure__get_packed_size(state_structure);

    result_buf = signal_buffer_alloc(len);
    if(!result_buf) {
        result = SG_ERR_NOMEM;
        goto complete;
    }

    data = signal_buffer_data(result_buf);
    result_size = textsecure__sender_key_state_structure__pack(state_structure, data);
    if(result_size != len) {
        signal_buffer_free(result_buf);
        result = SG_ERR_INVALID_PROTO_BUF;
        result_buf = 0;
        goto complete;
    }

complete:
    if(state_structure) {
        sender_key_state_serialize_prepare_free(state_structure);
    }
    if(result >= 0) {
        *buffer = result_buf;
    }
    return result;
}

int sender_key_state_deserialize(sender_key_state **state, const uint8_t *data, size_t len, signal_context *global_context)
{
    int result = 0;
    Textsecure__SenderKeyStateStructure *state_structure = 0;
    sender_key_state *result_state = 0;

    state_structure = textsecure__sender_key_state_structure__unpack(0, len, data);
    if(!state_structure) {
        result = SG_ERR_INVALID_PROTO_BUF;
        goto complete;
    }

    result = sender_key_state_deserialize_protobuf(&result_state, state_structure, global_context);
    if(result < 0) {
        goto complete;
    }

complete:
    if(state_structure) {
        textsecure__sender_key_state_structure__free_unpacked(state_structure, 0);
    }
    if(result_state) {
        if(result < 0) {
            SIGNAL_UNREF(result_state);
        }
        else {
            *state = result_state;
        }
    }

    return result;
}

int sender_key_state_serialize_prepare(sender_key_state *state, Textsecure__SenderKeyStateStructure *state_structure)
{
    int result = 0;
    size_t i = 0;
    Textsecure__SenderKeyStateStructure__SenderChainKey *chain_key_structure = 0;
    Textsecure__SenderKeyStateStructure__SenderSigningKey *signing_key_structure = 0;
    sender_message_key_node *cur_node = 0;
    signal_buffer *chain_key_seed = 0;

    assert(state);
    assert(state_structure);

    /* Sender key ID */
    state_structure->has_senderkeyid = 1;
    state_structure->senderkeyid = state->key_id;

    /* Sender chain key */
    chain_key_structure = malloc(sizeof(Textsecure__SenderKeyStateStructure__SenderChainKey));
    if(!chain_key_structure) {
        result = SG_ERR_NOMEM;
        goto complete;
    }
    textsecure__sender_key_state_structure__sender_chain_key__init(chain_key_structure);
    state_structure->senderchainkey = chain_key_structure;

    chain_key_structure->iteration = sender_chain_key_get_iteration(state->chain_key);
    chain_key_structure->has_iteration = 1;

    chain_key_seed = sender_chain_key_get_seed(state->chain_key);
    chain_key_structure->seed.data = signal_buffer_data(chain_key_seed);
    chain_key_structure->seed.len = signal_buffer_len(chain_key_seed);
    chain_key_structure->has_seed = 1;

    /* Sender signing key */
    signing_key_structure = malloc(sizeof(Textsecure__SenderKeyStateStructure__SenderSigningKey));
    if(!signing_key_structure) {
        result = SG_ERR_NOMEM;
        goto complete;
    }
    textsecure__sender_key_state_structure__sender_signing_key__init(signing_key_structure);
    state_structure->sendersigningkey = signing_key_structure;

    if(state->signature_public_key) {
        result = ec_public_key_serialize_protobuf(&(signing_key_structure->public_), state->signature_public_key);
        if(result < 0) {
            goto complete;
        }
        signing_key_structure->has_public_ = 1;
    }

    if(state->signature_private_key) {
        result = ec_private_key_serialize_protobuf(&(signing_key_structure->private_), state->signature_private_key);
        if(result < 0) {
            goto complete;
        }
        signing_key_structure->has_private_ = 1;
    }

    /* Sender message keys */
    if(state->message_keys_head) {
        size_t count;
        DL_COUNT(state->message_keys_head, cur_node, count);

        if(count > SIZE_MAX / sizeof(Textsecure__SenderKeyStateStructure__SenderMessageKey *)) {
            result = SG_ERR_NOMEM;
            goto complete;
        }

        state_structure->sendermessagekeys = malloc(sizeof(Textsecure__SenderKeyStateStructure__SenderMessageKey *) * count);
        if(!state_structure->sendermessagekeys) {
            result = SG_ERR_NOMEM;
            goto complete;
        }

        i = 0;
        DL_FOREACH(state->message_keys_head, cur_node) {
            signal_buffer *seed = 0;
            state_structure->sendermessagekeys[i] = malloc(sizeof(Textsecure__SenderKeyStateStructure__SenderMessageKey));
            if(!state_structure->sendermessagekeys[i]) {
                result = SG_ERR_NOMEM;
                break;
            }
            textsecure__sender_key_state_structure__sender_message_key__init(state_structure->sendermessagekeys[i]);

            state_structure->sendermessagekeys[i]->iteration = sender_message_key_get_iteration(cur_node->key);
            state_structure->sendermessagekeys[i]->has_iteration = 1;

            seed = sender_message_key_get_seed(cur_node->key);
            state_structure->sendermessagekeys[i]->seed.data = signal_buffer_data(seed);
            state_structure->sendermessagekeys[i]->seed.len = signal_buffer_len(seed);
            state_structure->sendermessagekeys[i]->has_seed = 1;

            if(result < 0) {
                break;
            }
            i++;
        }
        state_structure->n_sendermessagekeys = i;
        if(result < 0) {
            goto complete;
        }
    }

complete:
    return result;
}

void sender_key_state_serialize_prepare_free(Textsecure__SenderKeyStateStructure *state_structure)
{
    unsigned int i = 0;
    if(state_structure->senderchainkey) {
        free(state_structure->senderchainkey);
    }
    if(state_structure->sendersigningkey) {
        if(state_structure->sendersigningkey->public_.data) {
            free(state_structure->sendersigningkey->public_.data);
        }
        if(state_structure->sendersigningkey->private_.data) {
            free(state_structure->sendersigningkey->private_.data);
        }
        free(state_structure->sendersigningkey);
    }

    if(state_structure->sendermessagekeys) {
        for(i = 0; i < state_structure->n_sendermessagekeys; i++) {
            if(state_structure->sendermessagekeys[i]) {
                free(state_structure->sendermessagekeys[i]);
            }
        }
        free(state_structure->sendermessagekeys);
    }
    free(state_structure);
}

int sender_key_state_deserialize_protobuf(sender_key_state **state, Textsecure__SenderKeyStateStructure *state_structure, signal_context *global_context)
{
    int result = 0;
    sender_key_state *result_state = 0;
    sender_chain_key *chain_key = 0;
    ec_public_key *signature_public_key = 0;
    ec_private_key *signature_private_key = 0;

    if(state_structure->senderchainkey
            && state_structure->senderchainkey->has_iteration
            && state_structure->senderchainkey->has_seed) {
        signal_buffer *seed_buffer = signal_buffer_create(
                state_structure->senderchainkey->seed.data,
                state_structure->senderchainkey->seed.len);
        if(!seed_buffer) {
            result = SG_ERR_NOMEM;
            goto complete;
        }

        result = sender_chain_key_create(&chain_key,
                state_structure->senderchainkey->iteration,
                seed_buffer,
                global_context);
        signal_buffer_free(seed_buffer);
        if(result < 0) {
            goto complete;
        }
    }

    if(state_structure->sendersigningkey) {
        if(state_structure->sendersigningkey->has_public_) {
            result = curve_decode_point(&signature_public_key,
                    state_structure->sendersigningkey->public_.data,
                    state_structure->sendersigningkey->public_.len,
                    global_context);
            if(result < 0) {
                goto complete;
            }
        }
        if(state_structure->sendersigningkey->has_private_) {
            result = curve_decode_private_point(&signature_private_key,
                    state_structure->sendersigningkey->private_.data,
                    state_structure->sendersigningkey->private_.len,
                    global_context);
            if(result < 0) {
                goto complete;
            }
        }
    }

    if(state_structure->has_senderkeyid && chain_key && signature_public_key) {
        unsigned int i;
        result = sender_key_state_create(&result_state,
                state_structure->senderkeyid, chain_key,
                signature_public_key, signature_private_key,
                global_context);
        if(result < 0) {
            goto complete;
        }

        if(state_structure->n_sendermessagekeys > 0) {
            for(i = 0; i < state_structure->n_sendermessagekeys; i++) {
                signal_buffer *seed_buffer;
                sender_message_key *message_key;
                Textsecure__SenderKeyStateStructure__SenderMessageKey *message_key_structure =
                        state_structure->sendermessagekeys[i];

                if(!message_key_structure->has_iteration || !message_key_structure->has_seed) {
                    continue;
                }

                seed_buffer = signal_buffer_create(
                        message_key_structure->seed.data,
                        message_key_structure->seed.len);
                if(!seed_buffer) {
                    result = SG_ERR_NOMEM;
                    goto complete;
                }

                result = sender_message_key_create(&message_key,
                        message_key_structure->iteration, seed_buffer,
                        global_context);
                signal_buffer_free(seed_buffer);
                if(result < 0) {
                    goto complete;
                }

                result = sender_key_state_add_sender_message_key(result_state, message_key);
                if(result < 0) {
                    goto complete;
                }
                SIGNAL_UNREF(message_key);
            }
        }
    }
    else {
        result = SG_ERR_INVALID_PROTO_BUF;
    }

complete:
    if(chain_key) {
        SIGNAL_UNREF(chain_key);
    }
    if(signature_public_key) {
        SIGNAL_UNREF(signature_public_key);
    }
    if(signature_private_key) {
        SIGNAL_UNREF(signature_private_key);
    }
    if(result >= 0) {
        *state = result_state;
    }
    else {
        if(result_state) {
            SIGNAL_UNREF(result_state);
        }
    }
    return result;
}

int sender_key_state_copy(sender_key_state **state, sender_key_state *other_state, signal_context *global_context)
{
    int result = 0;
    signal_buffer *buffer = 0;
    uint8_t *data;
    size_t len;

    assert(other_state);
    assert(global_context);

    result = sender_key_state_serialize(&buffer, other_state);
    if(result < 0) {
        goto complete;
    }

    data = signal_buffer_data(buffer);
    len = signal_buffer_len(buffer);

    result = sender_key_state_deserialize(state, data, len, global_context);
    if(result < 0) {
        goto complete;
    }

complete:
    if(buffer) {
        signal_buffer_free(buffer);
    }
    return result;
}

uint32_t sender_key_state_get_key_id(sender_key_state *state)
{
    assert(state);
    return state->key_id;
}

sender_chain_key *sender_key_state_get_chain_key(sender_key_state *state)
{
    assert(state);
    return state->chain_key;
}

void sender_key_state_set_chain_key(sender_key_state *state, sender_chain_key *chain_key)
{
    assert(state);
    assert(chain_key);

    if(state->chain_key) {
        SIGNAL_UNREF(state->chain_key);
    }
    SIGNAL_REF(chain_key);
    state->chain_key = chain_key;
}

ec_public_key *sender_key_state_get_signing_key_public(sender_key_state *state)
{
    assert(state);
    return state->signature_public_key;
}

ec_private_key *sender_key_state_get_signing_key_private(sender_key_state *state)
{
    assert(state);
    return state->signature_private_key;
}

int sender_key_state_has_sender_message_key(sender_key_state *state, uint32_t iteration)
{
    sender_message_key_node *cur_node = 0;
    assert(state);

    DL_FOREACH(state->message_keys_head, cur_node) {
        if(sender_message_key_get_iteration(cur_node->key) == iteration) {
            return 1;
        }
    }

    return 0;
}

int sender_key_state_add_sender_message_key(sender_key_state *state, sender_message_key *message_key)
{
    int result = 0;
    sender_message_key_node *node = 0;
    int count;
    assert(state);
    assert(message_key);

    node = malloc(sizeof(sender_message_key_node));
    if(!node) {
        result = SG_ERR_NOMEM;
        goto complete;
    }

    SIGNAL_REF(message_key);
    node->key = message_key;
    DL_APPEND(state->message_keys_head, node);

    DL_COUNT(state->message_keys_head, node, count);
    while(count > MAX_MESSAGE_KEYS) {
        node = state->message_keys_head;
        DL_DELETE(state->message_keys_head, node);
        if(node->key) {
            SIGNAL_UNREF(node->key);
        }
        free(node);
        --count;
    }

complete:
    return result;
}

sender_message_key *sender_key_state_remove_sender_message_key(sender_key_state *state, uint32_t iteration)
{
    sender_message_key *result = 0;
    sender_message_key_node *cur_node = 0;
    sender_message_key_node *tmp_node = 0;
    assert(state);

    DL_FOREACH_SAFE(state->message_keys_head, cur_node, tmp_node) {
        if(sender_message_key_get_iteration(cur_node->key) == iteration) {
            DL_DELETE(state->message_keys_head, cur_node);
            result = cur_node->key;
            free(cur_node);
            break;
        }
    }

    return result;
}

void sender_key_state_destroy(signal_type_base *type)
{
    sender_key_state *state = (sender_key_state *)type;
    sender_message_key_node *cur_node;
    sender_message_key_node *tmp_node;

    SIGNAL_UNREF(state->chain_key);
    SIGNAL_UNREF(state->signature_public_key);
    SIGNAL_UNREF(state->signature_private_key);

    DL_FOREACH_SAFE(state->message_keys_head, cur_node, tmp_node) {
        DL_DELETE(state->message_keys_head, cur_node);
        if(cur_node->key) {
            SIGNAL_UNREF(cur_node->key);
        }
        free(cur_node);
    }
    state->message_keys_head = 0;

    free(state);
}