////////////////////////////////////////////////////////////////////////////////
//
/// LSU EE 4755 Fall 2020 Homework 3 -- SOLUTION
//

 /// Assignment  https://www.ece.lsu.edu/koppel/v/2020/hw03.pdf


`default_nettype none

//////////////////////////////////////////////////////////////////////////////
///  Problem 1
//
 /// Modify nnOxI and nn1xI so they compute same output as nnOxIbe.
 ///
//
//     [✔] Make sure that the testbench does not report errors.
//     [✔] Module must be synthesizable. Use command: genus -files syn.tcl
//
//     [✔] Don't assume any particular parameter value.
//     [✔] Don't make ports wider than necessary.
//
//     [✔] Code must be written clearly.


   /// Module Connection Names
   //
   //  no:  Number of elements in Output array.
   //  ni:  Number of elements in Input array.
   //  wo:  Width (number of bits) in each element of output array.
   //  wi:  Width (number of bits) in each element of input array.
   //  ww:  Width (number of bits) in each element of weight array.
   //  sat: If 0, on overflow use low wo bits of result.
   //       If 1, on overflow set result to maximum possible value,
   //             do this where overflow occurs.
   //       If 2, on overflow set result to maximum possible value,
   //             do this in nnOxI (not in nn1xI nor in nnAdd, nnMult, etc.)
   //  tr:  If 0, generate a linear connection of nnMADD modules in nn1xI.
   //       If 1, generate a tree connection of arithmetic units by
   //             recursively defining nn1xI.
   //  ao:  Activation (neuron) Output array.
   //  ai:  Activation (neuron) Input array.
   //  wht: Weights.


module nnOxI
  #( int no = 4, ni = 2, wo = 10, wi = 4, ww = 5, tr = 0, sat = 0 )
   ( output uwire [wo-1:0] ao[no],
     input uwire [wi-1:0] ai[ni],
     input uwire [ww-1:0] wht[no][ni] );

   //   [✔] Instantiate nn1xI modules here.
   //   [✔] If sat == 2 replace overflow values with max possible value.
   //   [✔] Don't forget to set appropriate parameter values.

   /// SOLUTION

   // Compute number of bits to represent largest possible value that
   // can appear on an ao.
   //
   localparam int wr = $clog2( ( 2**wi - 1 ) * ( 2**ww  - 1 ) * ni );

   if ( sat < 2 || wr <= wo ) begin

      // If overflow is not possible turn off check for saturation.
      //
      localparam int satp = wr <= wo ? 0 : sat;

      for ( genvar i = 0;  i < no;  i++ )
        nn1xI #(wo,wi,ww,ni,tr,satp) row( ao[i], ai, wht[i] );

   end else begin

      for ( genvar i = 0;  i < no;  i++ ) begin

         uwire [wr-1:0] ar;
         nn1xI #(wr,wi,ww,ni,tr,0) row( ar, ai, wht[i] );

         // If there is an overflow substitute maximum value.
         //
         assign ao[i] = ar[wr-1:wo] ? ~wo'(0) : ar[wo-1:0];

      end

   end

endmodule

module nn1xI
  #( int wo = 10, wi = 4, ww = 5, ni = 2, tr = 0, sat = 0 )
   ( output uwire [wo-1:0] ao,
     input uwire [wi-1:0] ai[ni],
     input uwire [ww-1:0] wht[ni] );

   //   [✔] If tr == 0 use generate loop to instantiate nnMADD modules.
   //   [✔] If tr == 1 use recursion to describe a tree structure ..
   //   [✔] .. and use nnMADD, nnMult, and nnAdd where appropriate.
   //   [✔] Don't forget to set appropriate parameter values.


   /// SOLUTION
   //
   if ( tr ) begin

      if ( ni == 1 ) begin

         nnMult #(wi,ww,wo,sat) mult(ao, ai[0], wht[0] );

      end else begin

         localparam int nlo = ni / 2;
         localparam int nhi = ni - nlo;
         uwire [wo-1:0] aolo, aohi;
         nn1xI #(wo,wi,ww,nlo,1,sat) nnlo(aolo, ai[0:nlo-1], wht[0:nlo-1]);
         nn1xI #(wo,wi,ww,nhi,1,sat) nnhi(aohi, ai[nlo:ni-1], wht[nlo:ni-1]);
         nnAdd #(wo,sat) add(ao,aolo,aohi);

      end

   end else begin

      uwire [wo-1:0] s[ni-1:-1];
      assign s[-1] = 0;
      assign ao = s[ni-1];

      for ( genvar i = 0;  i < ni;  i++ )
        nnMADD #(ww,wi,wo,sat) madd( s[i], wht[i], ai[i], s[i-1] );

   end

endmodule

module nnMADD
  #( int wa = 10, wb = 5, ws = wa + wb, sat = 0 )
   ( output uwire [ws-1:0] so,
     input uwire [wa-1:0] a, input uwire [wb-1:0] b, input uwire [ws-1:0] si);

   /// DO NOT MODIFY THIS MODULE.

   uwire [ws-1:0] p;
   nnMult #(wa,wb,ws,sat) mu(p, a, b);
   nnAdd #(ws,sat) ad(so, si, p);

endmodule

module nnAdd
  #( int w = 5, sat = 0 )
   ( output uwire [w-1:0] so,
     input uwire [w-1:0] a, b );

   /// DO NOT MODIFY THIS MODULE.

   uwire [w:0] s = a + b;
   localparam logic [w-1:0] smax = ~w'(0);
   assign so = sat && s[w] ? smax : s[w-1:0];

endmodule

module nnMult
  #( int wa = 5, wb = 6, wp = wa + wb, sat = 0 )
   ( output uwire [wp-1:0] p,
     input uwire [wa-1:0] a, input uwire [wb-1:0] b );

   /// DO NOT MODIFY THIS MODULE.

   localparam logic [wp-1:0] pmax = ~wp'(0);
   localparam int wmx = wp > wa+wb ? wp : wa+wb;
   uwire [wmx-wp:0] phi;
   uwire [wp-1:0] plo;
   assign {phi,plo} = a * b;
   assign p = sat && wp < wa + wb && phi ? pmax : plo;

endmodule


// Synthesizing at effort level "medium"

// Module Name                               Area   Delay   Delay
//                                                 Actual  Target

// nnOxI_no2_ni16_wo12_wi5_ww4_sat0_tr0    588304   6.972  90.000 ns
// nnOxI_no2_ni16_wo12_wi5_ww4_sat0_tr1    588304   6.972  90.000 ns
// nnOxI_no2_ni16_wo12_wi5_ww4_sat1_tr0    753136  63.864  90.000 ns
// nnOxI_no2_ni16_wo12_wi5_ww4_sat1_tr1    631611   7.043  90.000 ns
// nnOxI_no2_ni16_wo12_wi5_ww4_sat2_tr0    594261   7.450  90.000 ns
// nnOxI_no2_ni16_wo12_wi5_ww4_sat2_tr1    594261   7.450  90.000 ns

// nnOxI_no2_ni16_wo12_wi5_ww4_sat0_tr0    783094   4.828   1.000 ns
// nnOxI_no2_ni16_wo12_wi5_ww4_sat0_tr1    779386   4.852   1.000 ns
// nnOxI_no2_ni16_wo12_wi5_ww4_sat1_tr0    951332   9.503   1.000 ns
// nnOxI_no2_ni16_wo12_wi5_ww4_sat1_tr1    916787   5.136   1.000 ns
// nnOxI_no2_ni16_wo12_wi5_ww4_sat2_tr0    800554   4.980   1.000 ns
// nnOxI_no2_ni16_wo12_wi5_ww4_sat2_tr1    771789   4.981   1.000 ns

// Normal exit.



module nnOxIbe
  #( int no = 4, ni = 4, wo = 10, wi = 4, ww = 5, sat = 0 )
   ( output logic [wo-1:0] ao[no],
     input uwire  [wi-1:0] ai[ni], input uwire  [ww-1:0] wht[no][ni] );

   /// DO NOT MODIFY THIS MODULE
   //
   //  Study the code in this module to get a better understanding
   //  of what the output of nnOxI should be.

   // Determine the maximum possible value of each element of ao.
   //
   localparam logic [wo-1:0] smax = ~wo'(0);


   always_comb
     for ( int o = 0;  o < no;  o++ ) begin
        automatic int unsigned acc = 0;
        for ( int i=0; i<ni; i++ ) acc += ai[i] * wht[o][i];
        ao[o] = sat && acc > smax ? smax : acc;
     end

endmodule

//////////////////////////////////////////////////////////////////////////////
/// Testbench Code


// cadence translate_off

typedef struct { int ni, no, wa, ww; } Config;

module testbench;

   localparam int nc = 2;
   localparam int configs[nc][4] = '{ '{ 4,4, 16,9 }, '{ 5,3, 15,8 } };
   `ifdef grrr
   initial if ( nc != configs.size() )
     $fatal(1, "Constant nc should be %0d.\n", configs.size() );
   `endif

   int t_errs;     // Total number of errors.
   int t_errs_cat[string];  // Total errors by configuration category.
   string test_summary;
   initial begin
      t_errs = 0;
      test_summary = "";
   end

   final begin

      automatic int mlen = 0;
      foreach ( t_errs_cat[key] ) if ( mlen < key.len() ) mlen = key.len();

      $write("\n** Summary of Results **\n%s", test_summary);
      foreach ( t_errs_cat[key] )
        $write("%0s%0s %5d errors.\n",
               key, { mlen - key.len() {" "}}, t_errs_cat[key]);
      $write("Total number of errors: %0d\n",t_errs);

   end

   localparam int maxsat = 3;

   uwire d[maxsat*nc:-1];    // Start / Done signals.
   assign d[-1] = 1;  // Initialize first at true.

   // Instantiate a testbench at each size.
   //
   for ( genvar i=0; i<nc; i++ ) begin
      localparam int c[4] = configs[i];
      for ( genvar sat=0; sat<maxsat; sat++ ) begin
         localparam int idx = maxsat*i + sat;
         testbench_x #(c[0],c[1],c[2],c[3],sat)
         t2( .done(d[idx]), .start(d[idx-1]) );
      end
   end

endmodule

module testbench_x
  #( int ni = 4, no = 4, wo = 16, ww = 8, sat = 0 )
   ( output logic done, input uwire start );

   localparam int wi = ww + 1;

   localparam int nmut = 3;

   localparam int ntests = 10;

   logic [wi-1:0] ai[ni];
   uwire [wo-1:0] ao[nmut][no];
   logic [ww-1:0] wht[no][ni];

   string mname[] = { "Behav", "Linear", "Tree" };

   typedef struct { string name; int no, ni; } Test_Set;
   Test_Set ts[] = '{ '{ "n12", 1, 2 }, '{ "n1*", 1, ni }, '{ "n**", no, ni } };

   nnOxIbe #(no,ni,wo,wi,ww,sat) nn0(ao[0],ai,wht);
   nnOxI #(no,ni,wo,wi,ww,0,sat) nn1(ao[1],ai,wht);
   nnOxI #(no,ni,wo,wi,ww,1,sat) nn2(ao[2],ai,wht);

   initial begin

      automatic string config_label =
        $sformatf("no=%0d, ni=%0d, wo=%0d, wi=%0d, ww=%0d, sat=%0d",
                  no, ni, wo, wi, ww, sat );

      wait( start );

      $write("\n** Starting tests for %s\n", config_label);
      testbench.test_summary =
        { testbench.test_summary, $sformatf("Results from %s\n",config_label) };

      for ( int mut = 1;  mut < nmut;  mut++ ) begin

         $write("Testing module %s\n", mname[mut]);

         foreach ( ts[ti] ) begin

            automatic Test_Set tinfo = ts[ti];
            automatic int n_err = 0;

            $write("\n** Starting test set %s  (%0d outputs, %0d inputs) for %s **\n",
                   tinfo.name, tinfo.no, tinfo.ni, mname[mut] );

            for ( int tnum=0; tnum < ntests;  tnum++ ) begin

               for ( int io=0; io<no; io++ )
                 for ( int ii=0; ii<ni; ii++ )
                   wht[io][ii] = io < tinfo.no && ii < tinfo.ni ? $random : 0;

               for ( int ii=0; ii<ni; ii++ )
                 ai[ii] = ii < tinfo.ni ? $random : 0;

               #1;

               for ( int io=0; io<tinfo.no; io++ ) begin

                  if ( ao[0][io] !== ao[mut][io] ) begin
                     n_err++;
                     if ( n_err < 4 )
                       $write
                         ("Error test # %0d, output %0d: %0d != %0d (correct)\n",
                          tnum, io, ao[mut][io], ao[0][io] );
                  end
               end

            end

            begin
               automatic string sat_key =
                 $sformatf("Sat %0d",sat);
               automatic string long_key =
                 $sformatf("%s %s",mname[mut],sat_key);
               automatic string msg =
                 $sformatf("%0d %s tests on %s: %0d errors found.",
                           ntests, tinfo.name, mname[mut], n_err);
               $write("Done with %s\n",msg);
               testbench.test_summary =
                 { testbench.test_summary, $sformatf("Results of %s\n", msg) };
               testbench.t_errs += n_err;
               testbench.t_errs_cat[{"All ",sat_key}] += n_err;
               testbench.t_errs_cat[long_key] += n_err;
               testbench.t_errs_cat[mname[mut]] += n_err;
            end

         end

      end

      done = 1;

   end


endmodule

// cadence translate_on