/*
 * Copyright (c) Paul Stankovski
 * Free for all non-commercial use unless this directive conflicts with
 * other applicable copyright statement(s), patent holders, laws or such.
 */
#include "black_box_incremental_maxterm.h"
#include "black_box_bit_set_utils.h"
#include "black_box_ciphers.h"
#include "bittwiddling.h"
#include "memory_utils.h"
#include "log_utils.h"
#include "press_any_key.h"
#include "assert_utils.h"
#include <stdlib.h>
#include <string.h>

#define UPPER_HEX(c) (((c) >= 'A') && ((c) <= 'F'))
#define LOWER_HEX(c) (((c) >= 'a') && ((c) <= 'f'))
#define DIGIT(c) (((c) >= '0') && ((c) <= '9'))

static int stringToHexBuffer(BYTE* buf, const char *str) {
  int i, len = strlen(str);

  if (len & 1) return 1;
  for (i=0; i<len; i+=2) {
    int hi = 0, lo = 0;
    char hc = str[i], lc = str[i + 1];

    ASSERT(DIGIT(hc) || LOWER_HEX(hc) || UPPER_HEX(hc), "Unexpected character (hc)!\n");
    if (DIGIT(hc))     hi = hc - '0';
    if (LOWER_HEX(hc)) hi = hc - 'a' + 10;
    if (UPPER_HEX(hc)) hi = hc - 'A' + 10;
    ASSERT(DIGIT(lc) || LOWER_HEX(lc) || UPPER_HEX(lc), "Unexpected character (lc)!\n");
    if (DIGIT(lc))     lo = lc - '0';
    if (LOWER_HEX(lc)) lo = lc - 'a' + 10;
    if (UPPER_HEX(lc)) lo = lc - 'A' + 10;

    ASSERT(0 <= hi && hi <= 15, "Invalid hi-value!\n");
    ASSERT(0 <= lo && lo <= 15, "Invalid lo-value!\n");
    *buf++ = (BYTE)((hi << 4) | lo);
  }
  return 0;
}

static int hexBufferToString(char *str, const BYTE* buf, int len) {
  int i;
  const char bb[] = {'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F'};

  for (i=0; i<len; i++) {
    int hi = buf[i] >> 4, lo = buf[i] & 0x0F;

    if (hi < 0 || hi > 15) return 1;
    if (lo < 0 || lo > 15) return 1;
    *str++ = bb[hi];
    *str++ = bb[lo];
  }
  *str++ = 0;
  return 0;
}

#define UPPER_LETTER(c) (((c) >= 'A') && ((c) <= 'Z'))
#define LOWER_LETTER(c) (((c) >= 'a') && ((c) <= 'z'))
#define DOT(c) ((c) == '.')
#define ODD_CHARACTER(c) (!UPPER_LETTER(c) && !LOWER_LETTER(c) && !DIGIT(c) && !DOT(c))
static void lc(char *str) {
  int i;
  for (i=0; str[i]; i++)
    if (UPPER_LETTER(str[i]))
      str[i] += (char)('a' - 'A');
    else if (ODD_CHARACTER(str[i]))
      str[i] = '_';
}

static char *getFileName(char *fileName, bbCipher cipher, const BYTE *key, const BYTE *iv, int numParBits, const char *ext) {
  const char *name = blackBoxCipherName(cipher);
  char keyStr[1024], ivStr[1024];
  int keySizeInBytes, ivSizeInBytes;

  blackBoxInfo(cipher, &keySizeInBytes, &ivSizeInBytes, NULL, NULL);
  ASSERT(keySizeInBytes <= 512 && ivSizeInBytes <= 512, "Unexpected key or iv size!\n");

  if (hexBufferToString(keyStr, key, keySizeInBytes)) return NULL;
  if (hexBufferToString(ivStr, iv, ivSizeInBytes)) return NULL;
  sprintf(fileName, "incremental_maxterm__%s__key_%s__iv_%s__par_bits_%d.%s", name, keyStr, ivStr, numParBits, ext);
  lc(fileName);
  return fileName;
}

