/*
 * Copyright (c) Paul Stankovski
 * Free for all non-commercial use unless this directive conflicts with
 * other applicable copyright statement(s), patent holders, laws or such.
 */
#include "black_box_hight.h"

#define BLOCK_LEN 8
#define NUM_ROUNDS 32

/*
 * The Jaechul Sung test vectors seem to be generated from a version of the algorithm
 * that uses the subkeys in a different order.
 * (Note! It seems that the final output block is in reverse byte order as well)
 * Define the macro below for this version.
 * By not defining the macro, the code will be according to our interpretation
 * of the specification.
 **/
/* #define SUNG_TEST_VECTOR_VERSION */

typedef struct {
  BYTE *key;
  BYTE *out;
  BYTE WK[8];
  BYTE SK[128];
} HIGHT_info;


void subkeyGeneration(HIGHT_info *B) {

  BYTE d[128];
  int i, j;
  UINT32 s = 0x5A;

  /* generate constants */
  d[0] = (BYTE)s;
  for (i=1; i<128; i++) {
    s |= (((s >> 3) ^ s) & 1) << 7;
    s >>= 1;
    d[i] = (BYTE)s;
  }

  /* subkey generation */
  for (i=0; i<8; i++) {
    for (j=0; j<8; j++) {
      B->SK[16*i+j] = (BYTE)(B->key[15 - ((8+j-i) & 7)] + d[16*i+j]);
      B->SK[16*i+j+8] = (BYTE)(B->key[15 - (((8+j-i) & 7) + 8)] + d[16*i+j+8]);
    }
  }
}

void whiteningKeyGeneration(HIGHT_info *B) {
  B->WK[0] = B->key[3];
  B->WK[1] = B->key[2];
  B->WK[2] = B->key[1];
  B->WK[3] = B->key[0];
  B->WK[4] = B->key[15];
  B->WK[5] = B->key[14];
  B->WK[6] = B->key[13];
  B->WK[7] = B->key[12];
}

void keySchedule(HIGHT_info *B) {
  whiteningKeyGeneration(B);
  subkeyGeneration(B);
}

void HIGHT_init(HIGHT_info *B) {
  keySchedule(B);
}

#define ROTL(x, n) ((((x) << (n)) | ((x) >> (8 - (n)))) & 0xFF)
#define F0(x) (ROTL(x, 1) ^ ROTL(x, 2) ^ ROTL(x, 7))
#define F1(x) (ROTL(x, 3) ^ ROTL(x, 4) ^ ROTL(x, 6))
void HIGHT_encryptBlock(HIGHT_info *B, const BYTE *x) {

  BYTE X[8];
  BYTE t0, t1, t2, t3, t4, t5, t6, t7;
  BYTE *out = B->out;
  int i;

  /* initial transformation */
  X[0] = (BYTE)(x[7] + B->WK[0]); X[1] = x[6];
  X[2] = (BYTE)(x[5] ^ B->WK[1]); X[3] = x[4];
  X[4] = (BYTE)(x[3] + B->WK[2]); X[5] = x[2];
  X[6] = (BYTE)(x[1] ^ B->WK[3]); X[7] = x[0];

  for (i=0; i<NUM_ROUNDS; i++) {
    t1 = X[1]; t3 = X[3]; t5 = X[5]; t7 = X[7];
    X[1] = t0 = X[0];
    X[3] = t2 = X[2];
    X[5] = t4 = X[4];
    X[7] = t6 = X[6];
#ifdef SUNG_TEST_VECTOR_VERSION
    X[0] = (BYTE)(t7 ^ (F0(t6) + B->SK[4*i+3]));
    X[2] = (BYTE)(t1 + (F1(t0) ^ B->SK[4*i+0]));
    X[4] = (BYTE)(t3 ^ (F0(t2) + B->SK[4*i+1]));
    X[6] = (BYTE)(t5 + (F1(t4) ^ B->SK[4*i+2]));
#else /* according to specification */
    X[0] = (BYTE)(t7 ^ (F0(t6) + B->SK[4*i+3]));
    X[2] = (BYTE)(t1 + (F1(t0) ^ B->SK[4*i+2]));
    X[4] = (BYTE)(t3 ^ (F0(t2) + B->SK[4*i+1]));
    X[6] = (BYTE)(t5 + (F1(t4) ^ B->SK[4*i]));
#endif
  }

  /* final transformation */
#ifdef SUNG_TEST_VECTOR_VERSION
  *out++ = (BYTE)(X[1] + B->WK[4]); *out++ = X[2];
  *out++ = (BYTE)(X[3] ^ B->WK[5]); *out++ = X[4];
  *out++ = (BYTE)(X[5] + B->WK[6]); *out++ = X[6];
  *out++ = (BYTE)(X[7] ^ B->WK[7]); *out   = X[0];
#else /* according to specification */
  *out++ = X[0]; *out++ = (BYTE)(X[7] ^ B->WK[7]);
  *out++ = X[6]; *out++ = (BYTE)(X[5] + B->WK[6]);
  *out++ = X[4]; *out++ = (BYTE)(X[3] ^ B->WK[5]);
  *out++ = X[2]; *out   = (BYTE)(X[1] + B->WK[4]);
#endif
}

