#include<stdio.h>
#include<stdlib.h>

#include "group.h"
#include "memlib/memoryInspect.h"
#include "helper.h"
#include "measure.h"

/*******************************************************************************
 * getRandomIndices randomly selects maxGroupComparisons indices in the range
 * between 0 and len.
 *
 * @param len: Length of the list for which the random indices should be
 *        calculated
 * @param maxGroupComparisons: Number of entries that should be returned (if
 *        there are more than maxGroupComparisons items in the list)
 * @return Randomly selected indices
 ******************************************************************************/
u_int64_t *getRandomIndices(u_int64_t len, u_int64_t maxGroupComparisons) {
  u_int64_t *sequence = malloc(sizeof(u_int64_t) * len);

  for(u_int64_t i = 0; i < len; i++) {
    sequence[i] = i;
  }

  sequence = shuffleUInt64Array(sequence, len);

  u_int64_t *indices = malloc(sizeof(u_int64_t) * maxGroupComparisons);
  for(u_int64_t i = 0; i < len && i < maxGroupComparisons; i++) {
    indices[i] = sequence[i];
  }

  free(sequence);
  return indices;
}


/*******************************************************************************
 * timeAddressToGroup measures alternating access between an address group (by
 * using randomly chosen addresses from that group) and a single address.
 *
 * @param aGroup: Address group that should be compared
 * @param maxGroupComparisons: Maximum number of addresses from that group that
 *        should be used for comparison (if there are less addresses in the
 *        group, all addresses are used)
 * @param addr: Address that should be compared
 * @param measurementsPerAddress: Number of the measurements that should be
 *        performed for each address comparison (e.g. addr against each address
 *        from the group up to maxGroupComparison addresses)
 * @return Median access time (like returned using rdtscp)
 ******************************************************************************/
u_int64_t timeAddressToGroup(addressGroup *aGroup, u_int64_t maxGroupComparisons, volatile char *addr, u_int64_t measurementsPerAddress) {
  u_int64_t *times = malloc(sizeof(u_int64_t) * maxGroupComparisons * measurementsPerAddress);

  u_int64_t realGroupComparisons = 0;
  u_int64_t *indices = getRandomIndices(aGroup->nAddresses, maxGroupComparisons);
  for(u_int64_t idx = 0; idx < maxGroupComparisons && idx < aGroup->nAddresses; idx++) {
    u_int64_t addressIdx = indices[idx];
    for(u_int64_t nMeasurement = 0; nMeasurement < measurementsPerAddress; nMeasurement++) {
      if(addressIdx < aGroup->nAddresses) {
        realGroupComparisons ++;
        times[idx*measurementsPerAddress + nMeasurement] = measure_time(addr, aGroup->addresses[addressIdx]);
      } else {
        times[idx*measurementsPerAddress + nMeasurement] = 0;
      }
    }
  }
  for(u_int64_t idx = aGroup->nAddresses; idx < maxGroupComparisons; idx++) {
    for(u_int64_t nMeasurement = 0; nMeasurement < measurementsPerAddress; nMeasurement++) {
      times[idx*measurementsPerAddress + nMeasurement] = 0;
    }
  }

  free(indices);

  qsort(times, realGroupComparisons, sizeof(u_int64_t), compareInt);

  u_int64_t time = times[realGroupComparisons/2];
  free(times);

  return time;
}

/*******************************************************************************
 * getGroupIdxForAddress takes a list of address groups and an address, and
 * compares the address to the address groups. If the address group with the
 * longest time has a bigger access time than the threshold, that group is
 * considered at match. Otherwise, no match was found.
 *
 * @param aGroups: List of address groups the address should be compared to
 * @param maxGroupComparisons: Maximum addresses that are compared for each
 *        group
 * @param threshold: Threshold value between row hit and row conflict
 * @parma addr: Address that should be compared with the groups
 * @param measurementsPerAddress: Number of measurements that is done for each
 *        address
 * @param flags: Additional flags
 * @param mustGroup: If set to 1, the group with the biggest access time is used
 *        ignoring the threshold value.
 * @return Index of the group that matches or -1 if no group matches.
 ******************************************************************************/
int64_t getGroupIdxForAddress(addressGroups *aGroups, u_int64_t maxGroupComparisons, u_int64_t threshold, volatile char *addr, u_int64_t measurementsPerAddress, flag_t flags, int mustGroup) {
  int64_t retGroupIdx = -1;
  u_int64_t maxTime = 0;
  for(u_int64_t groupIdx = 0; groupIdx < aGroups->nAddressGroups; groupIdx++) {
    if(aGroups->addressGroups[groupIdx] == NULL) {
      continue;
    }

    u_int64_t time = timeAddressToGroup(aGroups->addressGroups[groupIdx], maxGroupComparisons, addr, measurementsPerAddress);
    if(time > threshold || mustGroup == 1) {
      if(flags & FL_DEBUG) {
        printf("[DEBUG]: Address 0x%lx matches group %ld with a time of %ld\n", (u_int64_t)addr, groupIdx, time);
      }
      if(time > maxTime) {
        retGroupIdx = groupIdx;
        maxTime = time;
      }
    }
  }

  return retGroupIdx;
}

