#include<stdio.h>
#include<stdlib.h>
#include<unistd.h>
#include<malloc.h>
#include<sys/mman.h>
#include<math.h>
#include<errno.h>
#include<string.h>
#include<sys/wait.h>
#include<time.h>

#include "hybridGrouping.h"
#include "asm.h"
#include "memlib/memoryInspect.h"
#include "hammerlib/afunc.h"
#include "measure.h"
#include "group.h"
#include "bankSequence.h"

char *ptr = NULL;
u_int64_t pageSize;

/*******************************************************************************
 * getAddressFunctions uses HammerLib to reverse engineer the address functions
 * of the system.
 *
 * @param nBanks: Number of banks in the system
 * @param nFunctions: Pointer to store the number of address functions found.
 * @return List of address functions that contains *nFunctions bit masks.
 ******************************************************************************/
u_int64_t *getAddressFunctions(u_int64_t nBanks, u_int64_t *nFunctions) {
  *nFunctions = 0;

  int totalErrors = 0;
  int nInitSets = 1;
  int verbosity = 2;
  int nChecks = 8;
  int iter = 100;
  int accessTimeCnt = -1;
  int getTime = 0;
  int vMode = 1;
  int scale = -1;
  int blockSize = sysconf(_SC_PAGESIZE);
  int measureRowSize = 0;
  int fenced = 1;
  int maxMaskBits = 7;

  hlMaskItems *mItems = hlConstructMaskItems();
  hlAddressGroups *aGroups = NULL;

  int groupValid = 0;
    for(int j = 0; j < 10 && groupValid == 0; j++) {
        aGroups = hlDestructAddressGroups(aGroups);
        aGroups = hlConstructAddressGroups(nBanks * 2, blockSize);
        totalErrors = 0;

        for(int i = 0; i < nInitSets; i++) {
            if(verbosity >= 2) {
                printf("\r[INFO]: Starting set %3d of %3d", i + 1, nInitSets);
            }
            if(verbosity >= 3) {
                printf("\n");
            }
            fflush(stdout);
            totalErrors += hlScanHugepages(1, aGroups, verbosity, nBanks, nChecks, iter, &accessTimeCnt, getTime, vMode, i, scale, blockSize, measureRowSize, fenced);
        }

        if(aGroups->aGroup[0]->nItems > ((16384/nBanks) * 11 / 10)) {
            if(verbosity >= 1) {
                printf("\n[WARN]: The group is invalid because the distribution is not equal.\n");
                printf("[WARN]: Retry to find better matching groups.\n");
            }
            aGroups->nItems = nBanks;
            hlPrintAddressGroupStats(aGroups, verbosity);
        } else {
            groupValid = 1;
            printf("\n");
        }
    }

    //Cut aGroups to banks (because banks * 2 is used for better accuracy in the beginning)
    aGroups->nItems = nBanks;
    hlPrintAddressGroupStats(aGroups, verbosity);

    if(verbosity >= 2) {
        printf("[INFO]: Scanned %d sets with %d errors (out of %ld addresses) in total.\n", nInitSets, totalErrors, nInitSets * 512 * sysconf(_SC_PAGESIZE) / blockSize);
    }

    if(verbosity >= 2) {
        printf("[INFO]: Scanning for masks. This may take a while.\n");
    }

    hlFindMasks(aGroups, maxMaskBits, mItems, verbosity);

    printf("\n");

    if(mItems == NULL || mItems->nItems == 0) {
        if(verbosity >= 0) {
            printf("[ERROR]: No masks found.\n");
        }
        return NULL;
    }

    hlSortAddressGroups(aGroups, mItems, verbosity);

    if(verbosity >= 2) {
        hlPrintAddressGroupStats(aGroups, 2);
    }

    hlDestructAddressGroups(aGroups);

  *nFunctions = mItems->nItems;
  return (u_int64_t *)(mItems->masks);
}

/*******************************************************************************
 * printHelp takes the name of the binary and an exit code, prints the usage
 * page of the program and exits afterwards using the exit code.
 *
 * @param binary: Name of the binary that was executed (e.g. argv[0]).
 * @param exitCode: Code that should be used as exit code when the execution is
 *        ended after the help message is printed.
 ******************************************************************************/
