/* LSU EE 7700-2 Fall 2003
   Classroom Example Program

   Parallel Radix Sort using pthreads

   To compile on Solaris: cc -mt rsortp.c -o rsortp -lpthread -lrt -fast

   See "sorter" function for radix sort.

 */

#include <stdio.h>
#include <malloc.h>
#include <stdlib.h>
#include <strings.h>
#include <pthread.h>
#include <time.h>

double
time_fp()
{
  struct timespec tp;
  clock_gettime(CLOCK_HIGHRES,&tp);
  return ((double)tp.tv_sec)+((double)tp.tv_nsec) * 0.000000001;
}

struct _barrier_lock_info {
  pthread_cond_t cond;
  pthread_mutex_t mutex;
  int num_in_barrier;
};

typedef struct _barrier_lock_info Barrier_Lock_Info;

struct barrier_info {
  Barrier_Lock_Info barrier_lock_info[2];
  int nprocs;
  int umm;
};

typedef struct barrier_info Barrier_Info;
typedef struct app_data App_Data;


struct thread_data {
  pthread_t tid;
  int start;
  int stop;
  int id;
  int *bins;
  App_Data *app;
};

typedef struct thread_data Thread_Data;

struct app_data {
  int *array, *cpy;
  int radix_lg;
  int bin_count;
  int mask;
  int amt;
  int nprocs;
  Barrier_Info bi;
  Thread_Data *td;
};


#define debug 0

void
barrier_init(Barrier_Info *bi,int nprocs)
{
  int i;
  bi->nprocs = nprocs;
  bi->umm = 0;
  for(i=0; i<2; i++)
    {
      int rv;
      bi->barrier_lock_info[i].num_in_barrier = 0;
      rv = pthread_cond_init(&bi->barrier_lock_info[i].cond,NULL);
      if( rv )
        {
          fprintf(stderr,"Could not initialize cond (%d)\n",rv);
          exit(1);
        }
      rv = pthread_mutex_init(&bi->barrier_lock_info[i].mutex,NULL);
      if( rv )
        {
          fprintf(stderr,"Could not initialize mutex (%d)\n",rv);
          exit(1);
        }
    }
}

void
barrier(Barrier_Info *bi)
{
  Barrier_Lock_Info *bli = &bi->barrier_lock_info[bi->umm];
  pthread_mutex_lock(&bli->mutex);
  if( ++bli->num_in_barrier == bi->nprocs )
    {
      bi->umm = 1 - bi->umm;
      bli->num_in_barrier--;
      pthread_mutex_unlock(&bli->mutex);
      pthread_cond_broadcast(&bli->cond);
    }
  else
    {
      pthread_cond_wait(&bli->cond,&bli->mutex);
      bli->num_in_barrier--;
      pthread_mutex_unlock(&bli->mutex);
    }
}


void*
sorter(void *arg)
{
  Thread_Data *td = (Thread_Data*) arg;
  App_Data *app = td->app;

  int *from = app->array;
  int *to = app->cpy;
  int start = td->start;
  int stop = td->stop;
  int mask = app->mask;
  int radix_lg = app->radix_lg;
  int bin_count = app->bin_count;
  int nprocs = app->nprocs;
  int bins_size = bin_count * sizeof(int);
  int *bins = td->bins = (int*) malloc( bins_size );

  int shift;

  for(shift=0; shift<32; shift+=radix_lg)
    {
      int i;
      int *swap;

      bzero(bins,bins_size);

      /* 
       *  Compute Histogram of Digit Values
       */

      for(i=start; i<stop; i++)
        {
          int digit = ( from[i] >> shift ) & mask;
          bins[digit]++;
        }

      barrier(&app->bi);

      /*
       *  Compute Prefix Sum
       */

      if( td->id == 0 )
        {
          /* Note, this could be parallelized. */
          int accumulator = 0;
          int p;
          for(i=0; i<bin_count; i++)
            for(p=0; p<nprocs; p++)
              {
                int this_bin = app->td[p].bins[i];
                app->td[p].bins[i] = accumulator;
                accumulator += this_bin;
              }
        }

      barrier(&app->bi);

      /*
       *  Permute Array Elements
       */

      for(i=start; i<stop; i++)
        {
          int digit = ( from[i] >> shift ) & mask;
          to[ bins[digit]++ ] = from[i];
        }

      swap = to;  to = from;  from = swap;

      barrier(&app->bi);

    }

  if( from != app->array )
    memcpy(app->array,app->cpy,app->amt*sizeof(app->cpy[0]));
}

void
check(int *array, int size)
{
  int i;
  for(i=1; i<size; i++)
    if( array[i] < array[i-1] )
      {
        fprintf(stderr,"Error, element %d too small: %d %d\n",
                i, array[i-1], array[i] );
        return;
      }
  printf("Array correctly sorted.\n");
}



int
main(int argv, char **argc)
{
  int per_child;
  int nprocs = 8;
  int amt;
  int print_amt;
  int i;
  int pos;
  double amtm = ((double)(1 << 20))/1000000;
  double start_time;
  Thread_Data *td;
  App_Data app;
  pthread_attr_t attr;

  app.radix_lg = 4;

  /*
   *  Read Command-Line Arguments
   */

  if( argv > 1 ) nprocs = atoi(argc[1]);
  if( argv > 2 ) amtm = atof(argc[2]);
  if( argv > 3 ) app.radix_lg = atoi(argc[3]);

  per_child = amtm * 1000000 / nprocs;
  amt = per_child * nprocs;
  print_amt = amt < 20 ? amt : 20;

  printf("Running radix sort for %d threads, %d (%d per proc), %d radix_lg.\n",
         nprocs, amt, per_child, app.radix_lg);

  app.nprocs = nprocs;
  app.bin_count = 1 << app.radix_lg;
  app.mask = app.bin_count - 1;

  app.amt = amt;
  app.array = (int*) malloc( amt * sizeof(app.array[0]) );
  app.cpy = (int*) malloc( amt * sizeof(app.array[0]) );

  for(i=0; i<amt; i++) app.array[i] = random();

  start_time = time_fp();

  /*
   *  Initialize and start child threads.
   */

  pthread_attr_init(&attr);
  pthread_attr_setscope(&attr, PTHREAD_SCOPE_SYSTEM);
  app.td = td = (Thread_Data*) malloc( sizeof(*td) * nprocs );
  barrier_init(&app.bi,nprocs);

  for(pos=0, i=0; i<nprocs; i++)
    {
      int rv;
      td[i].id = i;
      td[i].app = &app;
      td[i].start = pos;
      td[i].stop = pos += per_child;
      rv = pthread_create(&td[i].tid,&attr,sorter,(void*)&td[i]);
      if( rv )
        {
          fprintf(stderr,"Could not create thread, rv %d.\n",rv);
          exit(1);
        }
    }

  /*
   *  Wait for each thread to finish.
   */

  for(i=0; i<nprocs; i++)
    {
      int status;
      pthread_join( td[i].tid, (void**) &status );
      if( debug )
        printf("Thread %d (%d) returned with status %d.\n",
               td[i].tid, i, status);
    }

  if( debug )
    for(i=0; i<print_amt; i++) printf(" 0x%08x\n",app.array[i]);
  
  {
    double et = time_fp() - start_time;
    printf("Took %0.3f seconds\n",et);
  }

  check(app.array,amt);

  return 0;
}