lz77.c 5.99 KB
#include "hashmap.h"
#include "bitwriter.h"
#include "lz77.h"
#include <assert.h>
#include <stdio.h>
#include <stdlib.h>

const size_t DISTANCE_MAX = 32768;

struct IntQueueNode;
typedef struct IntQueueNode IntQueueNode;

struct IntQueueNode
{
    size_t value;
    IntQueueNode *next;
};

typedef struct IntQueue
{
    uint64_t size;
    IntQueueNode *head;
    IntQueueNode *tail;
} IntQueue;

IntQueue *IntQueue_init()
{
    IntQueue *queue = malloc(sizeof(IntQueue));
    assert(queue);
    queue->size = 0;
    queue->head = NULL;
    queue->tail = NULL;
    return queue;
}

void IntQueue_free(IntQueue *queue)
{
    if (queue->head)
    {
        IntQueueNode *pos = queue->head;
        while (pos)
        {
            IntQueueNode *next = pos->next;
            free(pos);
            pos = next;
        }
    }
    free(queue);
}

IntQueueNode *IntQueue_get_head(IntQueue *queue)
{
    return queue->head;
}

void IntQueue_push(IntQueue *queue, uint64_t value)
{
    IntQueueNode *node = malloc(sizeof(IntQueueNode));
    assert(node);
    node->value = value;
    node->next = NULL;
    queue->size++;

    if (queue->head == NULL && queue->tail == NULL)
    {
        queue->head = node;
        queue->tail = node;
    }
    else
    {
        queue->tail->next = node;
        queue->tail = node;
    }
}

uint64_t IntQueue_peek(IntQueue *queue)
{
    return queue->head->value;
}

void IntQueue_pop(IntQueue *queue)
{
    assert(queue->head);
    IntQueueNode *oldHead = queue->head;
    queue->head = queue->head->next;
    if (!queue->head) queue->tail = NULL;
    queue->size--;
    free(oldHead);
}

bool IntQueue_empty(IntQueue *queue)
{
    return queue->size == 0;
}

size_t write_lz77_stream(FILE *input_file, FILE *output_file)
{
    BitWriter *bitWriter = BitWriter_init(output_file);
    Hashmap *hashMap = Hashmap_init();

    /* read the input file to a buffer */
    /* TODO: use mmap? */
    fseek(input_file, 0, SEEK_END);
    size_t input_file_length = ftell(input_file);
    fseek(input_file, 0, SEEK_SET);
    uint8_t *inp = malloc(input_file_length);
    assert(inp);
    fread(inp, input_file_length, 1, input_file);

    /* write block headers */
    BitWriter_write_bit(bitWriter, 1); /* BFINAL */
    BitWriter_write_bit(bitWriter, 1); /* BTYPE */
    BitWriter_write_bit(bitWriter, 0);

    /* main lz77 loop */
    size_t cursor = 0;
    while  (cursor < input_file_length)
    {
        size_t old_cursor = cursor;

        size_t bytes_remaining = input_file_length - cursor;
        if (bytes_remaining < 3)
        {
            /* just write literal */
            BitWriter_write_alpha(bitWriter, inp[cursor++]);
        }
        else
        {
            uint32_t triple = inp[cursor] + (inp[cursor + 1] << 8) + (inp[cursor + 2] << 16);

            if (Hashmap_contains(hashMap, triple))
            {
                IntQueue *queue = Hashmap_get(hashMap, triple);

                assert(cursor - IntQueue_peek(queue) <= DISTANCE_MAX);

                IntQueueNode *pos = IntQueue_get_head(queue);

                size_t best_dist = 0;
                size_t best_len = 0;

                while (pos)
                {
                    size_t dif = 0;
                    size_t dist = cursor - pos->value;
                                                                 /* take mod to account for case when length longer than distance */
                    while (cursor + dif < input_file_length && inp[cursor + dif] == inp[pos->value + dif % dist] && dif < 258) dif += 1;
                    assert(dif >= 3);
                    if (dif > best_len)
                    {
                        best_dist = dist;
                        best_len = dif;
                    }

                    pos = pos->next;
                }

                /* write the length, distance pair and move on */
                BitWriter_write_length(bitWriter, best_len);
                BitWriter_write_distance(bitWriter, best_dist);
                cursor += best_len;
            }
            else
            {
                BitWriter_write_alpha(bitWriter, inp[cursor++]);
            }

            for (; old_cursor < cursor; old_cursor++)
            {
                /* clean up hash map elements which no longer matter */
                if (old_cursor >= DISTANCE_MAX) 
                {
                    size_t cleanup_loc = old_cursor - DISTANCE_MAX;
                    uint32_t old_triple = inp[cleanup_loc] + (inp[cleanup_loc + 1] << 8) + (inp[cleanup_loc + 2] << 16);
                    assert(Hashmap_contains(hashMap, old_triple));
                    IntQueue *queue = Hashmap_get(hashMap, old_triple);
                    assert(IntQueue_peek(queue) == cleanup_loc);
                    IntQueue_pop(queue);
                    if (IntQueue_empty(queue))
                    {
                        IntQueue_free(queue);
                        Hashmap_delete(hashMap, old_triple);
                    }
                }

                /* add new triple to hash map */
                if (input_file_length - old_cursor > 3)
                {
                    uint32_t new_triple = inp[old_cursor] + (inp[old_cursor + 1] << 8) + (inp[old_cursor + 2] << 16);
                    if (Hashmap_contains(hashMap, new_triple))
                    {
                        IntQueue *queue = Hashmap_get(hashMap, new_triple);
                        IntQueue_push(queue, old_cursor);
                    }
                    else
                    {
                        IntQueue *newQueue = IntQueue_init();
                        IntQueue_push(newQueue, old_cursor);
                        Hashmap_update(hashMap, new_triple, newQueue);
                }
                }
            }
        }
    }


    /* end of block */
    BitWriter_write_bin(bitWriter, 0, 7); /* EOB */


    /* cleanup */

    BitWriter_flush(bitWriter);
    size_t compressed_size = BitWriter_bytes_written(bitWriter);
    BitWriter_free(bitWriter);

    Hashmap_free(hashMap, (void (*)(void *)) IntQueue_free);

    free(inp);

    return compressed_size;
}