package cs2110;

import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.BitSet;
import java.util.HashMap;
import java.util.PriorityQueue;

public class HuffmanCoding {

    /**
     * The elements stored in our priority queue, which consist of the prefix subtree, `symbolTree`,
     * and the total `count` of all characters appearing in its leaves. The priorities of these
     * elements are their counts, with ties broken alphabetically (to make this procedure deterministic).
     */
    record SymbolCount(ImmutableBinaryTree<String> symbolTree, int count) implements Comparable<SymbolCount> {

        @Override
        public int compareTo(SymbolCount o) {
            if (count != o.count) {
                return count - o.count;
            }
            return symbolTree.root.compareTo(o.symbolTree.root); // break ties alphabetically
        }
        // Note: Implementing `Comparable` establishes the "natural priority order" that will be
        // used when inserting these `SymbolCount` objects into a PriorityQueue.
    }

    /**
     * Returns a HashMap associating each character that appears in the given `input` String with
     * the number of times that it occurs.
     */
    static HashMap<Character,Integer> getFrequencyMap(String input) {
        HashMap<Character, Integer> frequencies = new HashMap<>();
        for (char c : input.toCharArray()) {
            frequencies.merge(c, 1, (count, x) -> count + 1);
        }
        return frequencies;
    }

    /**
     * Initializes the priority queue used by the Huffman encoding procedure. At initialization,
     * the priority queue contains one leaf node (i.e., tree with `null` left and right subtrees)
     * for each character that occurs in the source text, prioritized by their frequency.
     */
    static PriorityQueue<SymbolCount> initializePQueue(HashMap<Character,Integer> frequencies) {
        // Step 2: Initialize PriorityQueue with terminal symbols
        PriorityQueue<SymbolCount> pQueue = new PriorityQueue<>();
        for (Character c : frequencies.keySet()) {
            pQueue.add(new SymbolCount(
                    new ImmutableBinaryTree<>(Character.toString(c), null, null),
                    frequencies.get(c)));
        }
        return pQueue;
    }

    /**
     * Carries out the Huffman encoding procedure to build and return the prefix tree associated
     * with this text. Each node in the tree is labeled with the characters found in all leaves of
     * its subtree, in order from left to right.
     */
    static ImmutableBinaryTree<String> buildPrefixTree(PriorityQueue<SymbolCount> pQueue) {
        while (true) {
            SymbolCount s1 = pQueue.remove();
            if (pQueue.isEmpty()) { // s1 holds the complete prefix tree
                return s1.symbolTree;
            }
            SymbolCount s2 = pQueue.remove();
            pQueue.add(new SymbolCount(
                    new ImmutableBinaryTree<>(s2.symbolTree.root + s1.symbolTree.root, s2.symbolTree, s1.symbolTree),
                    s1.count + s2.count));
        }
    }

    /**
     * Returns the binary string encoding of the given character `c` from the given `prefixTree`.
     * Requires that `c` is present in one of the leaves of the `prefixTree`.
     */
    static String getBinaryEncoding(char c, ImmutableBinaryTree<String> prefixTree) {
        assert prefixTree.root.indexOf(c) >= 0 : c + " cannot be encoded using this prefix tree";

        BinaryTree<String> current = prefixTree; // our current position in the tree traversal
        StringBuilder sb = new StringBuilder(); // will hold the binary encoding of this character
        while(!current.root.equals(Character.toString(c))) {
            if (current.left() != null && current.left().root.indexOf(c) >= 0) { // c in left subtree
                current = current.left();
                sb.append(0);
            } else { // c in right subtree
                current = current.right();
                sb.append(1);
            }
        }
        return sb.toString();
    }

    /**
     * Uses Huffman encoding to compress the contents of the given `sourceFile`, writing the result
     * to the file at location `encodedFile`.
     */
    public static void encode(String sourceFile, String encodedFile) throws IOException {
        // Step 1: Read the contents of the sourceFile.
        String input = Files.readString(Paths.get(sourceFile));

        // Step 2: Use the character frequencies to build a prefix tree.
        HashMap<Character,Integer> frequencyMap = getFrequencyMap(input);
        ImmutableBinaryTree<String> prefixTree = buildPrefixTree(initializePQueue(frequencyMap));

        // Step 3: Use the prefix tree to compute each character's binary encoding.
        HashMap<Character,String> encodingMap = new HashMap<>();
        for (Character c : frequencyMap.keySet()) {
            encodingMap.put(c, getBinaryEncoding(c, prefixTree));
        }

        // Step 4: Encode the file contents into a long binary string.
        StringBuilder sb = new StringBuilder();
        for (Character c : input.toCharArray()) {
            sb.append(encodingMap.get(c));
        }
        String encodingString = sb.toString();

        // Step 5: Write the encoding table and the (compacted) binary string to the encodedFile.
        BitSet bits = new BitSet(encodingString.length());
        for (int i = 0; i < encodingString.length(); i++) {
            if (encodingString.charAt(i) == '1') {
                bits.set(i);
            }
        }
        FileOutputStream fileOutputStream = new FileOutputStream(encodedFile);
        ObjectOutputStream objectOutputStream = new ObjectOutputStream(fileOutputStream);
        objectOutputStream.writeObject(encodingMap);
        objectOutputStream.writeObject(bits);
        objectOutputStream.flush();
        objectOutputStream.close();
        fileOutputStream.close();
    }

    /**
     * Decodes the given `encodedFile` by reversing the Huffman encoding process, writing the
     * result to the file at location `decodedFile`.
     */
    @SuppressWarnings("unchecked")
    public static void decode(String encodedFile, String decodedFile) throws Exception {

        // Step 1: Read the contents of the encodedFile and reconstruct the encoding map and binary string.
        FileInputStream fileInputStream = new FileInputStream(encodedFile);
        ObjectInputStream objectInputStream = new ObjectInputStream(fileInputStream);
        HashMap<Character,String> encodingMap = (HashMap<Character,String>) objectInputStream.readObject();
        BitSet bits = (BitSet) objectInputStream.readObject();
        objectInputStream.close();
        fileInputStream.close();

        // Step 2: "Flip" the encodingMap to obtain a decodingMap
        HashMap<String, Character> decodingMap = new HashMap<>();
        for (Character c : encodingMap.keySet()) {
            decodingMap.put(encodingMap.get(c), c);
        }

        // Step 3: Scan over binary string to decode file
        StringBuilder sb = new StringBuilder(); // holds decoded file contents
        String s = ""; // holds bit string of the character we are currently decoding
        for(int i = 0; i < bits.length(); i++) {
            s += bits.get(i) ? "1" : "0";
            if (decodingMap.containsKey(s)) {
                sb.append(decodingMap.get(s));
                s = "";
            }
        }

        // Step 4: Write the decoded file contents to decodedFile.
        Files.writeString(Paths.get(decodedFile),sb.toString());
    }

    public static void main(String[] args) throws Exception {
        String sourceFile = (args.length == 0 ? "verne.txt" : args[0]);
        System.out.println("Compressing file \"" + sourceFile + "\"\n");
        System.out.println("Original file size: " + new File(sourceFile).length() + " bytes");
        String encodedFile = sourceFile.substring(0,sourceFile.lastIndexOf(".")) + "_encoded.txt";
        encode(sourceFile, encodedFile);
        System.out.println("Compressed file size: " + new File(encodedFile).length() + " bytes");
    }
}
