////////////////////////////////////////////////////////////////////////////////
//
/// LSU EE 4755 Fall 2023 Homework 2 -- SOLUTION
//

 /// Assignment  https://www.ece.lsu.edu/koppel/v/2023/hw02.pdf


`default_nettype none

//////////////////////////////////////////////////////////////////////////////
///  Problem 1
//
  ///  Complete comp_p1 so that it computes (1-b/c)/a. See writeup.
 ///
//
//     [✔] Perform computation in order given by expression (1-b/c)/a.
//     [✔] Only modify comp_p1.
//
//     [✔] Use Chipware modules for floating point arithmetic and conversions.
//     [✔] Do not perform FP arithmetic with procedural code.
//
//     [✔] Make sure that the testbench does not report errors.
//     [✔] Module must be synthesizable. Use command: genus -files syn.tcl
//
//     [✔] Don't assume any particular parameter values.
//
//     [✔] Pay attention to cost. Don't grossly oversize things.
//     [✔] Code must be written clearly.


module fp_one
  #( int w_exp=5, w_sig=9, w_fp=1+w_exp+w_sig )( output uwire [w_fp-1:0] one );
   // Output is the constant 1. This module is synthesizable.
   assign one = { 1'b0, (w_exp)'(( 1 << w_exp-1 ) - 1), (w_sig)'(0) };
endmodule

typedef enum logic [2:0]
  { Rnd_to_even = 0, Rnd_to_0 = 1, Rnd_to_plus_if = 2,
    Rnd_to_minus_inf = 3, Rnd_to_plus_inf = 4, Rnd_from_0 = 5 }
    Rnd;

module comp_p1
    #( int w = 5, w_exp = 5, w_sig = 5, wfp = 1 + w_exp + w_sig )
   ( output uwire [wfp-1:0] h,
     input uwire [w-1:0] a, b, c );

   localparam Rnd rnd = Rnd_to_even;
   uwire logic [wfp-1:0] one;
   fp_one #(w_exp,w_sig) o(one);

   /// SOLUTION

   uwire logic [wfp-1:0] One;
   fp_one #(w_exp,w_sig) O(One);

   uwire [wfp-1:0] af, bf, cf, boc, numer;
   uwire [7:0] sa, sb, sc, sboc, snumer, sh;

   // Convert inputs to floating-point.
   //
   CW_fp_i2flt #( .sig_width(w_sig), .exp_width(w_exp), .isize(w), .isign(0) )
   coa( .z(af), .a(a), .rnd(rnd),  .status(sa) );
   CW_fp_i2flt #( .sig_width(w_sig), .exp_width(w_exp), .isize(w), .isign(0) )
   cob( .z(bf), .a(b), .rnd(rnd), .status(sb) );
   CW_fp_i2flt #( .sig_width(w_sig), .exp_width(w_exp), .isize(w), .isign(0) )
   coc( .z(cf), .a(c), .rnd(rnd),  .status(sc) );

   // Compute (1-b/c)/a
   //
   CW_fp_div #( .sig_width(w_sig), .exp_width(w_exp) )
   d1( .z(boc),   .a(bf),    .b(cf),   .status(sboc), .rnd(rnd) );
   CW_fp_sub #( .sig_width(w_sig), .exp_width(w_exp) )
   d2( .z(numer), .a(One),   .b(boc),  .status(snumer), .rnd(rnd) );
   CW_fp_div #( .sig_width(w_sig), .exp_width(w_exp) )
   d3( .z(h),     .a(numer), .b(af),   .status(sh), .rnd(rnd) );

endmodule


//////////////////////////////////////////////////////////////////////////////
///  Problem 2
//
  ///  Complete comp_p2 so that it computes (1-b/c)/a efficiently. See writeup.
 ///
//
//     [✔] Transform (1-b/c)/a for computation efficiency; implement that.
//     [✔] Only modify comp_p2.
//
//     [✔] Use Chipware modules for floating point arithmetic and conversions.
//     [✔] Do not perform FP arithmetic with procedural code.
//
//     [✔] Make sure that the testbench does not report errors.
//     [✔] Module must be synthesizable. Use command: genus -files syn.tcl
//
//     [✔] Don't assume any particular parameter values.
//
//     [✔] Pay attention to cost. Don't grossly oversize things.
//     [✔] Pay attention to performance (delay).
//     [✔] Code must be written clearly.


module comp_p2
    #( int w = 5, w_exp = 5, w_sig = 5, wfp = 1 + w_exp + w_sig )
   ( output uwire [wfp-1:0] h,
     input uwire [w-1:0] a, b, c );

   localparam logic [2:0] rnd = Rnd_to_0;

   /// SOLUTION
   //
   // Summary:
   //
   // - Transform  (1-b/c)/a  into  ( c - b ) / ( ac ).
   //
   // - Use integer arithmetic for  c - b  and for  ac.
   //     Take care to use enough bits in each expression.
   //
   // - Convert  c-b  and  ac  to floating point.
   //
   // - Compute  (c-b)/ac  with one extra bit of precision. 


   // Perform integer computations.

   // Note: Width (bits) of integer product is sum of width of operands.
   //
   localparam int wac = 2 * w;
   //
   uwire [wac-1:0] ac = a * c;

   // Use an extra bit for difference because result can be negative.
   //
   uwire [w:0] cmb = c - b;

   // Use one extra bit of precision when doing division.
   //
   localparam int w_Sig = w_sig + 1;
   localparam int wFp = 1 + w_exp + w_Sig;

   uwire [wFp-1:0] acf, cmbf, H;
   uwire [7:0] sa, sb, sboc;

   // Convert to floating point.
   //
   CW_fp_i2flt #( .sig_width(w_Sig), .exp_width(w_exp), .isize(wac), .isign(0) )
   coa( .z(acf),  .a(ac), .rnd(rnd), .status(sa) );
   CW_fp_i2flt #( .sig_width(w_Sig), .exp_width(w_exp), .isize(w+1), .isign(1) )
   cob( .z(cmbf), .a(cmb), .rnd(rnd), .status(sb) );

   // Compute quotient.
   //
   CW_fp_div #( .sig_width(w_Sig), .exp_width(w_exp) )
   di1( .z(H), .a(cmbf), .b(acf), .status(sboc), .rnd(rnd) );

   // Remove the extra bit.
   //
   assign h = H[wFp-1:wFp-wfp];

endmodule


//////////////////////////////////////////////////////////////////////////////
/// Testbench Code


// cadence translate_off

function automatic int unsigned rand_wid(int max_wid);
      automatic int wid = 1 + {$random()} % max_wid;
      return {$random()} & ( ( 1 << wid ) - 1 );
endfunction

function automatic real fabs(real val);
      fabs = val < 0 ? -val : val;
endfunction

function int min( int a, b );
      min = a <= b ? a : b;
endfunction
function int max( int a, b );
      max = a >= b ? a : b;
endfunction


virtual class conv #(int wexp=6, wsig=10);
   // Convert between real and fp types using parameter-provided
   // exponent and significand sizes.

   localparam int w = 1 + wexp + wsig;
   localparam int bias_r = ( 1 << 11 - 1 ) - 1;
   localparam int w_sig_r = 52;
   localparam int w_exp_r = 11;
   localparam int bias_h = ( 1 << wexp - 1 ) - 1;

   static function logic [w-1:0] rtof( real r );
      logic [wsig-1:0] sig_f;
      logic [w_sig_r-wsig-2:0] sig_x;
      logic sig_x_msb;
      logic [w_exp_r-1:0] exp_r;
      logic sign_r;
      { sign_r, exp_r, sig_f, sig_x_msb, sig_x } = $realtobits(r);
      // So, what about a rounding mode? Not now!
      rtof = !r ? 0 : { sign_r, wexp'( exp_r + bias_h - bias_r ), sig_f };
   endfunction

   static function real ftor( logic [w-1:0] f );
      ftor = !f ? 0.0
        : $bitstoreal
          ( { f[w-1],
              w_exp_r'( bias_r + f[w-2:wsig] - bias_h ),
              f[wsig-1:0], (w_sig_r-wsig)'(0) } );
   endfunction

   static function int err_bits( logic [w-1:0] a, b );

      logic [wsig-1:0] sig_a, sig_b;
      logic [wsig+2:0] frac_a, frac_b, frac_diff;
      logic [wexp-1:0] exp_a, exp_b;
      logic s_a, s_b;
      int delta_e;

      if ( $isunknown(a) || $isunknown(b) ) return 1 << wexp;
      if ( a == b ) return 0;

      { s_a, exp_a, sig_a } = a;
      { s_b, exp_b, sig_b } = b;

      if ( exp_a == 0 || exp_b == 0 ) begin
         logic [wsig-1:0] sig = ~ ( sig_a | sig_b );
         return 1 + wsig - $clog2( sig + 1 );
      end

      delta_e = $abs( 0 + exp_a - exp_b );
      if ( delta_e > 1 ) return delta_e + wsig;
      frac_a = exp_a > exp_b ? { 2'b1, sig_a, 1'b0 } : { 3'b1, sig_a };
      frac_b = exp_b > exp_a ? { 2'b1, sig_b, 1'b0 } : { 3'b1, sig_b };
      frac_diff =
        s_a != s_b ? frac_a + frac_b :
        frac_a > frac_b ? frac_a - frac_b : frac_b - frac_a;
      return $clog2( frac_diff + 1 );

   endfunction

endclass

// cadence translate_on

// cadence translate_off

// Module names. (Used by the testbench.)
//
typedef enum { M_p1, M_p2 } M_Type;

module testbench;

   localparam int n_tests = 10000;

   localparam int npsets = 5; // This MUST be set to the size of pset.
   // { w_exp, w_sig, w_int }
   localparam int pset[npsets][3] =
              '{
                { 7,  6,  4 },
                { 7,  8,  4 },
                { 8, 10,  5 },
                { 8, 10, 10 },
                { 8, 12, 10 }
                };

   localparam int nmsets = 2;
   localparam M_Type mset[nmsets] = '{ M_p1, M_p2 };

   string mtype_str[M_Type] = '{ M_p1: "comp_p1", M_p2: "comp_p2" };
   string mtype_abbr[M_Type] = '{ M_p1: "p1", M_p2: "p2" };

   int t_errs_mod[M_Type];
   int t_errs_size[int];
   int t_errs_each[M_Type][int];
   int t_mub_each[M_Type][int];
   real t_aub_each[M_Type][int];

   localparam int nsets = npsets * nmsets;

   logic d[nsets:-1]; // Start / Done signals.

   int t_errs;     // Total number of errors.
   initial begin
      t_errs = 0;
      for ( int m=0; m<nmsets; m++ )
        for ( int i=0; i<npsets; i++ ) begin
           t_errs_each[mset[m]][i] = -1;
           t_mub_each[mset[m]][i] = -1;
           t_aub_each[mset[m]][i] = -1;
        end

      d[-1] = 1;
   end

   final begin
      $write("\nNumber of tests: %0d.\n", n_tests);
      for ( int i=0; i<npsets; i++ )
        $write("Total for exp=%2d, sig=%2d, w=%2d: %5d errors.\n",
               pset[i][0], pset[i][1], pset[i][2],
               t_errs_size[i]);
      for ( int i=0; i<nmsets; i++ )
        $write("Total for mod %4s: %5d errors.\n",
               mtype_str[mset[i]], t_errs_mod[mset[i]]);
      for ( int m=0; m<nmsets; m++ )
        for ( int i=0; i<npsets; i++ )
          $write("Total %4s exp=%2d, sig=%2d, w=%2d: %5d errors. Err bits: avg %6.2f, max %3d\n",
                 mtype_str[mset[m]],
                 pset[i][0], pset[i][1], pset[i][2],
                 t_errs_each[mset[m]][i],
                 t_aub_each[mset[m]][i], t_mub_each[mset[m]][i]);

      $write("Total number of errors: %0d\n",t_errs);
   end

   for ( genvar m=0; m<nmsets; m++ )
     for ( genvar i=0; i<npsets; i++ ) begin
        localparam int idx = m * npsets + i;
        testbench_n
          #( .w_exp(pset[i][0]), .w_sig(pset[i][1]), .w_int(pset[i][2]),
             .pset(i), .mtype(mset[m]) )
        t2( .done(d[idx]), .tstart(d[idx-1]) );
     end

endmodule


module testbench_n
  #( int w_exp = 5, w_sig = 8, w_int = 12, pset = 0, M_Type mtype = M_p1 )
   ( output logic done, input uwire tstart );

   localparam int w_fp = 1 + w_sig + w_exp;
   localparam int bias = ( 1 << w_exp-1 ) - 1;
   logic [w_int-1:0] a, b, c;
   uwire [w_fp-1:0] h;

   case ( mtype )
     M_p1: comp_p1 #( w_int, w_exp, w_sig ) c1(h, a, b, c);
     M_p2: comp_p2 #( w_int, w_exp, w_sig ) c2(h, a, b, c);
   endcase

   initial begin

      automatic int n_tests = testbench.n_tests;
      automatic int n_err = 0;
      automatic int ub_max = 0, ub_emax = 0, ub_sum = 0;

      wait( tstart );

      $write("Starting tests for mod %4s exp=%2d, sig=%2d, w=%2d\n",
             testbench.mtype_str[mtype], w_exp, w_sig, w_int);

      for (int i=0; i<n_tests; i++ ) begin

         automatic bit choose_close_bc = $random() & 1'b1;

         real mut_h, shadow_h, shadow_hr, boc;
         logic [w_fp-1:0] shadow_hf;
         int ub, bit_loss, tol;

         a = rand_wid(w_int);
         if ( a == 0 ) a = 1;
         b = choose_close_bc ? $random() : rand_wid(w_int);
         c = choose_close_bc ? $random() : rand_wid(w_int);
         if ( c == 0 ) c = 1;

         bit_loss = mtype == M_p2 || b == c ? 0
           : $clog2( 1 + int'($ceil( 1 / fabs( 1 - real'(b)/c ) ) ) );
         tol = 1 + bit_loss;

         shadow_hr = ( 1 - real'(b)/c ) / a;
         shadow_hf = conv#(w_exp,w_sig)::rtof( shadow_hr );
         shadow_h = conv#(w_exp,w_sig)::ftor( shadow_hf );

         #1;

         mut_h = conv#(w_exp,w_sig)::ftor(h);
         ub = conv#(w_exp,w_sig)::err_bits( shadow_hf, h );
         if ( ub > 0 ) ub_sum += ub;

         if ( ub > tol ) begin
            n_err++;
            if ( ub > ub_emax ) begin
               ub_emax = ub;
               $write( "Error %s #(%0d,%0d,%0d) a=%d b=%d c=%d:  Err bits %0d (tol %0d)\n",
                       testbench.mtype_abbr[mtype],
                       w_exp, w_sig, w_int,
                       a, b, c,  ub, tol );
               $write( "  Output %.4e != %.4e (correct).\n",
                       mut_h, shadow_h );
               $write( "  Output 'h%h * 2^(%d-%0d) != 'h%h * 2^(%d-%0d) (correct)\n",
                       h[w_sig-1:0], h[w_sig+w_exp-1:w_sig], bias,
                       shadow_hf[w_sig-1:0], shadow_hf[w_sig+w_exp-1:w_sig],
                       bias );

            end

         end

         if ( ub > ub_max ) ub_max = ub;

      end

      $write("Finished tests for mod %4s exp=%2d, sig=%2d, w=%2d. %0d errors.\n",
             testbench.mtype_str[mtype], w_exp, w_sig, w_int, n_err);


      testbench.t_errs += n_err;
      testbench.t_errs_each[mtype][pset] = n_err;
      testbench.t_mub_each[mtype][pset] = ub_max;
      testbench.t_aub_each[mtype][pset] = real'(ub_sum) / n_tests;
      testbench.t_errs_mod[mtype] += n_err;
      testbench.t_errs_size[pset] += n_err;

      done = 1;
   end

endmodule


`define SIMULATION_ON