void HIGHT_encryptBlock_xor_withInitOutput(HIGHT_info *B, const BYTE *x, BYTE *lastBlock, int numBlocks) {

  BYTE X[8];
  BYTE t0, t1, t2, t3, t4, t5, t6, t7;
  BYTE *out = B->out;
  int i;

  /* initial transformation */
  X[0] = (BYTE)(x[7] + B->WK[0]); X[1] = x[6];
  X[2] = (BYTE)(x[5] ^ B->WK[1]); X[3] = x[4];
  X[4] = (BYTE)(x[3] + B->WK[2]); X[5] = x[2];
  X[6] = (BYTE)(x[1] ^ B->WK[3]); X[7] = x[0];

  for (i=0; i<NUM_ROUNDS-1; i++) {
    t1 = X[1]; t3 = X[3]; t5 = X[5]; t7 = X[7];
    X[1] = t0 = X[0];
    X[3] = t2 = X[2];
    X[5] = t4 = X[4];
    X[7] = t6 = X[6];
#ifdef SUNG_TEST_VECTOR_VERSION
    X[0] = (BYTE)(t7 ^ (F0(t6) + B->SK[4*i+3]));
    X[2] = (BYTE)(t1 + (F1(t0) ^ B->SK[4*i+0]));
    X[4] = (BYTE)(t3 ^ (F0(t2) + B->SK[4*i+1]));
    X[6] = (BYTE)(t5 + (F1(t4) ^ B->SK[4*i+2]));
#else /* according to specification */
    X[0] = (BYTE)(t7 ^ (F0(t6) + B->SK[4*i+3]));
    X[2] = (BYTE)(t1 + (F1(t0) ^ B->SK[4*i+2]));
    X[4] = (BYTE)(t3 ^ (F0(t2) + B->SK[4*i+1]));
    X[6] = (BYTE)(t5 + (F1(t4) ^ B->SK[4*i]));
#endif

    /* output intermediate block */
    *out++ ^= X[0];
    *out++ ^= X[1];
    *out++ ^= X[2];
    *out++ ^= X[3];
    *out++ ^= X[4];
    *out++ ^= X[5];
    *out++ ^= X[6];
    *out++ ^= X[7];
    if (--numBlocks == 0) return;
  }

  /* last round */
  t1 = X[1]; t3 = X[3]; t5 = X[5]; t7 = X[7];
  X[1] = t0 = X[0];
  X[3] = t2 = X[2];
  X[5] = t4 = X[4];
  X[7] = t6 = X[6];
#ifdef SUNG_TEST_VECTOR_VERSION
  X[0] = (BYTE)(t7 ^ (F0(t6) + B->SK[4*i+3]));
  X[2] = (BYTE)(t1 + (F1(t0) ^ B->SK[4*i+0]));
  X[4] = (BYTE)(t3 ^ (F0(t2) + B->SK[4*i+1]));
  X[6] = (BYTE)(t5 + (F1(t4) ^ B->SK[4*i+2]));
#else /* according to specification */
  X[0] = (BYTE)(t7 ^ (F0(t6) + B->SK[4*i+3]));
  X[2] = (BYTE)(t1 + (F1(t0) ^ B->SK[4*i+2]));
  X[4] = (BYTE)(t3 ^ (F0(t2) + B->SK[4*i+1]));
  X[6] = (BYTE)(t5 + (F1(t4) ^ B->SK[4*i]));
#endif

  /* final transformation */
  t1 = (BYTE)(X[1] + B->WK[4]);
  t3 = (BYTE)(X[3] ^ B->WK[5]);
  t5 = (BYTE)(X[5] + B->WK[6]);
  t7 = (BYTE)(X[7] ^ B->WK[7]);

#ifdef SUNG_TEST_VECTOR_VERSION
  /* output last block */
  *out++ ^= t1; *out++ ^= X[2];
  *out++ ^= t3; *out++ ^= X[4];
  *out++ ^= t5; *out++ ^= X[6];
  *out++ ^= t7; *out   ^= X[0];

  /* extract last block */
  lastBlock[0] = t1; lastBlock[1] = X[2];
  lastBlock[2] = t3; lastBlock[3] = X[4];
  lastBlock[4] = t5; lastBlock[5] = X[6];
  lastBlock[6] = t7; lastBlock[7] = X[0];
#else /* according to specification */
  /* output last block */
  *out++ ^= X[0]; *out++ ^= t7;
  *out++ ^= X[6]; *out++ ^= t5;
  *out++ ^= X[4]; *out++ ^= t3;
  *out++ ^= X[2]; *out   ^= t1;

  /* extract last block */
  lastBlock[0] = X[0]; lastBlock[1] = t7;
  lastBlock[2] = X[6]; lastBlock[3] = t5;
  lastBlock[4] = X[4]; lastBlock[5] = t3;
  lastBlock[6] = X[2]; lastBlock[7] = t1;
#endif
}