/*******************************************************************************
 * groupAddresses takes a list of addrInfo structs and groups them according to
 * the submitted parameters. That grouing is initial, e.g. there are no existing
 * groups in the beginning.
 *
 * @param aInfo: List of addrInfo structs that should be grouped
 * @param len: Number of addrInfo elements in the list
 * @param threshold: Threshold  between row hit and row conflict
 * @param maxGroupComparisons: Maximum number of addresses compared for each
 *        group
 * @param measurementsPerAddress: Number of measurements done for each address
 *        comparison
 * @param flags: Additional flags
 * @return Address Groups that contain the submitted addresses
 ******************************************************************************/
addressGroups *groupAddresses(addrInfo ** aInfo, u_int64_t len, u_int64_t threshold, u_int64_t maxGroupComparisons, u_int64_t measurementsPerAddress, flag_t flags) {
  addressGroups *aGroups = initializeAddressGroups();

  u_int64_t pageSize = getPageSize();

  for(u_int64_t pageIdx = 0; pageIdx < len; pageIdx++) {
    printf("Comparing page %ld of %ld\r", pageIdx + 1, len);
    fflush(stdout);
    int mustGroup = 0;
    int64_t groupIdx = getGroupIdxForAddress(aGroups, maxGroupComparisons, threshold, (volatile char *)(aInfo[pageIdx]->hva), measurementsPerAddress, flags, mustGroup);
    if(groupIdx >= 0) {
      aGroups->addressGroups[groupIdx] = addAddressToAddressGroup(aGroups->addressGroups[groupIdx], (volatile char *)(aInfo[pageIdx]->hva), pageSize);
    } else {
      aGroups = addAddressGroupToAddressGroups(aGroups, NULL);
      aGroups->addressGroups[aGroups->nAddressGroups - 1] = addAddressToAddressGroup(aGroups->addressGroups[aGroups->nAddressGroups -1], (volatile char *)(aInfo[pageIdx]->hva), pageSize);
    }
  }
  printf("\n");


  return aGroups;
}

/*******************************************************************************
 * removeSmallGroups takes address groups and removes all groups that have fewer
 * than 1/fractionOfMedian addresses of the median number of addresses of all
 * address groups.
 *
 * @param aGroups Address Groups that should be reduced
 * @param fractionOfMedian: all groups with less than 1/fractionOfMedian
 *        addresses of the median number of addresses over all groups are
 *        removed
 * @param freeAddressesInRemovedGroups: If set to 1, the addresses in the groups
 *        that are removed are freed. Otherwise, they are not modified.
 * @return Reduced list of address groups
 ******************************************************************************/
addressGroups *removeSmallGroups(addressGroups *aGroups, u_int64_t fractionOfMedian, u_int64_t freeAddressesInRemovedGroups) {
  u_int64_t *sizes = malloc(sizeof(u_int64_t) * aGroups->nAddressGroups);
  for(u_int64_t i = 0; i < aGroups->nAddressGroups; i++) {
    sizes[i] = aGroups->addressGroups[i]->nAddresses;
  }

  qsort(sizes, aGroups->nAddressGroups, sizeof(u_int64_t), compareInt);

  u_int64_t median = sizes[aGroups->nAddressGroups/2];

  for(u_int64_t i = 0; i < aGroups->nAddressGroups; i++) {
    if(aGroups->addressGroups[i]->nAddresses < median / fractionOfMedian) {
      aGroups = removeAddressGroupFromAddressGroups(aGroups, i--, 1, freeAddressesInRemovedGroups);
    }
  }
  return aGroups;
}

/*******************************************************************************
 * regroupAddresses takes address groups and removes one group after another.
 * Each address from the currently removed group is grouped in the other groups.
 * If there is no match, a new group is created. When the groups are stable,
 * the resulting address groups should contain the same addresses as the
 * submitted ones. However, the indices can differ due to the implementation.
 *
 * @param aGroups: List of address groups that should be regrouped
 * @param threshold: Threshold between row hit and row conflict
 * @param maxGroupComparison: Maximum number of addresses taken from the
 *        existing address groups for comparison
 * @param measurementsPerAddress: Number of measurements for each address
 *        combination
 * @param flags: Various flags
 * @return Regrouped address groups
 ******************************************************************************/
