#include <stdio.h>
#include <assert.h>
#include <string.h>
#include <stdint.h>
#include <stdlib.h>
#include <stdbool.h>

#include <inflate.h>

/* Types of huffman codes */ 
#define FIXED 1
#define DYNAMIC 2

/* Dummy value for min_codes in huffman_t structs */
#define NO_CODE -1

/*
 * Constants for FIXED code type
 */ 

/* Alphabet for codes of length 7, 8, and 9, in order 
 * Not really intended for use; use HUFFMAN_FIXED instead. 
 */ 

int _FIXED_7[24] = {256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 
                    269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279};

int _FIXED_8[152] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 
                     17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 
                     31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 
                     45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 
                     59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 
                     73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 
                     87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 
                     101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 
                     112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 
                     123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 
                     134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 280, 
                     281, 282, 283, 284, 285, 286, 287};

int _FIXED_9[112] = {144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 
                     155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 
                     166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 
                     177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 
                     188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 
                     199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 
                     210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 
                     221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 
                     232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 
                     243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 
                     254, 255
}; 

/* Note that we only have codes of length 7, 8, 9 */ 
huffman_t HUFFMAN_FIXED = {
    .bl_counts = {0, 0, 0, 0, 0, 0, 0, 24, 152, 112, 0, 0, 0, 0, 0},

    .alphabet = {NULL, NULL, NULL, NULL, NULL, NULL, NULL, 
                 _FIXED_7, _FIXED_8, _FIXED_9,
                 NULL, NULL, NULL, NULL, NULL},

    .min_codes = {NO_CODE, NO_CODE, NO_CODE, NO_CODE, NO_CODE, NO_CODE, NO_CODE, 
                  0, 48, 400, 
                  NO_CODE, NO_CODE, NO_CODE, NO_CODE, NO_CODE, NO_CODE} 
};


/* Number of characters in the code-length alphabet */ 
#define N_CL_ALPHABET 19

/* Number of characters in the distance alphabet */
#define N_DISTS 30

int DIST_ALPHABET[N_DISTS] = {
    0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20,
    21, 22, 23, 24, 25, 26, 27, 28, 29
};

huffman_t HUFFMAN_FIXED_DISTS = {
    .bl_counts = {0, 0, 0, 0, 0, 30, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},

    .alphabet = {NULL, NULL, NULL, NULL, NULL, DIST_ALPHABET, NULL, NULL,
                 NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL},

    .min_codes = {NO_CODE, NO_CODE, NO_CODE, NO_CODE, NO_CODE, 0,
                  NO_CODE, NO_CODE, NO_CODE, NO_CODE, NO_CODE, NO_CODE,
                  NO_CODE, NO_CODE, NO_CODE, NO_CODE}
};

/* Conversion tables for lengths
 * To index into lengths, use (value_read - LENGTH_OFFSET)
 * We may have to read additional bits; check LEN_ADDITIONAL for how many 
 */
#define LENGTH_OFFSET 257
int LEN_TABLE[29] = {
    3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31, 35, 43, 51, 
    59, 67, 83, 99, 115, 131, 163, 195, 227, 258
};

int LEN_ADDITIONAL[29] = {
    0, 0, 0, 0, 0, 0, 0, 0,
    1, 1, 1, 1, 2, 2, 2, 2,
    3, 3, 3, 3, 4, 4, 4, 4,
    5, 5, 5, 5, 0
};

/* Conversion tables for distance codes
 * This can be indexed into directly with the distance code. 
 * Again, we may have to read additional bits to get the distance. 
 */
int DIST_TABLE[30] = { 
    1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193, 257, 385, 513,
    769, 1025, 1537, 2049, 3073, 4097, 6145, 8193, 12289, 16385, 24577    
};
int DIST_ADDITIONAL[30] = { 
    0, 0, 0, 0, 1, 1, 2, 2,
    3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 
    9, 9, 10, 10, 11, 11, 12, 12, 13, 13
};

/*
 * Constants for DYNAMIC code type 
 */

/* Number of characters in the literal-length alphabet */
#define N_LITERALS 286

/* After the dynamic block header, a sequence of lengths
 * occurs; the lengths correspond to characters in the 
 * code-length alphabet in this order. 
 */ 
int CL_ORDER[N_CL_ALPHABET] = {
    16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15
};

/* Keep track of the bit we are at. 
 * DO NOT MODIFY ANYWHERE EXCEPT get_next_bit */ 
int _CUR_BIT = 0;

/* Resets _CUR_BIT to 0. */
void reset_pos() {
    _CUR_BIT = 0; 
}

