/*
 * Copyright (c) Paul Stankovski
 * Free for all non-commercial use unless this directive conflicts with
 * other applicable copyright statement(s), patent holders, laws or such.
 */
#include "memory_tracker.h"
#include <stdlib.h> /* malloc, free */
#include <string.h> /* memset, memcpy */
#include <stdio.h> /* printf */
#include <pthread.h> /* mutexes */
#include "assert_utils.h"
#include "platform_types.h"
#include "memxor.h"
#include "memrnd.h"

/*******************************************************************************
 * Definitions, types and variables
 ******************************************************************************/
#define MAX_SIMULTANEOUS_ALLOCATIONS 1000

static int initialized = 0;
static double totalNumAllocations = 0;
static int maxNumSimultaneousAllocations = 0;
static int numPendingDeallocations = 0;
static int slotSpaceExhausted = 0;
typedef struct {
  int bytes;
  void *p;
  const char *file;
  int line;
} AllocInfo;
static AllocInfo Info[MAX_SIMULTANEOUS_ALLOCATIONS];

static void defaultGarbagingCallback(
  int bytes,
  const char *file,
  int line,
  int field,
  const char *fromFile,
  int fromLine);
static memoryGarbagingWarningCallback garbaging = defaultGarbagingCallback;

static const UINT64 abba = U64C(0xABBAACDCABBAACDC); /* boundary marker */

/* mutexes */
static pthread_mutex_t mallocMutex = PTHREAD_MUTEX_INITIALIZER; /* for malloc, free and thread safe memory report */
static pthread_mutex_t memoryWriteMutex = PTHREAD_MUTEX_INITIALIZER; /* for memory write operations */

/*******************************************************************************
 * Utilities
 ******************************************************************************/
static void clearAllocInfoSlot(int i) {
  ASSERT(i >= 0 && i < MAX_SIMULTANEOUS_ALLOCATIONS, "Unexpected index range!\n");
  Info[i].bytes = 0;
  Info[i].p = NULL;
  Info[i].file = NULL;
  Info[i].line = 0;
}

static void initAllocInfo(void)
{
  int i;
  for (i=0; i<MAX_SIMULTANEOUS_ALLOCATIONS; i++)
    clearAllocInfoSlot(i);
}

static int isInitialized(void) { return initialized; }

static void initialize(void) {
  initAllocInfo();
  initialized = 1;
  totalNumAllocations = 0;
  maxNumSimultaneousAllocations = 0;
  numPendingDeallocations = 0;
  slotSpaceExhausted = 0;
}

static void setPreGarbagingField(void *block) {
  BYTE *p = (BYTE*)block;
  *((UINT64*)p) = abba;
}

static void setPostGarbagingField(void *block, int size) {
  BYTE *p = (BYTE*)block;
  *((UINT64*)(p + sizeof(abba) + size)) = abba;
}

static int preFieldGarbaging(void *block) {
  BYTE *p = (BYTE*)block;
  return *((UINT64*)(p - sizeof(abba))) != abba;
}

static int postFieldGarbaging(void *block, int size) {
  BYTE *p = (BYTE*)block;
  return *((UINT64*)(p + size)) != abba;
}

static void checkAllTrackedAllocationsForGarbaging(const char *file, int line) {
  int i;

  if (!isInitialized())
    initialize();

  for (i=0; i<MAX_SIMULTANEOUS_ALLOCATIONS; i++)
    if (Info[i].p) { /* slot contains a tracked allocation */

      /* check pre garbaging field */
      if (garbaging && preFieldGarbaging(Info[i].p))
        defaultGarbagingCallback(Info[i].bytes, Info[i].file, Info[i].line, BUFFER_GARBAGING_AT_PRE_FIELD, file, line);

      /* check post garbaging field */
      if (garbaging && postFieldGarbaging(Info[i].p, Info[i].bytes))
        defaultGarbagingCallback(Info[i].bytes, Info[i].file, Info[i].line, BUFFER_GARBAGING_AT_POST_FIELD, file, line);
    }
}

static void checkAllTrackedAllocationsForGarbagingThreadSafe(const char *file, int line) {
  pthread_mutex_lock(&mallocMutex); /* same mutes as malloc/free as allocation table must not be modified during the report */
  checkAllTrackedAllocationsForGarbaging(file, line);
  pthread_mutex_unlock(&mallocMutex);
}

static int locateSlot(void *p) {
  int i;
  for (i=0; i<MAX_SIMULTANEOUS_ALLOCATIONS; i++)
    if (Info[i].p == p)
      return i;
  return -1;
}

/* call before your application closes to deallocate internal structures. */
void shutDownMemoryTracker(void) {
  int i;
  memoryGarbagingWarningCallback prevGarbagingCallback;

  if (!isInitialized())
    initialize();

  /* check for garbaging */
  prevGarbagingCallback = garbaging;
  garbaging = defaultGarbagingCallback; /* temporarily set default garbaging report handler */
  reportMemoryStatus(NULL, NULL, NULL);
  garbaging = prevGarbagingCallback; /* restore previous handler */

  /* return all memory */
  for (i=0; i<MAX_SIMULTANEOUS_ALLOCATIONS; i++) {
    if (!Info[i].p) continue;
    free((BYTE*)Info[i].p - sizeof(abba)); /* free memory */
    clearAllocInfoSlot(i); /* clear slot */
  }
  initialized = 0;

  /* destroy mutexes */
  pthread_mutex_destroy(&mallocMutex);
  pthread_mutex_destroy(&memoryWriteMutex);
}