/******************************************************************************
 * Black box variants
 ******************************************************************************/
#define COPY_BUF_8(dst, src) { UINT64 *d = (UINT64*)(dst); UINT64 *s = (UINT64*)(src); *d = *s; }
#define XOR_BUF_8(dst, src) { UINT64 *d = (UINT64*)(dst); UINT64 *s = (UINT64*)(src); *d ^= *s; }
int HIGHT_xor(const BYTE *key, const BYTE *iv, const BYTE *inBuf, unsigned int numInputBytes, BYTE *outBuf, unsigned int numOutputBytes) {
  HIGHT_info ctx;
  BYTE pt[BLOCK_LEN];
  BYTE ct[BLOCK_LEN];
  int i;
  const int numBlocks = numOutputBytes / BLOCK_LEN;

  if (numOutputBytes == 0) return 0;
  if (numInputBytes < numOutputBytes) return -1;
  if ((numOutputBytes % BLOCK_LEN) != 0) return -1;

  ctx.key = (BYTE*)key;
  ctx.out = ct;

  HIGHT_init(&ctx);
  for (i=0; i<numBlocks; i++) {
    COPY_BUF_8(pt, i==0 ? iv : ct)
    XOR_BUF_8(pt, inBuf + i * BLOCK_LEN)
    HIGHT_encryptBlock(&ctx, pt);
    XOR_BUF_8(outBuf + i * BLOCK_LEN, ct)
  }
  return 0;
}

int HIGHT_xor_withInitOutput(const BYTE *key, const BYTE *iv, const BYTE *inBuf, unsigned int numInputBytes, BYTE *outBuf, unsigned int numOutputBytes) {
  HIGHT_info ctx;
  BYTE pt[BLOCK_LEN];
  BYTE ct[BLOCK_LEN];
  int i;
  int numBlocks = numOutputBytes / BLOCK_LEN;
  const int numSuppressedBytes = (NUM_ROUNDS - 1) * BLOCK_LEN;

  if (numOutputBytes == 0) return 0;
  if (numInputBytes + numSuppressedBytes < numOutputBytes) return -1;
  if ((numOutputBytes % BLOCK_LEN) != 0) return -1;
  if (inBuf == outBuf) return -1;

  ctx.key = (BYTE*)key;
  ctx.out = outBuf;

  HIGHT_init(&ctx);

  /* first ROUNDS blocks ((ROUNDS - 1) suppressed + first) */
  COPY_BUF_8(pt, iv)
  XOR_BUF_8(pt, inBuf)
  HIGHT_encryptBlock_xor_withInitOutput(&ctx, pt, ct, numBlocks);
  if (numBlocks <= NUM_ROUNDS)
    return 0;
  ctx.out = ct;
  numBlocks -= NUM_ROUNDS;

  /* remaining blocks */
  for (i=0; i<numBlocks; i++) {
    COPY_BUF_8(pt, ct)
    XOR_BUF_8(pt, inBuf + (i + 1) * BLOCK_LEN)
    HIGHT_encryptBlock(&ctx, pt);
    XOR_BUF_8(outBuf + (i + NUM_ROUNDS) * BLOCK_LEN, ct)
  }
  return 0;
}

/******************************************************************************
 * Black box API
 ******************************************************************************/
int blackBoxHIGHTEncryption(const BYTE *key, const BYTE *iv, const BYTE *inBuf, unsigned int numInputBytes, BYTE *outBuf, unsigned int numOutputBytes, int withInitRoundOutput) {
#ifdef SUNG_TEST_VECTOR_VERSION
  const BYTE *_key = (const BYTE *)"\x00\x11\x22\x33\x44\x55\x66\x77\x88\x99\xaa\xbb\xcc\xdd\xee\xff";
  const BYTE *_inBuf = (const BYTE *)"\x00\x00\x00\x00\x00\x00\x00\x00";
  (void)key;
  (void)inBuf;
  if (withInitRoundOutput)
      return HIGHT_xor_withInitOutput(_key, iv, _inBuf, numInputBytes, outBuf, numOutputBytes);
  return HIGHT_xor(_key, iv, _inBuf, numInputBytes, outBuf, numOutputBytes);
#else
  if (withInitRoundOutput)
      return HIGHT_xor_withInitOutput(key, iv, inBuf, numInputBytes, outBuf, numOutputBytes);
  return HIGHT_xor(key, iv, inBuf, numInputBytes, outBuf, numOutputBytes);
#endif
}

/******************************************************************************
 * Basic cipher information
 ******************************************************************************/
void getBlackBoxHIGHTInfo(int *keySizeInBytes, int *ivSizeInBytes, int *suppressedBytes, int *implicitBlockSizeInBytes) {
  if (keySizeInBytes) *keySizeInBytes = 16;
  if (ivSizeInBytes) *ivSizeInBytes = BLOCK_LEN;
  if (suppressedBytes) *suppressedBytes = (NUM_ROUNDS - 1) * BLOCK_LEN;
  if (implicitBlockSizeInBytes) *implicitBlockSizeInBytes = BLOCK_LEN;
}