static int readBitBufferFromArg(bbCipher cipher, BYTE *buf, int isIvBuf, const char *arg) {
  int len;
  int bufSizeInBytes;

  blackBoxInfo(cipher, isIvBuf ? NULL : (&bufSizeInBytes), isIvBuf ? (&bufSizeInBytes) : NULL, NULL, NULL);

  len = strlen(arg);
  if (len == 1) {
    switch (atoi(arg)) {
    case 0: MEMSET(buf, 0, bufSizeInBytes); return 0;
    case 1: MEMSET(buf, 0xFF, bufSizeInBytes); return 1;
    case 2: MEMRND(buf, bufSizeInBytes); return 2;
    }
  } else if (len == bufSizeInBytes * 2) {
    if (!stringToHexBuffer(buf, arg))
      return 3;
  }
  return -1;
}

static void printAllArguments(int argc, char **argv) {
  int i;
  logger(NULL, LOGNOFLUSH, "argc = %d\n", argc);
  for (i=0; i<argc; i++)
    logger(NULL, LOGNOFLUSH, "argv[%d] = %s\n", i, argv[i]);
  logger(NULL, LOGALL, "");
}

static void getCipher(bbCipher *cipher, const char *arg) {
  *cipher = (bbCipher)atoi(arg);
  logger(NULL, LOGALL, "\ncipher = %d = %s\n\n", *cipher, blackBoxCipherName(*cipher));
}

static int getKey(bbCipher cipher, BYTE *key, const char *arg) {
  int keySizeInBytes;
  int keyFill = readBitBufferFromArg(cipher, key, 0, arg);

  blackBoxInfo(cipher, &keySizeInBytes, NULL, NULL, NULL);
  if (keyFill == -1) {
    logger(NULL, LOGALL, "\nInvalid key fill!\n");
    return 1;
  }
  ASSERT(keyFill >=0 && keyFill <= 3, "Unexpected iv fill!\n");
  logger(NULL, LOGNOFLUSH, "Key fill = ");
  switch (keyFill) {
  case 0: logger(NULL, LOGNOFLUSH, "ZEROS    ="); break;
  case 1: logger(NULL, LOGNOFLUSH, "ONES     ="); break;
  case 2: logger(NULL, LOGNOFLUSH, "RANDOM   ="); break;
  case 3: logger(NULL, LOGNOFLUSH, "EXPLICIT =");
  }
  logBuf(NULL, LOGNOFLUSH, key, keySizeInBytes, 0, 0);
  logger(NULL, LOGALL, "\n");
  return 0;
}

static int getIV(bbCipher cipher, BYTE *iv, const char *arg) {
  int ivSizeInBytes;
  int ivFill = readBitBufferFromArg(cipher, iv, 1, arg);

  blackBoxInfo(cipher, NULL, &ivSizeInBytes, NULL, NULL);
  if (ivFill == -1) {
    logger(NULL, LOGALL, "\nInvalid iv fill!\n");
    return 1;
  }
  ASSERT(ivFill >=0 && ivFill <= 3, "Unexpected iv fill!\n");
  logger(NULL, LOGNOFLUSH, " IV fill = ");
  switch (ivFill) {
  case 0: logger(NULL, LOGNOFLUSH, "ZEROS    ="); break;
  case 1: logger(NULL, LOGNOFLUSH, "ONES     ="); break;
  case 2: logger(NULL, LOGNOFLUSH, "RANDOM   ="); break;
  case 3: logger(NULL, LOGNOFLUSH, "EXPLICIT =");
  }
  logBuf(NULL, LOGNOFLUSH, iv, ivSizeInBytes, 0, 0);
  logger(NULL, LOGALL, "\n\n");
  return 0;
}

static void getNumParallellBits(int *numParBits, const char *arg) {
  *numParBits = atoi(arg);
  logger(NULL, LOGALL, "Parallell bits = %d\n\n", *numParBits);
}

static FILE *openLogFile(bbCipher cipher, BYTE *key, BYTE *iv, int parallellBits) {
  FILE *f;
  char logFileName[512];

  if (!getFileName(logFileName, cipher, key, iv, parallellBits, "txt")) {
    logger(NULL, LOGALL, "Log file name error!\n");
    return NULL;
  }
  f = fopen(logFileName, "w");
  if (f) logger(NULL, LOGALL, "Log file: %s\n\n", logFileName);
  return f;
}