/*******************************************************************************
 * Report functions
 ******************************************************************************/
void registerGarbagingWarningCallback(memoryGarbagingWarningCallback f) {
  if (!isInitialized())
    initialize();
  garbaging = f;
}

static void defaultGarbagingCallback(
  int bytes,
  const char *file,
  int line,
  int field,
  const char *fromFile,
  int fromLine) {
  char text[512];
  if (fromFile && fromLine > 0 && file && line > 0)
    sprintf(text, "Memory garbaging at %s field of %d byte buffer allocated at line %d in file %s (detected from line %d in file %s)!", field == BUFFER_GARBAGING_AT_PRE_FIELD ? "pre" : "post", bytes, line, file, fromLine, fromFile);
  else if (file && line > 0)
    sprintf(text, "Memory garbaging at %s field of %d byte buffer allocated at line %d in file %s!", field == BUFFER_GARBAGING_AT_PRE_FIELD ? "pre" : "post", bytes, line, file);
  else if (fromFile && fromLine > 0)
    sprintf(text, "Memory garbaging detected at %s field (detected from line %d in file %s)!", field == BUFFER_GARBAGING_AT_PRE_FIELD ? "pre" : "post", fromLine, fromFile);
  else
    sprintf(text, "Memory garbaging detected!");
  ASSERT_ALWAYS(text);
}

/* report  */
void reportMemoryStatus(memoryInfoOverviewCallback overview, memoryInfoDetailsCallback details, memoryInfoEnd end) {
  int i;
  int numAllocations = 0;

  pthread_mutex_lock(&mallocMutex); /* same mutes as malloc/free as allocation table must not be modified during the report */
  if (!isInitialized())
    initialize();

  if (overview)
    overview(totalNumAllocations, maxNumSimultaneousAllocations, numPendingDeallocations, slotSpaceExhausted);

  if (numPendingDeallocations > 0) {
    for (i=0; i<MAX_SIMULTANEOUS_ALLOCATIONS; i++)
      if (Info[i].p) {
        if (details)
          details(Info[i].bytes, Info[i].file, Info[i].line);

        /* check pre garbaging field */
        if (garbaging && preFieldGarbaging(Info[i].p))
          garbaging(Info[i].bytes, Info[i].file, Info[i].line, BUFFER_GARBAGING_AT_PRE_FIELD, NULL, -1);

        /* check post garbaging field */
        if (garbaging && postFieldGarbaging(Info[i].p, Info[i].bytes))
          garbaging(Info[i].bytes, Info[i].file, Info[i].line, BUFFER_GARBAGING_AT_POST_FIELD, NULL, -1);

        numAllocations++;
        if (numAllocations == numPendingDeallocations)
          break;
      }
  }

  if (end)
    end();
  pthread_mutex_unlock(&mallocMutex);
}

/* default callback function for reporting general memory status information */
static void defaultOverviewCallback(
  int totalNumAllocations,
  int maxNumSimultaneousAllocations,
  int numPendingDeallocations,
  int slotSpaceExhausted) {
  printf("****************** MEMORY TRACKER STATUS REPORT ******************\n");
  printf("Total number of allocations                = %d\n", totalNumAllocations);
  printf("Maximum number of simultaneous allocations = %d\n", maxNumSimultaneousAllocations);
  printf("Number of pending deallocations            = %d\n", numPendingDeallocations);
  if (!slotSpaceExhausted)
    printf("All allocations have been tracked\n");
  else
    printf("Some allocations may not have been tracked (out of slots)!\n");
  printf("******************************************************************\n");
}

/* default callback function for reporting each remaining (unfreed) memory block */
static void defaultMemoryInfoDetailsCallback(
  int bytes,
  const char *file,
  int line) {
  printf("%7d bytes allocated at line %3d in file %s\n", bytes, line, file);
}

/* default callback function for end of status report */
static void defaultMemoryInfoEnd(void) {
  printf("*************** END OF MEMORY TRACKER STATUS REPORT **************\n");
}

void reportMemoryStatusDefault(void) {
  reportMemoryStatus(defaultOverviewCallback, defaultMemoryInfoDetailsCallback, defaultMemoryInfoEnd);
}

/*******************************************************************************
 * Memory operations
 ******************************************************************************/