// cadence translate_on

`default_nettype wire

`ifdef SIMULATION_ON

`include "/apps/linux/cadence/GENUS211/share/synth/lib/chipware/sim/verilog/CW/CW_fp_mult.v"
`include "/apps/linux/cadence/GENUS211/share/synth/lib/chipware/sim/verilog/CW/CW_fp_add.v"
`include "/apps/linux/cadence/GENUS211/share/synth/lib/chipware/sim/verilog/CW/CW_fp_sub.v"
`include "/apps/linux/cadence/GENUS211/share/synth/lib/chipware/sim/verilog/CW/CW_fp_div.v"
`include "/apps/linux/cadence/GENUS211/share/synth/lib/chipware/sim/verilog/CW/CW_fp_i2flt.v"

`else

`include "/apps/linux/cadence/GENUS211/share/synth/lib/chipware/syn/CW/CW_fp_mult.v"
`include "/apps/linux/cadence/GENUS211/share/synth/lib/chipware/syn/CW/CW_fp_add.v"
`include "/apps/linux/cadence/GENUS211/share/synth/lib/chipware/syn/CW/CW_fp_sub.v"
`include "/apps/linux/cadence/GENUS211/share/synth/lib/chipware/syn/CW/CW_fp_i2flt.v"
`include "/apps/linux/cadence/GENUS211/share/synth/lib/chipware/syn/CW/CW_fp_div.v"

`endif