static void getBitSet(int *numKeyBits, int *keyBit, int *numIvBits, int *ivBit, char **argv) {
  int i;

  *numKeyBits = atoi(argv[0]);
  *numIvBits = atoi(argv[1 + *numKeyBits]);
  for (i=0; i<*numKeyBits; i++) keyBit[i] = atoi(argv[1 + i]);
  for (i=0; i<*numIvBits; i++)   ivBit[i] = atoi(argv[2 + *numKeyBits + i]);

  logger(NULL, LOGNOFLUSH, "Num bits = %d + %d = %d\n", *numKeyBits, *numIvBits, *numKeyBits + *numIvBits);
  logger(NULL, LOGNOFLUSH, "Key bits = "); logBitSet(NULL, LOGNOFLUSH, keyBit, *numKeyBits);
  logger(NULL, LOGNOFLUSH, "\n IV bits = "); logBitSet(NULL, LOGNOFLUSH, ivBit, *numIvBits);
  logger(NULL, LOGALL, "\n\n");
}

static void logHeader(FILE *logFile, bbCipher cipher, BYTE *key, BYTE *iv, int numKeyBits, int *keyBit, int numIvBits, int *ivBit, int parallellBits) {
  int i, keySizeInBytes, ivSizeInBytes, suppressedBytes, implicitBlockSizeInBytes;
  int len = strlen(blackBoxCipherName(cipher));

  blackBoxInfo(cipher, &keySizeInBytes, &ivSizeInBytes, &suppressedBytes, &implicitBlockSizeInBytes);
  for (i=0; i<len+55; i++) logger(logFile, LOGNOFLUSH, "*");
  logger(logFile, LOGNOFLUSH, "\n*\n* Incremental %s Maxterm\n*\n* Key =", blackBoxCipherName(cipher));
  logBuf(logFile, LOGNOFLUSH, key, keySizeInBytes, 0, 0);
  logger(logFile, LOGNOFLUSH, " (%d bytes)\n*  IV =", keySizeInBytes);
  logBuf(logFile, LOGNOFLUSH, iv, ivSizeInBytes, 0, 0);
  logger(logFile, LOGNOFLUSH, " (%d bytes)\n*\n* Num bits = %d + %d = %d\n* Key bits = ", ivSizeInBytes, numKeyBits, numIvBits, numKeyBits + numIvBits);
  logBitSet(logFile, LOGNOFLUSH, keyBit, numKeyBits);
  logger(logFile, LOGNOFLUSH, "\n*  IV bits = ");
  logBitSet(logFile, LOGNOFLUSH, ivBit, numIvBits);
  logger(logFile, LOGNOFLUSH, "\n*\n* %d parallell bit%s\n*\n", parallellBits, parallellBits == 1 ? "" : "s");
  for (i=0; i<len+55; i++) logger(logFile, LOGNOFLUSH, "*");
  logger(logFile, LOGALL, "\n\n");
}

static void logBitSetResult(FILE *logFile, bbCipher cipher, int bitSetSize, BYTE *buf, time_t start) {
  int keySizeInBytes, ivSizeInBytes, suppressedBytes, implicitBlockSizeInBytes, r;

  blackBoxInfo(cipher, &keySizeInBytes, &ivSizeInBytes, &suppressedBytes, &implicitBlockSizeInBytes);
  r = blackBoxCipherType(cipher) == kBlockCipher ?
      numInitialZeroBitsByBlocks(buf, suppressedBytes, implicitBlockSizeInBytes) :
      numInitialZeroBits(buf, suppressedBytes);
  logger(logFile, LOGNOFLUSH, "[");
  logDateRaw(logFile, LOGNOFLUSH);
  logger(logFile, LOGNOFLUSH, ", ");
  logTimeRaw(logFile, LOGNOFLUSH, start);
  logger(logFile, LOGNOFLUSH, "], Bit set size %2d, %4d zero rounds,", bitSetSize + 1, r);
  logBuf(logFile, LOGNOFLUSH, buf, suppressedBytes, 0, 0);
  logger(logFile, LOGALL, "\n");
}