void *memoryTrackerMalloc(size_t size, const char *file, int line) {
  void *p;
  int i;

  if (!isInitialized())
    initialize();

  /* stats */
  totalNumAllocations++;
  numPendingDeallocations++;
  if (numPendingDeallocations > maxNumSimultaneousAllocations)
    maxNumSimultaneousAllocations = numPendingDeallocations;

  /* find first empty slot */
  i = locateSlot(NULL);
  if (i == -1) { /* no empty slots */
    slotSpaceExhausted = 1;
    p = malloc(size); /* use regular malloc */
    ASSERT(p, "Allocation failed (using regular malloc)!");
    return p;
  }

  /* prepare buffer */
  p = malloc(size + sizeof(abba) * 2); /* add pre and post garbaging fields */
  setPreGarbagingField(p);
  setPostGarbagingField(p, size);

  /* store info */
  Info[i].bytes = size;
  Info[i].p = (BYTE*)p + sizeof(abba);
  Info[i].file = file;
  Info[i].line = line;
  ASSERT(Info[i].p, "Allocation failed!");

  checkAllTrackedAllocationsForGarbaging(file, line);
  return Info[i].p;
}

void memoryTrackerFree(void *block, const char *file, int line) {
  int i;

  if (!isInitialized())
    initialize();

  /* stats */
  numPendingDeallocations--;

  /* locate */
  i = locateSlot(block);
  if (i == -1) { /* allocation not registered because info slots were exhausted (which is ok) */
    ASSERT(slotSpaceExhausted, "Could not find corresponding allocation!");
    free(block); /* regular free */
    return;
  }

  /* check pre garbaging field */
  if (garbaging && preFieldGarbaging(block))
    garbaging(Info[i].bytes, Info[i].file, Info[i].line, BUFFER_GARBAGING_AT_PRE_FIELD, file, line);

  /* check post garbaging field */
  if (garbaging && postFieldGarbaging(block, Info[i].bytes))
    garbaging(Info[i].bytes, Info[i].file, Info[i].line, BUFFER_GARBAGING_AT_POST_FIELD, file, line);

  /* free */
  free((BYTE*)block - sizeof(abba));

  /* clear slot */
  clearAllocInfoSlot(i);

  checkAllTrackedAllocationsForGarbaging(file, line);
  return;
}

void *memoryTrackerMemcpy(void *dst, const void *src, size_t n, const char *file, int line) {
  void *p;
  if (!isInitialized())
    initialize();
  p = memcpy(dst, src, n);
  checkAllTrackedAllocationsForGarbaging(file, line);
  return p;
}

void *memoryTrackerMemset(void *s, int c, size_t n, const char *file, int line) {
  void *p;
  if (!isInitialized())
    initialize();
  p = memset(s, c, n);
  checkAllTrackedAllocationsForGarbaging(file, line);
  return p;
}

void *memoryTrackerMemxor(void *dst, void *src, size_t n, const char *file, int line) {
  void *p;
  if (!isInitialized())
    initialize();
  p = memxor(dst, src, n);
  checkAllTrackedAllocationsForGarbaging(file, line);
  return p;
}

void *memoryTrackerMemrnd(void *dst, size_t n, const char *file, int line) {
  void *p;
  if (!isInitialized())
    initialize();
  p = memrnd(dst, n);
  checkAllTrackedAllocationsForGarbaging(file, line);
  return p;
}

/* thread-safe variants */
void *memoryTrackerTMalloc(size_t size, const char *file, int line) {
  void *p;
  pthread_mutex_lock(&mallocMutex);
  p = memoryTrackerMalloc(size, file, line);
  pthread_mutex_unlock(&mallocMutex);
  return p;
}

void memoryTrackerTFree(void *block, const char *file, int line) {
  pthread_mutex_lock(&mallocMutex);
  memoryTrackerFree(block, file, line);
  pthread_mutex_unlock(&mallocMutex);
}

void *memoryTrackerTMemcpy(void *dst, const void *src, size_t n, const char *file, int line) {
  void *p;
  pthread_mutex_lock(&memoryWriteMutex);
  if (!isInitialized())
    initialize();
  p = memcpy(dst, src, n);
  pthread_mutex_unlock(&memoryWriteMutex);
  checkAllTrackedAllocationsForGarbagingThreadSafe(file, line);
  return p;
}

void *memoryTrackerTMemset(void *s, int c, size_t n, const char *file, int line) {
  void *p;
  pthread_mutex_lock(&memoryWriteMutex);
  if (!isInitialized())
    initialize();
  p = memset(s, c, n);
  pthread_mutex_unlock(&memoryWriteMutex);
  checkAllTrackedAllocationsForGarbagingThreadSafe(file, line);
  return p;
}

void *memoryTrackerTMemxor(void *dst, void *src, size_t n, const char *file, int line) {
  void *p;
  pthread_mutex_lock(&memoryWriteMutex);
  if (!isInitialized())
    initialize();
  p = memxor(dst, src, n);
  pthread_mutex_unlock(&memoryWriteMutex);
  checkAllTrackedAllocationsForGarbagingThreadSafe(file, line);
  return p;
}

void *memoryTrackerTMemrnd(void *dst, size_t n, const char *file, int line) {
  void *p;
  pthread_mutex_lock(&memoryWriteMutex);
  if (!isInitialized())
    initialize();
  p = memrnd(dst, n);
  pthread_mutex_unlock(&memoryWriteMutex);
  checkAllTrackedAllocationsForGarbagingThreadSafe(file, line);
  return p;
}

