////////////////////////////////////////////////////////////////////////////////
//
/// LSU EE 4755 Fall 2024 Homework 5 -- SOLUTION
//

 /// Assignment  https://www.ece.lsu.edu/koppel/v/2024/hw05.pdf

`default_nettype none


//////////////////////////////////////////////////////////////////////////////
///  Problem 1
//
  ///  Complete dot_seq_2_base.
//
//     [✔] The module can use procedural code.
//
//     [✔] Make sure that the testbench does not report errors.
//     [✔] Module must be synthesizable. Use command: genus -files syn.tcl
//
//     [✔] Don't assume any particular parameter values.
//
//     [✔] As always, code must be written clearly.
//     [✔] As always, pay attention to cost and performance.

module dot_seq_2
  #( int w = 5, wi = 4 )
   ( output logic [w-1:0] dp,
     output logic [wi-1:0] first_id, last_id,
     input uwire [w-1:0] a[2], b[2],
     input uwire [wi-1:0] in_id,
     input uwire reset, first, last,
     input uwire clk );

   // [✔] Critical path can't include more than one arithmetic operation.
   // [✔] Don't change outputs until operation is complete.

   /// SOLUTION

   /// Declare Pipeline Latch Registers
   //
   // A pipeline latch is an edge-triggered register that carries
   // values passing from stage to stage.
   //
   // Index values in the first unpacked dimension indicate the
   // stage(s) for which a pipeline latch is declared. For example,
   // pl_a is only declared for stage 1, pl_prod is only declared for
   // stage 2, but pl_id declares three registers, pl_id[1] for stage
   // 1, pl_id[2] for stage 2, and pl_id[3] for stage 3. Note that
   // pl_a[1] is a stage-1 register that holds two w-bit values. A
   // register declared for stage x is written by hardware in stage
   // x-1 and read by hardware in stage x.
   //
   logic [w-1:0] pl_a[1:1][2], pl_b[1:1][2];  // Arriving vector elements.
   logic [w-1:0] pl_prod[2:2][2];             // Vector products.
   logic [w-1:0] pl_sum[3:3];                 // Dot prod of 2-element segment.
   logic [wi-1:0] pl_id[1:3];                 // ID.
   logic [1:0] pl_fl[1:3];                    // The first and last signals.
   //
   // Note that ID, first, and last, move through the pipeline
   // unchanged, which is why there is a pipeline latch in each stage
   // to carry their values. In contrast, vector elements (arriving in
   // a and b) are transformed in each stage and so the names of the
   // pipeline latches carrying their values and values computed from
   // them change at each stage: a/b -> prod -> sum.

   /// Declare Accumulator Registers
   //
   // An accumulator holds values that don't move with the pipeline.
   // Here the accumulators are updated by values from stage 3.
   //
   logic [wi-1:0] acc_id;
   logic [w-1:0] acc_sum;

   always_ff @( posedge clk ) begin

      /// Stage 0
      //
      // Move arriving inputs into pipeline latches. Since the problem
      // states inputs arrive late in the cycle it is not possible to
      // do any calculations using their values until the next cycle.
      //
      pl_a[1] <= a;  // This copies both elements of a.
      pl_b[1] <= b;
      pl_id[1] <= in_id;
      pl_fl[1] <= reset ? 2'b0 : {last,first};

      /// Stage 1
      //
      // Compute products ..
      //
      for ( int i=0; i<2; i++ ) pl_prod[2][i] <= pl_a[1][i] * pl_b[1][i];
      //
      // .. and move everyone else along unchanged (except for reset).
      //
      pl_id[2] <= pl_id[1];
      pl_fl[2] <= reset ? 2'd0 : pl_fl[1];

      /// Stage 2
      //
      // Compute sum ..
      //
      pl_sum[3] <= pl_prod[2][0] + pl_prod[2][1];
      //
      // .. and move everyone else along unchanged (except for reset).
      //
      pl_id[3] <= pl_id[2];
      pl_fl[3] <= reset ? 2'h0 : pl_fl[2];

      /// Stage 3
      //
      begin
         // Declare intermediate values in this block.
         //
         automatic logic s3_first = pl_fl[3][0]; // For readability.
         automatic logic s3_last =  pl_fl[3][1]; // For readability.

         // Add arriving value on to accumulated sum, unless this
         // is the first set of vector elements.
         //
         automatic logic [w-1:0] s3_sum =
           s3_first ? pl_sum[3] : pl_sum[3] + acc_sum;
         //
         // Write the sum to the accumulator register.
         //
         acc_sum <= s3_sum;
         //
         // Note that there is no need to check reset or last because
         // if they were true the value written to acc_sum would not
         // be used and so there is no need to waste hardware that
         // would avoid writing it.

         if ( reset ) begin

            // Set output IDs to zero on a reset.
            //
            first_id <= 0;
            last_id <= 0;
            //
            // Since the problem does not say to set dp to zero there
            // is no need to waste hardware doing so.

         end else begin

            // Hold on to the ID if this is the beginning of the
            // vector.
            //
            if ( s3_first ) acc_id <= pl_id[3];
            //
            // A beginner's mistake is to use the value of first
            // currently at the module inputs rather than the value
            // in pipeline latch pl_fs[3].

            if ( s3_last ) begin

               // If this is the end of the vector, update the module
               // outputs.
               //
               first_id <= s3_first ? pl_id[3] : acc_id;
               last_id <= pl_id[3];
               dp <= s3_sum;

            end
         end
      end

   end

endmodule

`ifdef xxxx
Synthesizing at effort level "high"