int get_next_bit(char *buf) {
    int buf_pos = _CUR_BIT / 8;
    int pos = _CUR_BIT - 8 * buf_pos; 

    char byte = buf[buf_pos];
    char mask = 1 << pos;

    _CUR_BIT += 1;

    return ((byte & mask) != 0);
}

int get_n_bits(char *buf, int n, bool reverse) {
    int res = 0;
    for (int i = 0; i < n; i++) {
        if (!reverse) { 
            res += get_next_bit(buf) << (n - i - 1);
        } else { 
            res += get_next_bit(buf) << i;
        }
    }
    return res;
}

int read_chunk(char *buf, huffman_t hf) { 
    int code = 0; 
    /* There should be no codes of length 0, so we can start at 1 */ 
    for (int i = 1; i < MAX_LENGTH + 1; i++) {
        /* We haven't matched up to this point; shift left, and read another bit */ 
        code <<= 1;
        code += get_next_bit(buf);

        /* Do we have codes of length i? */
        if (hf.alphabet[i] != NULL) {
            /* Index into hf.alphabet[i], i.e. values for codes of length i */ 
            int idx = (code - hf.min_codes[i]);
            /* If this is a valid index, then we have a match of length i */ 
            if (idx < hf.bl_counts[i]) {
                return hf.alphabet[i][idx]; 
            }
        }
    }
    
    /* Something has gone wrong if we reach here */ 
    fprintf(stderr, "error: could not match length\n");
    exit(1); 
}

huffman_t *make_huffman(int *lens, int n_symbols) {
    huffman_t *hf = (huffman_t *) malloc(sizeof(huffman_t));
    /* Initialize fields */ 
    for (int i = 0; i < MAX_LENGTH + 1; i++) {
        hf->bl_counts[i] = 0;
        hf->alphabet[i] = NULL;
        hf->min_codes[i] = NO_CODE;
    }

    /* Update bl_counts with the lens */ 
    for (int i = 0; i < n_symbols; i++) {
        int len = lens[i];
        if (len != 0) {
            hf->bl_counts[len] += 1;  
        } 
    }

    /*
     * Construct min_codes, the value of the smallest code for each 
     * code length
     */
    int code = 0;
    for (int i = 1; i < MAX_LENGTH + 1; i++) {
        code = (code + hf->bl_counts[i - 1]) << 1;
        if (hf->bl_counts[i] != 0) hf->min_codes[i] = code; 
    }

    /*
     * Start assigning codes to alphabet characters. 
     */
    for (int len = 0; len < MAX_LENGTH + 1; len++) {
        if (hf->bl_counts[len] != 0) {
            int *symbols = (int *) malloc(hf->bl_counts[len] * sizeof(int));
            int cur = 0; 

            for (int sym = 0; sym < n_symbols; sym++) {
                if (lens[sym] == len) {
                    symbols[cur] = sym; 
                    cur += 1;         
                }
            }
            hf->alphabet[len] = symbols;
        }
    }
    
    return hf;
}

void destroy_huffman(huffman_t *hf) {
    for (int i = 0; i < MAX_LENGTH + 1; i++) {
        if (hf->alphabet[i] != NULL) free(hf->alphabet[i]); 
    }
    free(hf); 
}

int *read_lens(char *buf, huffman_t *hf_codes, int num_symbols, int num_codes) {
    int *lens = (int *) malloc(num_symbols * sizeof(int)); 
    for (int i = 0; i < num_symbols; i++) lens[i] = 0; 

    int cur_lit = 0; 

    while (cur_lit < num_codes) {
        int len = read_chunk(buf, *hf_codes);
        if (len <= 15) { 
            lens[cur_lit] = len;
            cur_lit += 1;
        } else if (len == 16) {
            int rep = get_n_bits(buf, 2, true) + 3;
            int prev = lens[cur_lit - 1];

            for (int j = 0; j < rep; j++) { 
                lens[cur_lit] = prev;
                cur_lit += 1; 
            }
        } else if (len == 17) {
            int rep = get_n_bits(buf, 3, true) + 3;
            for (int j = 0; j < rep; j++) {
                lens[cur_lit] = 0; 
                cur_lit += 1; 
            }
        } else if (len == 18) { 
            int rep = get_n_bits(buf, 7, true) + 11;
            for (int j = 0; j < rep; j++) {
                lens[cur_lit] = 0; 
                cur_lit += 1; 
            }
        } else { 
            fprintf(stderr, "error: invalid length"); 
            exit(1); 
        }
    }
    return lens; 
}