addressGroups *regroupAddresses(addressGroups *aGroups, u_int64_t threshold, u_int64_t maxGroupComparisons, u_int64_t measurementsPerAddress, flag_t flags) {
  u_int64_t pageSize = getPageSize();
  u_int64_t regrouped = 0;
  u_int64_t nAddressGroups = aGroups->nAddressGroups;
  for(u_int64_t groupIdx = 0; groupIdx < nAddressGroups; groupIdx++) {
    printf("Regrouping %ld of %ld (%ld addresses)\r", groupIdx, nAddressGroups, aGroups->addressGroups[groupIdx]->nAddresses);
    fflush(stdout);
    addressGroup *aGroup = aGroups->addressGroups[groupIdx];
    removeAddressGroupFromAddressGroupsWithoutRelocate(aGroups, groupIdx, 0, 0);

    u_int64_t newGroupsForAddressGroup = 0;

    for(u_int64_t addressIdx = 0; addressIdx < aGroup->nAddresses; addressIdx++) {
      int mustGroup = 0;
      int64_t iGroupIdx = getGroupIdxForAddress(aGroups, maxGroupComparisons, threshold, aGroup->addresses[addressIdx], measurementsPerAddress, flags, mustGroup);
      if(iGroupIdx >= 0) {
        if(iGroupIdx != aGroups->nAddressGroups - 1) {
          //printf("\nRegrouped %p\n", aGroup->addresses[addressIdx]);
          regrouped++;
        }
        aGroups->addressGroups[iGroupIdx] = addAddressToAddressGroup(aGroups->addressGroups[iGroupIdx], aGroup->addresses[addressIdx], pageSize);
      } else {
        newGroupsForAddressGroup++;
        if(newGroupsForAddressGroup > 1) {
          //printf("\nRegrouped %p\n", aGroup->addresses[addressIdx]);
        }

        aGroups = addAddressGroupToAddressGroups(aGroups, NULL);
        aGroups->addressGroups[aGroups->nAddressGroups - 1] = addAddressToAddressGroup(aGroups->addressGroups[aGroups->nAddressGroups - 1], aGroup->addresses[addressIdx], pageSize);
      }
    }

    regrouped += newGroupsForAddressGroup - 1;
    aGroup = freeAddressGroup(aGroup, 0);
  }

  aGroups = cleanAddressGroups(aGroups);

  return aGroups;
}

/*******************************************************************************
 * verifyGroups takes a list of address groups and verifies if they are correct
 * by measuring the timings again.
 *
 * @param aGroups: List of address groups that should be evaluated
 * @param threshold: Threshold between row hit and row conflict
 * @param maxGroupComparison: Maximum number of addresses taken from the
 *        address groups for comparison
 ******************************************************************************/
void verifyGroups(addressGroups *aGroups, u_int64_t threshold, u_int64_t maxGroupComparisons) {
  u_int64_t idx;
  u_int64_t *times = malloc(sizeof(u_int64_t) * maxGroupComparisons);
  for(u_int64_t groupIdx = 0; groupIdx < aGroups->nAddressGroups; groupIdx++) {
    for(u_int64_t addressIdx = 0; addressIdx < aGroups->addressGroups[groupIdx]->nAddresses; addressIdx++) {
      if(aGroups->addressGroups[groupIdx]->nAddresses < 2) {
        printf("Ignoring address group %ld because there are only %ld addresses.\n", groupIdx, aGroups->addressGroups[groupIdx]->nAddresses);
        continue;
      }

      u_int64_t offset = 0;
      for(u_int64_t cnt = 0; cnt < maxGroupComparisons; cnt++) {
        idx = (addressIdx + cnt + offset) % aGroups->addressGroups[groupIdx]->nAddresses;
        if(idx == addressIdx) {
          offset++;
          idx = (addressIdx + cnt + offset) % aGroups->addressGroups[groupIdx]->nAddresses;
        }
        times[cnt] = measure_time(aGroups->addressGroups[groupIdx]->addresses[addressIdx], aGroups->addressGroups[groupIdx]->addresses[idx]);
      }

      qsort(times, maxGroupComparisons, sizeof(u_int64_t), compareInt);

      u_int64_t time = times[maxGroupComparisons/2];
      if(time < threshold) {
        printf("Address %p (%ld) does not match in Group %ld anymore. Time: %ld, Threshold: %ld", aGroups->addressGroups[groupIdx]->addresses[addressIdx], addressIdx, groupIdx, time, threshold);
        for(u_int64_t cnt = 0; cnt < maxGroupComparisons; cnt++) {
          printf(" %ld", times[cnt]);
        }
        printf("\n");

        u_int64_t foundGroup = 0;
        for(u_int64_t iGroupIdx = 0; iGroupIdx < aGroups->nAddressGroups; iGroupIdx++) {
          u_int64_t realGroupComparisons = 0;
          for(u_int64_t iAddressIdx = 0; iAddressIdx < maxGroupComparisons; iAddressIdx++) {
            if(iAddressIdx < aGroups->addressGroups[iGroupIdx]->nAddresses) {
              realGroupComparisons ++;
              times[iAddressIdx] = measure_time(aGroups->addressGroups[groupIdx]->addresses[addressIdx], aGroups->addressGroups[iGroupIdx]->addresses[iAddressIdx]);
            } else {
              times[iAddressIdx] = 0;
            }
          }

          qsort(times, realGroupComparisons, sizeof(u_int64_t), compareInt);

          u_int64_t time = times[realGroupComparisons/2];

          if(time > threshold) {
            printf("Address would match in Group %ld.\n", iGroupIdx);
            foundGroup = 1;
          }
        }

        if(foundGroup == 0) {
          printf("Addess does not match any Group.\n");
        }
      }
    }
  }

  free(times);
}
