////////////////////////////////////////////////////////////////////////////////
//
/// LSU EE 4755 Fall 2018 Homework 5 -- SOLUTION
//

 /// Assignment  https://www.ece.lsu.edu/koppel/v/2018/hw05.pdf


`default_nettype none
//////////////////////////////////////////////////////////////////////////////
///  Problem 1
//
 /// Complete batcher_sort so that it recursively implements a Batcher
 /// sorter using a merge module.
//
//     [✔] Assume that n is a power of 2.
//     [✔] Use implicit and explicit structural code only.
//     [✔] Use recursion as described in the handout.
//     [✔] Use behav_merge initially and when it's done, batcher_merge.
//
//     [✔] Make sure that the testbench does not report errors.
//     [✔] Module must be synthesizable. Use command: genus -files syn.tcl
//
//     [✔] Use SimVision for debugging.
//     [✔] Modify testbench to facilitate solution ..
//         .. but code must pass original testbench.
//
//     [✔] As always, code should be efficient and clearly written.

module batcher_sort
  #( int n = 4, int w = 8 )
   ( output uwire [w-1:0] x[n], input uwire [w-1:0] a[n] );

   /// SOLUTION

   if ( n == 1 ) begin

      // Set the terminal case at n==1 ..
      // .. because sorting is easy when there's just one element!
      //
      assign x = a;

   end else begin

      localparam int nh = n/2;
      uwire [w-1:0] xlo[nh], xhi[nh];

      // Recursively instantiate two sorters, slo and shi, ..
      // .. slo will sort elements 0 to nh-1, and ..
      // .. shi will sort elements nh to n-1.
      //
      batcher_sort #(nh,w) slo( xlo, a[0:nh-1] );
      batcher_sort #(nh,w) shi( xhi, a[nh:n-1] );

      // Use a merge module to combine the two sorted sequences.
      //
      batcher_merge #(nh,w) m( x, xlo, xhi );

   end

endmodule




module behav_merge
  #( int n = 4, int w = 8 )
   ( output logic [w-1:0] x[2*n], input uwire [w-1:0] a[n], b[n] );

   logic [$clog2(n+1)-1:0] ia, ib;
   always_comb begin
      ia = 0; ib = 0;
      for ( int i = 0;  i < 2*n;  i++ )
        x[i] = ib == n || ia < n && a[ia] <= b[ib] ? a[ia++] : b[ib++];
   end

endmodule


//////////////////////////////////////////////////////////////////////////////
///  Problem 2
//
 /// Modify batcher_merge so that it recursively implements a Batcher
 /// odd/even merge module.
//
//     [✔] Recursively implement a Batcher Odd / Even merge module.
//
//     [✔] Assume that n is a power of 2.
//     [✔] Use sort2 to swap the values.
//
//     [✔] Make sure that the testbench does not report errors.
//     [✔] Module must be synthesizable. Use command: genus -files syn.tcl
//
//     [✔] Use SimVision for debugging.
//     [✔] Modify testbench to facilitate solution ..
//         .. but code must pass original testbench.

module batcher_merge
  #( int n = 4, int w = 8 )
   ( output uwire [w-1:0] x[2*n], input uwire [w-1:0] a[n], b[n] );

   /// SOLUTION

   // Note: Input a and input b are each sorted.

   // Declare the outputs of the recursively instantiated merge modules.
   //
   uwire [w-1:0] xlo[n], xhi[n];

   if ( n == 1 ) begin

      // No need for recursion when each sorted sequence is one element.
      //
      assign xlo[0] = a[0];
      assign xhi[0] = b[0];

   end else begin

      localparam int nh = n/2;

      // Put even elements of a into ae ..
      // ..   odd elements of a into ao ..
      // .. and likewise for b.

      uwire [w-1:0] ae[nh], ao[nh], be[nh], bo[nh];

      for ( genvar i=0; i<nh; i++ )
        begin
           assign ae[i] = a[2*i];
           assign ao[i] = a[2*i+1];
           assign be[i] = b[2*i];
           assign bo[i] = b[2*i+1];
        end

      // Use one merge unit to merge the sorted sequences ae and bo ..
      //
      batcher_merge #(nh,w) mlo( xlo, ae, bo );
      //
      // and the other to merge sorted sequences ao and be. 
      //
      batcher_merge #(nh,w) mhi( xhi, ao, be );
      //
      // This ensures that one of the two smallest elements is xlo[0] ..
      // .. and the other is xhi[0].

   end

   // Use 2-input sorters to complete the merge.
   //
   for ( genvar i=0; i<n; i++ )
     sort2 #(w) s2( x[2*i], x[2*i+1], xlo[i], xhi[i] );

endmodule


// Correctly functioning 2-input sorter.
module sort2
  #( int w = 8 )( output uwire [w-1:0] x0, x1, input uwire [w-1:0] a0, a1 );
   assign {x0, x1} = a0 <= a1 ? { a0, a1 } : { a1, a0 };
endmodule


//////////////////////////////////////////////////////////////////////////////
/// Testbench Code
//
//  The code below instantiates some of the modules above,
//  provides test inputs, and verifies the outputs.
//
//  The testbench may be modified to facilitate your solution. For
//  example, one might modify the testbench so that the first tests it
//  performs are those which make it easier to determine what the
//  problem is, for example, test inputs that are all 0's or all 1's.
//
//  Of course, the removal of tests which your module fails is not a
//  method of fixing a broken module.  The TA-bot will test your
//  code using a fresh copy of the testbench, not the one below.



// cadence translate_off

module sortx
  #( int n = 5,
     int modnum = 0,
     int mut_idx = 0,
     int w = 10,
     int max_muts = 3,
     int max_n = n)
   ( output uwire [w-1:0] xlong[max_muts][max_n],
     input uwire [w-1:0] a[n] );

   localparam int nlo = n/2;
   localparam int nhi = n - nlo;
   uwire [w-1:0] x[n];
   assign xlong[mut_idx][0:n-1] = x;
   uwire [w-1:0] alo[nlo] = a[0:nlo-1];
   uwire [w-1:0] ahi[nhi] = a[nlo:n-1];

   if ( modnum == 0 ) begin:A

      localparam string name = "Batcher Merge";
      localparam bit merge = 1;
      batcher_merge #(nlo,w) s(x,alo,ahi);

   end else if ( modnum == 1 ) begin:A

      localparam string name = "Batcher Sort";
      localparam bit merge = 0;
      batcher_sort #(n,w) s(x,a);

   end else if ( modnum == 2 ) begin:A

      localparam string name = "sort3";
      localparam bit merge = 0;

   end else begin:A

      localparam string name = "sort4";
      localparam bit merge = 0;

   end

endmodule


module testbench;

   localparam int w = 8;
   localparam int n_tests = 10;
   localparam int max_n = 32;
   localparam int max_muts = 12;

   logic [w-1:0] a[max_n];
   uwire [w-1:0] x[max_muts][max_n];

   typedef struct { int idx; string name; bit merge; int n; } Info;
   Info pi[$];

   for ( genvar i=0; i<2; i++ ) begin
      for ( genvar nlg = 1; nlg < 6; nlg++ ) begin
         localparam int n = 1 << nlg;
         localparam int idx = i * 6 + nlg;
         sortx #(n,i,idx,w,max_muts,max_n) s(x,a[0:n-1]);
         initial pi.push_back( '{ idx, s.A.name, s.A.merge, s.n } );
      end
   end

   initial begin

      automatic int g_elt_err_count = 0;
      automatic int g_sort_err_count = 0;

      $write("Starting testbench.\n");

      // Initialize the input to a recognizable pattern, which should
      // be overwritten but if not, we can tell. If we print the value in
      // hex.
      for ( int e = 0; e < max_n; e++ ) a[e] = 'haaaaaaaa;

      foreach ( pi[idx] ) begin

         automatic Info p = pi[idx];
         automatic string mut = p.name;
         automatic int n = p.n;
         automatic int s_size = n;
         automatic int nlo = n/2;
         automatic int nhi = n - nlo;
         automatic logic [w-1:0] shadow[] = new[s_size];
         automatic logic [w-1:0] alo[] = new[nlo];
         automatic logic [w-1:0] ahi[] = new[nhi];
         automatic int this_sort_err_count = 0;

         for ( int i = 0;  i < n_tests;  i++ ) begin

            automatic int this_elt_err_count = 0;

            // To make sure that the comparison is correct restrict the
            // key to a subset of bits.
            automatic int n_bits = {$random} % w + 1;
            automatic int mask = ( 1 << n_bits ) - 1;

            for ( int i=0; i<w; i++ ) begin
               automatic int b = {$random} % w;
               {mask[b],mask[i]} = {mask[i],mask[b]};
            end

            for ( int e = 0; e < s_size; e++ )
              begin
                 a[e] = {$random} & mask;
                 shadow[e] = a[e];
                 if ( e < nlo ) alo[e] = a[e]; else ahi[e-nlo] = a[e];
              end

            if ( p.merge ) begin
               alo.sort();
               ahi.sort();
               for ( int e=0; e<nlo; e++ ) a[e] = alo[e];
               for ( int e=nlo; e<n; e++ ) a[e] = ahi[e-nlo];
            end

            #1;

            shadow.sort();

            for ( int e = 0; e < s_size; e++ ) begin
               automatic logic [w-1:0] elt = x[p.idx][e];
               if ( shadow[e] === elt ) continue;
               this_elt_err_count++;
               g_elt_err_count++;
               if ( g_elt_err_count > 5 ) continue;
               $write
                 ("Mod %s, n=%0d, sort %2d idx %2d, wrong elt %d != %d (correct)\n",
                  mut, n, i, e, elt, shadow[e]);
            end

            if ( this_elt_err_count ) this_sort_err_count++;

         end

         if ( this_sort_err_count ) g_sort_err_count++;

         $write("Tests for %s (idx %0d)  n=%0d done, errors in %0d of %0d sorts.\n",
                mut, p.idx, n,  this_sort_err_count, n_tests);

      end

      $write("Done with all tests, errors on %0d sorters.\n",
             g_sort_err_count);

   end

endmodule

// cadence translate_on