#include "mempool.h"

#include <cassert.h>

#include <stdlib.h>
#include <string.h>

/// Initialize the free list.
/// All of the blocks in the pool are assumed free.
static void init_free_list(mempool* pool) {
  assert(pool);
  for (size_t i = 0; i < pool->num_blocks - 1; ++i) {
    pool->block_info[i].next_free = i + 1;
  }
  pool->block_info[pool->num_blocks - 1].next_free = 0;
}

bool mempool_make_(
    mempool* pool, BlockInfo* block_info, void* blocks, size_t num_blocks,
    size_t block_size_bytes) {
  assert(pool);
  assert((block_info && blocks) || (!block_info && !blocks));
  assert(num_blocks >= 1);

  pool->block_size_bytes = block_size_bytes;
  pool->num_blocks       = num_blocks;
  pool->num_used_blocks  = 0;
  pool->head             = 0;
  pool->used             = 0;
  pool->trap             = true;

  // Initialize blocks and block info.
  if (!block_info) {
    block_info    = calloc(num_blocks, sizeof(BlockInfo));
    blocks        = calloc(num_blocks, block_size_bytes);
    pool->dynamic = true;
    if ((block_info == 0) || (blocks == 0)) {
      return false;
    }
  } else {
    memset(blocks, 0, num_blocks * block_size_bytes);
    memset(block_info, 0, num_blocks * sizeof(BlockInfo));
    pool->dynamic = false;
  }
  pool->block_info = block_info;
  pool->blocks     = blocks;

  init_free_list(pool);

  return true;
}

void mempool_del_(mempool* pool) {
  assert(pool);
  if (pool->dynamic) {
    if (pool->block_info) {
      free(pool->block_info);
      pool->block_info = 0;
    }
    if (pool->blocks) {
      free(pool->blocks);
      pool->blocks = 0;
    }
  }
}

void mempool_clear_(mempool* pool) {
  assert(pool);
  pool->head = 0;
  pool->used = 0;
  memset(pool->blocks, 0, pool->num_blocks * pool->block_size_bytes);
  memset(pool->block_info, 0, pool->num_blocks * sizeof(BlockInfo));
  init_free_list(pool);
}

void* mempool_alloc_(mempool* pool) {
  assert(pool);

  BlockInfo* head = &pool->block_info[pool->head];
  if (head->used) {
    if (pool->trap) {
      FAIL("mempool allocation failed, increase the pool's capacity.");
    }
    return 0; // Pool is full.
  }

  // Allocate the block.
  void* block     = &pool->blocks[pool->head * pool->block_size_bytes];
  head->used      = true;
  head->next_used = pool->used;
  pool->used      = pool->head;
  pool->head      = head->next_free;
  head->next_free = 0;

  pool->num_used_blocks++;

  return block;
}

void mempool_free_(mempool* pool, void** block_ptr) {
  assert(pool);
  assert(block_ptr);

  const size_t block_index =
      ((uint8_t*)*block_ptr - pool->blocks) / pool->block_size_bytes;
  assert(block_index < pool->num_blocks);
  BlockInfo* info = &pool->block_info[block_index];

  // Disallow double-frees.
  assert(info->used);

  // Zero out the block so that we don't get stray values the next time it is
  // allocated.
  memset(*block_ptr, 0, pool->block_size_bytes);

  // Free the block and add it to the head of the free list.
  info->used      = false;
  info->next_used = 0;
  info->next_free = pool->head;
  pool->head      = block_index;
  if (pool->used == block_index) {
    pool->used = 0;
  }

  pool->num_used_blocks--;

  *block_ptr = 0;
}

void* mempool_get_block_(const mempool* pool, size_t block_index) {
  assert(pool);
  assert(block_index < pool->num_blocks);
  assert(pool->block_info[block_index].used);
  return pool->blocks + block_index * pool->block_size_bytes;
}

size_t mempool_get_block_index_(const mempool* pool, const void* block) {
  assert(pool);
  const size_t block_byte_index = (const uint8_t*)block - pool->blocks;
  assert(block_byte_index % pool->block_size_bytes == 0);
  return block_byte_index / pool->block_size_bytes;
}

size_t mempool_block_size_bytes_(const mempool* pool) {
  assert(pool);
  return pool->block_size_bytes;
}

size_t mempool_capacity_(const mempool* pool) {
  assert(pool);
  return pool->num_blocks;
}

size_t mempool_size_(const mempool* pool) {
  assert(pool);
  return pool->num_used_blocks;
}

void mempool_enable_traps_(mempool* pool, bool enable) {
  assert(pool);
  pool->trap = enable;
}