void read_block(char *buf, FILE *out) {
    /* First bit is the BFINAL flag */ 

    bool bfinal = get_next_bit(buf);

    /* Next two bits are BTYPE */ 
    int btype = get_n_bits(buf, 2, true); 

    /* By default, use fixed mapping */ 
    huffman_t *hf = &HUFFMAN_FIXED; 
    huffman_t *hf_dist = &HUFFMAN_FIXED_DISTS; 

    if (btype == DYNAMIC) {  
        /* Read dynamic block header */
        int hlit = get_n_bits(buf, 5, true);
        int hdist = get_n_bits(buf, 5, true);
        int hclen = get_n_bits(buf, 4, true); 

        /* Read in the code lengths for the code-length alphabet */
        int *lens_cl = (int *) malloc(N_CL_ALPHABET * sizeof(int));
        for (int i = 0; i < N_CL_ALPHABET; i++) lens_cl[i] = 0; 

        for (int i = 0; i < hclen + 4; i++) {
            int len = get_n_bits(buf, 3, true); 
            int cl = CL_ORDER[i]; 
            lens_cl[cl] = len; 
        }

        huffman_t *hf_codes = make_huffman(lens_cl, N_CL_ALPHABET); 
        
        /* Now read in the code lengths for the literal alphabet */
        int *lens_lit = read_lens(buf, hf_codes, N_LITERALS, hlit + 257); 
        
        /* Huffman mapping for the length-literal alphabet */ 
        huffman_t *hf_lit = make_huffman(lens_lit, N_LITERALS); 

        /* Read in the lengths for the distance alphabet */ 
        int *lens_dist = read_lens(buf, hf_codes, N_DISTS, hdist + 1); 
        
        /* Huffman mapping for the distance code alphabet */
        huffman_t *hf_dist_new = make_huffman(lens_dist, N_DISTS); 

        /* Cleanup */
        destroy_huffman(hf_codes); 
        free(lens_cl); 
        free(lens_lit); 
        free(lens_dist); 

        hf = hf_lit;
        hf_dist = hf_dist_new;

    } else if (btype != FIXED) { 
        fprintf(stderr, "error: unrecognized btype\n");
        exit(1);
    }

    /* Read to decode when we reach this point. */ 
    int chunk_val;

    do 
    { 
        /* Match a huffman code */ 
        chunk_val = read_chunk(buf, *hf); 
        printf("%d\n", chunk_val); 
        /* Literal, just write to output buffer */ 
        if (chunk_val < 256) {
            fwrite(&chunk_val, 1, sizeof(char), out); 
        /* We read a length */ 
        } else if (chunk_val > 256) {
            int length = LEN_TABLE[chunk_val - LENGTH_OFFSET]; 
            int addit_len = LEN_ADDITIONAL[chunk_val - LENGTH_OFFSET]; 
            length += get_n_bits(buf, addit_len, true); 

            /* Read in the distance code */ 
            int dist_code = read_chunk(buf, *hf_dist); 
            assert(dist_code != -1); 
            int dist = DIST_TABLE[dist_code];
            int addit_dist = DIST_ADDITIONAL[dist_code];
            dist += get_n_bits(buf, addit_dist, true);
            for (int i = 0; i < length; i++) { 
                fseek(out, -dist, SEEK_CUR);
                int val = fgetc(out);
                fseek(out, 0, SEEK_END);
                fwrite(&val, 1, sizeof(char), out); 
            }
        }
    }
    while (chunk_val != 256); 

    /* Skip over filler at the end of the block */ 
    while (_CUR_BIT % 8 != 0) {
        _CUR_BIT += 1;
    }

    if (btype == DYNAMIC) {
        destroy_huffman(hf);
        destroy_huffman(hf_dist); 
    } 
}

void truncate_suffix(char *fname) {
    char *dot = strrchr(fname, '.'); 

    const char *suffix = ".deflate";

    for (int i = 0; i < 8; i++) {
        if (dot == NULL || *(dot + i) != suffix[i]) { 
            fprintf(stderr, "error: must be a .deflate file"); 
            exit(1);
        }
    }
    *dot = '\0';
}

void inflate(FILE *fp, char *fname) { 
    truncate_suffix(fname);

    /* Create output file */
    FILE *out = fopen(fname, "wb+");
    
    /* Read the file into a buffer */  
    fseek(fp, 0, SEEK_END);
    long size = ftell(fp);
    rewind(fp);
    char *buf = (char *) malloc(size);
    if (!buf) {
        fprintf(stderr, "error: memory error\n");
        exit(1);
    } 
    fread(buf, 1, size, fp);

    while (_CUR_BIT < 8 * size) {
        read_block(buf, out); 
    } 
    
    fclose(out);
    free(buf);
}