/* content of key and iv buffers will be modified (and restored) in this function */
static int incrementalMaxtermInternal(FILE *logFile, bbCipher cipher, BYTE *key, BYTE *iv, int numKeyBits, int *keyBit, int numIvBits, int *ivBit, int parallellBits) {
  int suppressedBytes, implicitBlockSizeInBytes;
  int inLen, currBit;
  BYTE *buf, *in;
  time_t start = time(NULL);

  /* allocate auxiliary buffers */
  blackBoxInfo(cipher, NULL, NULL, &suppressedBytes, &implicitBlockSizeInBytes);
  inLen = implicitBlockSizeInBytes;
  buf  = MALLOC(suppressedBytes); MEMSET(buf, 0, suppressedBytes);
  in   = MALLOC(inLen);           MEMSET(in,  0, inLen);

  /* log header */
  logHeader(logFile, cipher, key, iv, numKeyBits, keyBit, numIvBits, ivBit, parallellBits);

  /* compute xor step by step */

  /* initial xor (over empty bit set) */
  if (xorOverBitSet(cipher, key, 0, keyBit, iv, 0, ivBit, in, inLen, buf, parallellBits)) { logger(logFile, LOGALL, "xorOverBitSet error 2!\n"); FREE(buf); FREE(in); return 1; }
  logBitSetResult(logFile, cipher, -1, buf, start);

  /* now add one bit at a time */
  for (currBit = 0; currBit < numKeyBits + numIvBits; currBit++) {
    BYTE *b = currBit < numKeyBits ? key : iv;
    int bit = currBit < numKeyBits ? keyBit[currBit] : ivBit[currBit - numKeyBits];
    int currNumKeyBits = currBit > numKeyBits ? numKeyBits : currBit;
    int currNumIvBits = currBit > numKeyBits ? (currBit - numKeyBits) : 0;

    flipBufBit(b, bit); /* toggle bit in key or iv */
    if (xorOverBitSet(cipher, key, currNumKeyBits, keyBit, iv, currNumIvBits, ivBit, in, inLen, buf, parallellBits)) { logger(logFile, LOGALL, "xorOverBitSet error in loop!\n"); FREE(buf); FREE(in); return 1; } /* add xor over the set of previously handled bits (empty set first time, and so on) */
    flipBufBit(b, bit); /* toggle bit in key or iv (restore above toggle) */
    logBitSetResult(logFile, cipher, currBit, buf, start); /* log partial result */
  }

  /* cleanup */
  FREE(buf);
  FREE(in);
  return 0;
}

#define MAX_NUM_BITS 80
int blackBoxIncrementalMaxterm(int argc, char **argv) {
  FILE *logFile;
  bbCipher cipher;
  int numKeyBits, numIvBits;
  int keyBit[MAX_NUM_BITS], ivBit[MAX_NUM_BITS];
  BYTE *key, *iv;
  int keySizeInBytes, ivSizeInBytes;
  int parallellBits;

  /* process arguments */
  printAllArguments(argc, argv); /* print all arguments */
  getCipher(&cipher, argv[1]); /* get cipher */
  blackBoxInfo(cipher, &keySizeInBytes, &ivSizeInBytes, NULL, NULL);
  key = MALLOC(keySizeInBytes * sizeof(BYTE));
  iv = MALLOC(ivSizeInBytes * sizeof(BYTE));
  if (getKey(cipher, key, argv[2])) { FREE(key); FREE(iv); return 1; } /* get key fill */
  if (getIV(cipher, iv, argv[3])) { FREE(key); FREE(iv); return 1; } /* get iv fill */
  getNumParallellBits(&parallellBits, argv[4]); /* get num parallell bits */
  getBitSet(&numKeyBits, keyBit, &numIvBits, ivBit, &argv[5]); /* get bit set */
  logFile = openLogFile(cipher, key, iv, parallellBits); /* open log file */

  if (incrementalMaxtermInternal(logFile, cipher, key, iv, numKeyBits, keyBit, numIvBits, ivBit, parallellBits)) { FREE(key); FREE(iv); return 1; }

  fclose(logFile);
  FREE(key);
  FREE(iv);
  return 0;
}