void printHelp(char *binary, u_int64_t exitCode) {
  printf("Usage: %s [-n nPages] [-s skipPerc] [-d] [-h]\n", binary);
  printf("\tnPages number of 4K Pages to allocate (default: 512)\n");
  printf("\tskipPerc percentage of addresses that is not grouped (default: 20)\n");
  exit(exitCode);
}


/*******************************************************************************
 * analyzePages implements the core functionality of the hybrid grouping PoC:
 * In the beginning, the threshold is measured. Afterwards, the number of banks
 * and initial groups are calculated. Next, the address functions are reverse
 * engineered using HammerLib. After that, the sequence of banks and the mapping
 * between group indices and bank numbers are calculated. Next, the addresses
 * are iterated and added to the groups based on the calculated block size and
 * the bank sequence.
 *
 * @param aInfo: Addresses that should be grouped
 * @param len: Number of addresses that should be grouped
 * @param skipPerc: Skip the first skipPerc percent of memory addresses. Due to
 *        the behaviour of the buddy allocator, the first blocks are small so
 *        there is no benefit in using hybridGrouping. The bigger the blocks are
 *        the more advantage brings the usage of HybridGrouping.
 ******************************************************************************/
int analyzePages(addrInfo **aInfo, u_int64_t len, u_int64_t skipPerc) {
  // Initialization
  u_int64_t maxAddrs = 512;
  u_int64_t maxGroupComparisons = 5;
  u_int64_t nHugePages = 1;
  u_int64_t measurementsPerAddress = 5;
  u_int64_t maxFailedVerifications = 3;

  // Measure threshold
  u_int64_t threshold = measure_system_threshold(0, 25);

  // Calculate number of banks and initial groups
  u_int64_t nAddrs = len;
  if(len > maxAddrs) {
    nAddrs = maxAddrs;
  }

  addressGroups *aGroups = groupAddresses(aInfo, nAddrs, threshold, maxGroupComparisons, measurementsPerAddress, FL_NONE);

  // assume as not valide because first grouping often results in the double
  // amount of groups than there are banks in the system
  u_int64_t isValid = 0;
  while(!isValid) {
    aGroups = regroupAddresses(aGroups, threshold, maxGroupComparisons, measurementsPerAddress, FL_NONE);
    aGroups = removeSmallGroups(aGroups, 4, 0);

    printf("\nGroups after regroup: %ld\n", aGroups->nAddressGroups);

    if(countBits(aGroups->nAddressGroups) == 1) {
      isValid = 1;
    }
  }

  // Add physical address information, should be removed for productive use
  // since that requires root privileges. For the PoC it is needed for
  // evaluation.
  for(u_int64_t groupIdx = 0; groupIdx < aGroups->nAddressGroups; groupIdx++) {
    int64_t *hvas = malloc(sizeof(int64_t) * aGroups->addressGroups[groupIdx]->nAddresses);
    for(u_int64_t addrIdx = 0; addrIdx < aGroups->addressGroups[groupIdx]->nAddresses; addrIdx++) {
      hvas[addrIdx] = (int64_t)aGroups->addressGroups[groupIdx]->addresses[addrIdx];
    }

    addrInfo **aInfo = getAddrInfoFromHva(hvas, aGroups->addressGroups[groupIdx]->nAddresses);
    free(hvas);
    updatePhysicalInformation(aInfo, aGroups->addressGroups[groupIdx]->nAddresses);
  }

  printf("Assuming %ld banks.\n", aGroups->nAddressGroups);


  // Measure address functions
  u_int64_t nFns = 0;
  //u_int64_t fns[] = {0x2000, 0x48000, 0x24000, 0x90000};
  u_int64_t *fns = getAddressFunctions(aGroups->nAddressGroups, &nFns);

  // Calculate the sequence of banks
  bankInformation *bInfo = getBankInformation(fns, nFns, FL_NONE);

  // Calculate the group <-> bank mapping
  bInfo = addBankMappingInformation(bInfo, fns, nFns, aGroups, threshold, maxGroupComparisons, nHugePages, measurementsPerAddress);

  // Do not regroup the addresses that are already grouped
  u_int64_t blockSize = 1;
  u_int64_t contCnt = 0;

  u_int64_t *partSequence = NULL;

  u_int64_t digits = (u_int64_t)ceil(log10(len * 1.0));

  u_int64_t *measuredBanks = NULL;
  u_int64_t verificationBank = 0;

  u_int64_t *nextMeasuredBanks = NULL;
  u_int64_t nextVerificationBank = 0;

  u_int64_t lastNextBankMatched = 0;
  u_int64_t nFailedVerifications = 0;
  uniqueBankOffset *nextUBankOffsets;

  if(len / 10 > nAddrs && skipPerc < 100) {
    // skip the first nPerc percent of allocated blocks, only for the PoC to
    // increase speed by skipping small blocks
    nAddrs = len / 100 * skipPerc;
  }

  u_int64_t nErrors = 0;
  // Iterate over all addresses, calculate the current block size and add the
  // addresses to the groups using the bank sequence
  for(u_int64_t i = nAddrs; i < len - blockSize; i+= blockSize) {
    printf("Handle address %ld of %ld (block size: %ld)                   \r", i, len, blockSize);
    fflush(stdout);

    uniqueBankOffset *uBankOffsets = getBankOffsetForBlockSize(bInfo, blockSize);
    measuredBanks = realloc(measuredBanks, sizeof(u_int64_t) * uBankOffsets->nOffsets);

    if(blockSize < bInfo->nBankSequence) {
      nextUBankOffsets = getBankOffsetForBlockSize(bInfo, blockSize * 2);
      nextMeasuredBanks = realloc(nextMeasuredBanks, sizeof(u_int64_t) * nextUBankOffsets->nOffsets);
    }

    u_int64_t *requiredOffsets = NULL;
    u_int64_t nRequiredOffsets = 0;

    for(u_int64_t idx = 0; idx < uBankOffsets->nOffsets; idx++) {
      requiredOffsets = uniqueAddToArray(requiredOffsets, &nRequiredOffsets, uBankOffsets->offsets[idx]);
    }
    requiredOffsets = uniqueAddToArray(requiredOffsets, &nRequiredOffsets, uBankOffsets->verificationOffset);

    for(u_int64_t idx = 0; idx < nextUBankOffsets->nOffsets; idx++) {
      requiredOffsets = uniqueAddToArray(requiredOffsets, &nRequiredOffsets, nextUBankOffsets->offsets[idx]);
    }
    requiredOffsets = uniqueAddToArray(requiredOffsets, &nRequiredOffsets, nextUBankOffsets->verificationOffset);

    u_int64_t *measuredBanks = malloc(sizeof(u_int64_t) * nRequiredOffsets);

    for(u_int64_t idx = 0; idx < nRequiredOffsets; idx++) {
      int mustGroup = 1;
      int64_t groupIdx = getGroupIdxForAddress(aGroups, maxGroupComparisons, threshold, (volatile char *)(aInfo[i + requiredOffsets[idx]]->hva), measurementsPerAddress, FL_NONE, mustGroup);
      measuredBanks[idx] = bInfo->groupToBankMapping[groupIdx];
    }


    for(u_int64_t idx = 0; idx < uBankOffsets->nOffsets; idx++) {
      measuredBanks[idx] = measuredBanks[findValue(requiredOffsets, nRequiredOffsets, uBankOffsets->offsets[idx])];
    }
    verificationBank = measuredBanks[findValue(requiredOffsets, nRequiredOffsets, uBankOffsets->verificationOffset)];

    for(u_int64_t idx = 0; idx < nextUBankOffsets->nOffsets; idx++) {
      nextMeasuredBanks[idx] = measuredBanks[findValue(requiredOffsets, nRequiredOffsets, nextUBankOffsets->offsets[idx])];
    }
    nextVerificationBank = measuredBanks[findValue(requiredOffsets, nRequiredOffsets, nextUBankOffsets->verificationOffset)];

    partSequence = malloc(sizeof(u_int64_t) * blockSize);
    if(verifyBankSequence(bInfo, measuredBanks, uBankOffsets->nOffsets, verificationBank, blockSize, &partSequence, FL_NONE)) {
      // Add addresses based on the part sequence
      u_int64_t errCnt = 0;
      for(u_int64_t idx = 0; idx < blockSize; idx++) {
        u_int64_t groupIdx = bInfo->bankToGroupMapping[partSequence[idx]];
        aGroups->addressGroups[groupIdx] = addAddressToAddressGroup(aGroups->addressGroups[groupIdx], (volatile char *)(aInfo[i + idx]->hva), pageSize);

        u_int64_t hybridBankIdx = partSequence[idx];
        u_int64_t pfnBankIdx = calculateBankFromAddress((volatile char *)(aInfo[i + idx]->pfn << 12), fns, nFns);

        if(hybridBankIdx != pfnBankIdx) {
          printf("\nThe calculated bank index (%ld) and real bank index (%ld) differ at offset %ld.\n", hybridBankIdx, pfnBankIdx, idx);
          errCnt++;
        }
      }

      if(errCnt != 0) {
        nErrors += errCnt;
        printf("\nThere were %ld errors in grouping.\n", errCnt);
      }
    } else {
      if (blockSize > 1) {
        //reset address to the beginning of the failed block, set block
        //size to half and repeat.
        i -= blockSize;
        if(nFailedVerifications >= maxFailedVerifications) {
          nFailedVerifications = 0;
          blockSize /= 2;
          printf("\n\033[1m\033[38;5;9m[%0*ld] Decreasing block size to %ld\033[0m\n", (int)digits, i, blockSize);
          continue;
        }
        nFailedVerifications++;
        continue;
      }
      contCnt = 0;
    }

      if(verifyBankSequence(bInfo, nextMeasuredBanks, nextUBankOffsets->nOffsets, nextVerificationBank, blockSize * 2, NULL, FL_NONE)) {
        contCnt ++;
        lastNextBankMatched = 1;
      } else {
        if(!lastNextBankMatched) {
          contCnt = 0;
        }
        lastNextBankMatched = 0;
      }

    if(contCnt >= 20 && lastNextBankMatched && blockSize < bInfo->nBankSequence) {
      //Switch only to a bigger block size when the last sub block did not
      //match, so the next one will probabily be alligned correctly.
      blockSize *= 2;
      printf("\n\033[1m\033[38;5;40m[%0*ld] Increasing block size to %ld\033[0m\n", (int)digits, i, blockSize);
      contCnt = 0;
    }

  }

  u_int64_t perc =  nErrors * 10000 / (len - nAddrs + 512);
  printf("Grouped %ld addresses with %ld errors (%3ld.%02ld%%).\n", len - nAddrs + 512, nErrors, perc / 100, perc % 100);

  free(partSequence);

  return 1;
}

