package cs2110;

import java.util.LinkedList;

/**
 * A map implementation backed by a chaining hash table.
 */
public class HashMap<K,V> implements Map<K,V>{

    /**
     * Represents a (key,value) pair in this map
     */
    private record Entry<K,V> (K key, V value) { }

    /**
     * The backing storage of this HashMap.
     */
    private LinkedList<Entry<K,V>>[] buckets;

    /**
     * The number of elements stored in this HashMap.
     */
    private int size;

    /**
     * The initial capacity of the hash table.
     */
    public static final int INITIAL_CAPACITY = 5;

    /**
     * The maximum load factor permissible before resizing.
     */
    public static final double MAX_LOAD_FACTOR = 0.75;

    /**
     * Construct a new, initially empty, hash set.
     */
    public HashMap() {
        buckets = emptyTable(INITIAL_CAPACITY);
        size = 0;
    }

    /**
     * Constructs and returns an empty chaining hash table consisting with the given `capacity`.
     */
    @SuppressWarnings("unchecked")
    private LinkedList<Entry<K,V>>[] emptyTable(int capacity) {
        LinkedList<Entry<K,V>>[] table = new LinkedList[capacity];
        for (int i = 0; i < capacity; i++) {
            table[i] = new LinkedList<>();
        }
        return table;
    }

    /**
     * Returns the hash value of the given `elem`
     */
    private int index(K key) {
        return Math.abs(key.hashCode() % buckets.length);
    }

    /**
     * Reassigns `buckets` to an array with double the capacity and re-hashes all entries.
     */
    private void doubleCapacity() {
        LinkedList<Entry<K,V>>[] oldBuckets = buckets;
        buckets = emptyTable(buckets.length * 2);
        for (LinkedList<Entry<K,V>> bucket : oldBuckets) {
            for (Entry<K,V> entry : bucket) {
                buckets[index(entry.key)].add(entry);
            }
        }
    }

    /**
     * Returns the index of an entry with the given `key` in the `i`th bucket, or returns -1
     * if there is no such entry.
     */
    private int findInBucket(K key, int i) {
        int j = 0;
        for (Entry<K,V> entry : buckets[i]) {
            if (entry.key.equals(key)) {
                return j;
            }
            j++;
        }
        return -1;
    }

    @Override
    public void put(K key, V value) {
        assert key != null;
        if ((double) (size + 1) / buckets.length > MAX_LOAD_FACTOR) { // exceed max load factor
            doubleCapacity();
        }

        int i = index(key);
        int j = findInBucket(key, i);
        if (j >= 0) { // already in map, overwrite
            buckets[i].set(j,new Entry<>(key, value));
        } else { // not in map yet
            buckets[i].add(new Entry<>(key, value));
            size += 1;
        }
    }

    @Override
    public boolean containsKey(K key) {
        return findInBucket(key, index(key)) >= 0;
    }

    @Override
    public V get(K key) {
        assert key != null;
        assert containsKey(key);

        int i = index(key);
        return buckets[i].get(findInBucket(key,i)).value;
    }

    @Override
    public int size() {
        return size;
    }

    @Override
    public V remove(K key) {
        assert key != null;
        assert containsKey(key);

        int i = index(key);
        int j = findInBucket(key, i);
        size -= 1;
        return buckets[i].remove(j).value;
    }

    @Override
    public Set<K> keySet() {
        Set<K> set = new HashSet<>();
        for (LinkedList<Entry<K,V>> bucket : buckets) {
            for (Entry<K,V> entry : bucket) {
                set.add(entry.key);
            }
        }
        return set;
    }
}