Module Name                           Area   Delay   Delay     Synth
                                            Actual  Target      Time
dot_seq_2_w8_wi8                    148802    3.58   900.0 ns      8 s
dot_seq_2_w8_wi8_17                 219405    2.02     0.1 ns     43 s
`endif


//////////////////////////////////////////////////////////////////////////////
/// Testbench Code
//
// It is okay to modify the testbench code to facilitate the coding
// and debugging of your modules. Keep in mind that your submission
// will be tested using a different testbench, so on the one hand no
// one will be accused of dishonesty for modifying the testbench
// below. However be sure to restore any changes to make sure that
// your code passes the original testbench.


// cadence translate_off

program reactivate
   (output uwire clk_reactive, output int cycle_reactive,
    input uwire clk, input var int cycle);
   assign clk_reactive = clk;
   assign cycle_reactive = cycle;
endprogram


module dot_seq_m
  #( int n = 1, w = 5, wi = 4 )
   ( output logic [w-1:0] dp,
     output logic [wi-1:0] first_id, last_id,
     input uwire [w-1:0] a[n], b[n],
     input uwire [wi-1:0] in_id,
     input uwire reset, first, last,
     input uwire clk );

   if ( n == 1 ) begin

      dot_seq #(w,wi) d(dp,first_id,last_id,a[0],b[0],in_id,reset,first,last,clk);

   end else if ( n == 2 ) begin

      dot_seq_2 #(w,wi) d(dp,first_id,last_id,a,b,in_id,reset,first,last,clk);

   end

endmodule

module testbench;

   logic done;

   localparam int m = 2;
   localparam int n_tests = 10000;
   localparam int wi = 8;
   localparam int w = 8;

   // Maximum number of cycles from "last" signal to arrival of outputs.
   localparam int latency_max = 5;
   // Minimum latency that will be considered correct.
   localparam int latency_min = 2 + $clog2(m);

   localparam int cyc_max = n_tests * 1000;

   int seed;
   initial seed = 475501;

   function automatic bit rand_bern( int period );
      rand_bern = $dist_uniform(seed,1,period) == 1;
   endfunction

   bit clk;
   int cycle, cycle_limit;
   logic clk_reactive;
   int cycle_reactive;
   reactivate ra(clk_reactive,cycle_reactive,clk,cycle);
   string event_trace;
   string ev_trace[$];

   initial begin
      clk = 0;
      cycle = 0;
      event_trace = "";

      done = 0;
      cycle_limit = cyc_max;
      //  wait( tstart );

      fork
         while ( !done ) #1 cycle += clk++;
         wait( cycle >= cycle_limit )
           $write("Exit from clock loop at cycle %0d, limit %0d.  %s\n %s\n",
                  cycle, cycle_limit, "** CYCLE LIMIT EXCEEDED **",
                  event_trace);
      join_any;

      done = 1;
   end

   typedef struct
     {
      logic [w-1:0] dp;
      logic [wi-1:0] first_id, last_id;
      int cyc; // Cycle that last input asserted.
      int latency; // Number of cycles from last to dp to appearing at output.
      bit correct;
      } Info;

   uwire [w-1:0] dp;
   uwire [wi-1:0] out_first_id, out_last_id;
   logic [w-1:0] a[m], b[m];
   logic [wi-1:0] in_id;
   logic reset, first, last;
   bit done_tests, start_check;

   dot_seq_m #(m,w,wi)
   ds(dp,out_first_id,out_last_id,a,b,in_id,reset,first,last,clk);

   enum { S_reset, S_gap, S_continue } State;
   Info exp_info[$];
   int n_tests_completed;
   Info info_null;

   initial begin

      automatic int state = S_reset;
      logic [w-1:0] shadow_sum;
      logic [wi-1:0] shadow_id;

      n_tests_completed = 0;
      info_null.first_id = 0;
      info_null.last_id = 0;
      info_null.dp = 0;
      info_null.cyc = cycle;
      info_null.correct = 0;

      exp_info.push_back(info_null);

      start_check = 0;
      done_tests = 0;
      first = 0;
      last = 0;
      for ( int j=0; j<m; j++ ) begin a[j]=0; b[j]=0; end
      in_id = $dist_uniform(seed,1,(1<<wi)-1);
      reset = 1;
      @( negedge clk ); @( negedge clk );
      start_check = 1;

      while ( n_tests_completed < n_tests ) begin

         string tr_entry;

         first = 0;
         last = 0;
         reset = 0;
         for ( int j=0; j<m; j++ ) begin
            if ( n_tests_completed < 20 ) begin
               a[j] = 1;          b[j] = 1 << ( j * 4 );
            end else begin
               a[j] = {$random};  b[j]= {$random};
            end
         end

         // For error detection recent input IDs must be distinct.
         in_id += $dist_uniform(seed,1,3);
         if ( in_id == 0 ) in_id = 1;

         if ( state == S_reset ) state = S_gap;
         if ( rand_bern(100) ) begin
            reset = 1;
            in_id = $dist_uniform(seed,1,(1<<wi)-1);
            exp_info.delete();
            info_null.cyc = cycle;
            exp_info.push_back(info_null);
            state = S_reset;
         end

         if ( state == S_gap ) begin
            first = rand_bern(3);
            if ( first ) begin
               state = S_continue;
               shadow_sum = 0;
               shadow_id = in_id;
            end
         end

         for ( int j=0; j<m; j++ )
           shadow_sum += a[j] * b[j];

         if ( state == S_continue ) begin
            last = rand_bern(4);
            if ( last ) begin
               Info info;
               state = S_gap;
               n_tests_completed++;
               info.dp = shadow_sum;
               info.first_id = shadow_id;
               info.last_id = in_id;
               info.cyc = cycle;
               info.correct = 0;
               exp_info.push_back(info);
            end
         end

         if ( exp_info.size() > 1 && exp_info[1].cyc + latency_max < cycle )
           void'(exp_info.pop_front());

         @( negedge clk );

      end

      first = 0;
      last = 0;

      done_tests = 1;

      $write("Done with inputs.\n");

      begin
         automatic int cyc_limit = cycle + 20;
         wait ( cycle == cyc_limit );
         done = 1;

      end

   end

   initial begin

      automatic int sum_latency = 0;
      automatic int n_correct = 0;
      automatic int n_err_fid = 0;
      automatic int n_err_lid = 0;
      automatic int n_err_sum = 0;
      automatic int n_err_time = 0;
      automatic int n_err_early = 0;
      automatic int err_fid_cyc = 0;
      automatic int err_lid_cyc = 0;
      automatic int err_dp_cyc = 0;
      string tr_entry;

      #0;

      wait( start_check );

      while ( !done ) begin
         Info info, info_dp;
         automatic string err_text[$];
         string tr_entry;
         int fid_idx, lid_idx, dp_idx, age;
         bit err_time, err_fid, err_lid, err_early;

         @( posedge clk_reactive );

         if ( done ) break;

         fid_idx = -1;  lid_idx = -1;  dp_idx = -1;
         info_dp = info_null; age = -1;
         err_early = 0;

         for ( int i=0; i<exp_info.size(); i++ ) begin
            automatic Info info = exp_info[i];
            automatic int bi = exp_info.size() -1 -i;
            automatic int nm = 0;
            if ( info.first_id === out_first_id ) begin fid_idx = bi; nm++; end
            if ( info.last_id === out_last_id ) begin lid_idx = bi; nm++; end
            if ( info.dp === dp ||
                 info.first_id == 0 && out_first_id === 0 &&
                 info.last_id == 0 && out_last_id === 0 )
              begin
                 dp_idx = bi; nm++;
                 if ( !info.correct && info.first_id ) begin
                    automatic int latency = cycle - info.cyc;
                    exp_info[i].correct = 1;
                    exp_info[i].latency = latency;
                    sum_latency += latency;
                    err_early = latency < 3;
                    n_correct++;
                 end
                 info_dp = exp_info[i];
              end
            if ( nm == 3 ) break;
         end

         info = exp_info[$];

         if ( dp_idx >= 0 ) age = cycle - info_dp.cyc;

         err_fid = fid_idx == -1 || dp_idx >= 0 && fid_idx != dp_idx;
         err_lid = lid_idx == -1 || dp_idx >= 0 && lid_idx != dp_idx;

         if ( err_fid && info.cyc != err_fid_cyc ) begin
            n_err_fid++;
            err_fid_cyc = info.cyc;
            if ( n_err_fid < 4 )
              err_text.push_back
                ( $sformatf("Error first ID: %0h != %0h (correct)\n",
                           out_first_id, info_dp.first_id ) );
         end
         if ( err_lid && info.cyc != err_lid_cyc ) begin
            n_err_lid++;
            err_lid_cyc = info.cyc;
            if ( n_err_lid < 4 )
              err_text.push_back
                ( $sformatf("Error  last ID: %0h != %0h (correct)\n",
                           out_last_id, info_dp.last_id ) );
         end

         if ( err_early ) begin
            n_err_early++;
            if ( n_err_early < 4 )
              err_text.push_back
                ( $sformatf("Error dp timing: latency %0d cyc < 3 cyc (minimum)\n",
                            info_dp.latency) );
         end

         if ( dp_idx == -1 && info.cyc != err_dp_cyc ) begin
            n_err_sum++;
            err_dp_cyc = info.cyc;
            if ( n_err_sum < 4 )
              err_text.push_back
              ( $sformatf("Error dp: %h != %h (correct)\n", dp, info.dp ) );
         end

         err_time = 0 && fid_idx >= 0 && lid_idx >= 0 && dp_idx >= 0
           && ( fid_idx != lid_idx || lid_idx != dp_idx );

         if ( err_time ) begin
            n_err_time++;
            if ( n_err_time < 4 )
              err_text.push_back
              ( $sformatf("Error timing: %0d %0d %0d (should all be same)\n",
                     fid_idx, lid_idx, dp_idx) );
         end

         begin
            automatic string a_str, b_str;
            for ( int j=0; j<m; j++ ) begin
               a_str = { a_str, j?",":"", $sformatf("%h",a[j]) };
               b_str = { b_str, j?",":"", $sformatf("%h",b[j]) };
            end

         tr_entry =
           { $sformatf("%4d %s%s%s ID %h  A %s  B %s",
                       cycle,
                       reset ? "R" : "_",
                       first ? "F" : "_",
                       last ? "L" : "_",
                       in_id, a_str, b_str ),
             "  ",
             $sformatf("exp: fid %h  lid %h  dp %h",
                       exp_info[$].first_id, exp_info[$].last_id,
                       exp_info[$].dp),
             "  ",
             $sformatf("MOD: %s %h  %s %h  %s %h",
                       err_fid ? "ERR" : err_time ? "err" : "FID",
                       out_first_id,
                       err_lid ? "ERR" : err_time ? "err" : "LID",
                       out_last_id,
                       dp_idx == -1 ? "ER" : age < 3 ? "EY" : err_time ? "er" : "DP",
                       dp )
             };
         end

         ev_trace.push_back(tr_entry);

         while ( ev_trace.size() > 10 ) tr_entry = ev_trace.pop_front();

         if ( err_text.size() || cycle < 20 ) begin

            while ( ev_trace.size() ) $write("%s\n",ev_trace.pop_front());

            //  foreach ( ev_trace[i] ) $write("%s\n",ev_trace[i]);
            foreach ( err_text[i] ) $write("%s",err_text[i]);

         end

      end

      begin

         automatic real avg_latency =
           n_correct ? sum_latency / real'(n_correct) : 0;

         $write("Done with %0d tests. %0d dp errs ( %0d correct )\n",
                n_tests_completed, n_err_sum, n_correct);

         //  $write("Done with %0d tests. %0d ID errs,  %0d FID errs, %0d LID errs.\n",
         $write("Done with %0d tests. %0d FID errs, %0d LID errs.\n",
                n_tests_completed, n_err_fid, n_err_lid);
         $write("Done with %0d tests. Correct %0d, avg latency %.1f cyc %0s\n",
                n_tests_completed, n_correct, avg_latency,
                avg_latency < latency_min ?
                $sformatf("Error, too low, %0d cyc minimum.",latency_min) : "Okay");

      end

   end


endmodule


// cadence translate_on