int main(int argc, char *const argv[]) {
  srand(time(NULL));

  u_int64_t nPages = 65536;
  u_int64_t debug = 0;
  u_int64_t skipPerc = 20;

  u_int64_t opt;
  extern char *optarg;
  while((opt = getopt(argc, argv, "n:ghd")) != -1) {
    switch(opt) {
      case 'n':
        nPages = parseNumber(optarg);
        if(nPages == 0) {
          printf("Invalid value %s for nPages.\n", optarg);
          printHelp(argv[0], EXIT_FAILURE);
        }
        break;
      case 's':
        skipPerc = parseNumber(optarg);
        break;
      case 'd':
        debug = 1;
        break;
      case 'h':
        printHelp(argv[0], EXIT_SUCCESS);
        break;
      default:
        printHelp(argv[0], EXIT_FAILURE);
        break;
    }
  }


  //printf("[INFO]: Allocating %d 4K Pages\n", nPages);
  pageSize = sysconf(_SC_PAGESIZE);

  u_int64_t analysisSuccessful = 0;
  while(analysisSuccessful == 0) {
    int64_t *hvas = malloc(sizeof(int64_t) * nPages);

    volatile char *ptr = mmap(NULL, pageSize * sizeof(char) * nPages, PROT_READ|PROT_WRITE, MAP_SHARED|MAP_ANONYMOUS|MAP_LOCKED, -1, 0);
    if(ptr == MAP_FAILED) {
      printf("%s\n", strerror(errno));
      return -1;
    }

    for(u_int64_t i = 0; i < nPages; i++) {
      hvas[i] = (int64_t)(ptr + i * pageSize);
    }
    addrInfo **aInfo = getAddrInfoFromHva(hvas, nPages);
    free(hvas);

    updatePhysicalInformation(aInfo, nPages);
    printMappings(aInfo, nPages, debug);

    analysisSuccessful = analyzePages(aInfo, nPages, skipPerc);

    munmap((void *)ptr, pageSize * sizeof(char) * nPages);

    for(int i = 0; i < nPages; i++) {
      destructAddrInfoItem(aInfo[i]);
    }
    free(aInfo);
  }

  return EXIT_SUCCESS;
}
