```////////////////////////////////////////////////////////////////////////////////
//
/// LSU EE 4755 Fall 2020 Homework 3 -- SOLUTION
//

/// Assignment  https://www.ece.lsu.edu/koppel/v/2020/hw03.pdf

`default_nettype none

//////////////////////////////////////////////////////////////////////////////
///  Problem 1
//
/// Modify nnOxI and nn1xI so they compute same output as nnOxIbe.
///
//
//     [✔] Make sure that the testbench does not report errors.
//     [✔] Module must be synthesizable. Use command: genus -files syn.tcl
//
//     [✔] Don't assume any particular parameter value.
//     [✔] Don't make ports wider than necessary.
//
//     [✔] Code must be written clearly.

/// Module Connection Names
//
//  no:  Number of elements in Output array.
//  ni:  Number of elements in Input array.
//  wo:  Width (number of bits) in each element of output array.
//  wi:  Width (number of bits) in each element of input array.
//  ww:  Width (number of bits) in each element of weight array.
//  sat: If 0, on overflow use low wo bits of result.
//       If 1, on overflow set result to maximum possible value,
//             do this where overflow occurs.
//       If 2, on overflow set result to maximum possible value,
//             do this in nnOxI (not in nn1xI nor in nnAdd, nnMult, etc.)
//  tr:  If 0, generate a linear connection of nnMADD modules in nn1xI.
//       If 1, generate a tree connection of arithmetic units by
//             recursively defining nn1xI.
//  ao:  Activation (neuron) Output array.
//  ai:  Activation (neuron) Input array.
//  wht: Weights.

module nnOxI
#( int no = 4, ni = 2, wo = 10, wi = 4, ww = 5, tr = 0, sat = 0 )
( output uwire [wo-1:0] ao[no],
input uwire [wi-1:0] ai[ni],
input uwire [ww-1:0] wht[no][ni] );

//   [✔] Instantiate nn1xI modules here.
//   [✔] If sat == 2 replace overflow values with max possible value.
//   [✔] Don't forget to set appropriate parameter values.

/// SOLUTION

// Compute number of bits to represent largest possible value that
// can appear on an ao.
//
localparam int wr = \$clog2( ( 2**wi - 1 ) * ( 2**ww  - 1 ) * ni );

if ( sat < 2 || wr <= wo ) begin

// If overflow is not possible turn off check for saturation.
//
localparam int satp = wr <= wo ? 0 : sat;

for ( genvar i = 0;  i < no;  i++ )
nn1xI #(wo,wi,ww,ni,tr,satp) row( ao[i], ai, wht[i] );

end else begin

for ( genvar i = 0;  i < no;  i++ ) begin

uwire [wr-1:0] ar;
nn1xI #(wr,wi,ww,ni,tr,0) row( ar, ai, wht[i] );

// If there is an overflow substitute maximum value.
//
assign ao[i] = ar[wr-1:wo] ? ~wo'(0) : ar[wo-1:0];

end

end

endmodule

module nn1xI
#( int wo = 10, wi = 4, ww = 5, ni = 2, tr = 0, sat = 0 )
( output uwire [wo-1:0] ao,
input uwire [wi-1:0] ai[ni],
input uwire [ww-1:0] wht[ni] );

//   [✔] If tr == 0 use generate loop to instantiate nnMADD modules.
//   [✔] If tr == 1 use recursion to describe a tree structure ..
//   [✔] .. and use nnMADD, nnMult, and nnAdd where appropriate.
//   [✔] Don't forget to set appropriate parameter values.

/// SOLUTION
//
if ( tr ) begin

if ( ni == 1 ) begin

nnMult #(wi,ww,wo,sat) mult(ao, ai[0], wht[0] );

end else begin

localparam int nlo = ni / 2;
localparam int nhi = ni - nlo;
uwire [wo-1:0] aolo, aohi;
nn1xI #(wo,wi,ww,nlo,1,sat) nnlo(aolo, ai[0:nlo-1], wht[0:nlo-1]);
nn1xI #(wo,wi,ww,nhi,1,sat) nnhi(aohi, ai[nlo:ni-1], wht[nlo:ni-1]);

end

end else begin

uwire [wo-1:0] s[ni-1:-1];
assign s[-1] = 0;
assign ao = s[ni-1];

for ( genvar i = 0;  i < ni;  i++ )
nnMADD #(ww,wi,wo,sat) madd( s[i], wht[i], ai[i], s[i-1] );

end

endmodule

#( int wa = 10, wb = 5, ws = wa + wb, sat = 0 )
( output uwire [ws-1:0] so,
input uwire [wa-1:0] a, input uwire [wb-1:0] b, input uwire [ws-1:0] si);

/// DO NOT MODIFY THIS MODULE.

uwire [ws-1:0] p;
nnMult #(wa,wb,ws,sat) mu(p, a, b);

endmodule

#( int w = 5, sat = 0 )
( output uwire [w-1:0] so,
input uwire [w-1:0] a, b );

/// DO NOT MODIFY THIS MODULE.

uwire [w:0] s = a + b;
localparam logic [w-1:0] smax = ~w'(0);
assign so = sat && s[w] ? smax : s[w-1:0];

endmodule

module nnMult
#( int wa = 5, wb = 6, wp = wa + wb, sat = 0 )
( output uwire [wp-1:0] p,
input uwire [wa-1:0] a, input uwire [wb-1:0] b );

/// DO NOT MODIFY THIS MODULE.

localparam logic [wp-1:0] pmax = ~wp'(0);
localparam int wmx = wp > wa+wb ? wp : wa+wb;
uwire [wmx-wp:0] phi;
uwire [wp-1:0] plo;
assign {phi,plo} = a * b;
assign p = sat && wp < wa + wb && phi ? pmax : plo;

endmodule

// Synthesizing at effort level "medium"

// Module Name                               Area   Delay   Delay
//                                                 Actual  Target

// nnOxI_no2_ni16_wo12_wi5_ww4_sat0_tr0    588304   6.972  90.000 ns
// nnOxI_no2_ni16_wo12_wi5_ww4_sat0_tr1    588304   6.972  90.000 ns
// nnOxI_no2_ni16_wo12_wi5_ww4_sat1_tr0    753136  63.864  90.000 ns
// nnOxI_no2_ni16_wo12_wi5_ww4_sat1_tr1    631611   7.043  90.000 ns
// nnOxI_no2_ni16_wo12_wi5_ww4_sat2_tr0    594261   7.450  90.000 ns
// nnOxI_no2_ni16_wo12_wi5_ww4_sat2_tr1    594261   7.450  90.000 ns

// nnOxI_no2_ni16_wo12_wi5_ww4_sat0_tr0    783094   4.828   1.000 ns
// nnOxI_no2_ni16_wo12_wi5_ww4_sat0_tr1    779386   4.852   1.000 ns
// nnOxI_no2_ni16_wo12_wi5_ww4_sat1_tr0    951332   9.503   1.000 ns
// nnOxI_no2_ni16_wo12_wi5_ww4_sat1_tr1    916787   5.136   1.000 ns
// nnOxI_no2_ni16_wo12_wi5_ww4_sat2_tr0    800554   4.980   1.000 ns
// nnOxI_no2_ni16_wo12_wi5_ww4_sat2_tr1    771789   4.981   1.000 ns

// Normal exit.

module nnOxIbe
#( int no = 4, ni = 4, wo = 10, wi = 4, ww = 5, sat = 0 )
( output logic [wo-1:0] ao[no],
input uwire  [wi-1:0] ai[ni], input uwire  [ww-1:0] wht[no][ni] );

/// DO NOT MODIFY THIS MODULE
//
//  Study the code in this module to get a better understanding
//  of what the output of nnOxI should be.

// Determine the maximum possible value of each element of ao.
//
localparam logic [wo-1:0] smax = ~wo'(0);

always_comb
for ( int o = 0;  o < no;  o++ ) begin
automatic int unsigned acc = 0;
for ( int i=0; i<ni; i++ ) acc += ai[i] * wht[o][i];
ao[o] = sat && acc > smax ? smax : acc;
end

endmodule

//////////////////////////////////////////////////////////////////////////////
/// Testbench Code

typedef struct { int ni, no, wa, ww; } Config;

module testbench;

localparam int nc = 2;
localparam int configs[nc][4] = '{ '{ 4,4, 16,9 }, '{ 5,3, 15,8 } };
`ifdef grrr
initial if ( nc != configs.size() )
\$fatal(1, "Constant nc should be %0d.\n", configs.size() );
`endif

int t_errs;     // Total number of errors.
int t_errs_cat[string];  // Total errors by configuration category.
string test_summary;
initial begin
t_errs = 0;
test_summary = "";
end

final begin

automatic int mlen = 0;
foreach ( t_errs_cat[key] ) if ( mlen < key.len() ) mlen = key.len();

\$write("\n** Summary of Results **\n%s", test_summary);
foreach ( t_errs_cat[key] )
\$write("%0s%0s %5d errors.\n",
key, { mlen - key.len() {" "}}, t_errs_cat[key]);
\$write("Total number of errors: %0d\n",t_errs);

end

localparam int maxsat = 3;

uwire d[maxsat*nc:-1];    // Start / Done signals.
assign d[-1] = 1;  // Initialize first at true.

// Instantiate a testbench at each size.
//
for ( genvar i=0; i<nc; i++ ) begin
localparam int c[4] = configs[i];
for ( genvar sat=0; sat<maxsat; sat++ ) begin
localparam int idx = maxsat*i + sat;
testbench_x #(c[0],c[1],c[2],c[3],sat)
t2( .done(d[idx]), .start(d[idx-1]) );
end
end

endmodule

module testbench_x
#( int ni = 4, no = 4, wo = 16, ww = 8, sat = 0 )
( output logic done, input uwire start );

localparam int wi = ww + 1;

localparam int nmut = 3;

localparam int ntests = 10;

logic [wi-1:0] ai[ni];
uwire [wo-1:0] ao[nmut][no];
logic [ww-1:0] wht[no][ni];

string mname[] = { "Behav", "Linear", "Tree" };

typedef struct { string name; int no, ni; } Test_Set;
Test_Set ts[] = '{ '{ "n12", 1, 2 }, '{ "n1*", 1, ni }, '{ "n**", no, ni } };

nnOxIbe #(no,ni,wo,wi,ww,sat) nn0(ao[0],ai,wht);
nnOxI #(no,ni,wo,wi,ww,0,sat) nn1(ao[1],ai,wht);
nnOxI #(no,ni,wo,wi,ww,1,sat) nn2(ao[2],ai,wht);

initial begin

automatic string config_label =
\$sformatf("no=%0d, ni=%0d, wo=%0d, wi=%0d, ww=%0d, sat=%0d",
no, ni, wo, wi, ww, sat );

wait( start );

\$write("\n** Starting tests for %s\n", config_label);
testbench.test_summary =
{ testbench.test_summary, \$sformatf("Results from %s\n",config_label) };

for ( int mut = 1;  mut < nmut;  mut++ ) begin

\$write("Testing module %s\n", mname[mut]);

foreach ( ts[ti] ) begin

automatic Test_Set tinfo = ts[ti];
automatic int n_err = 0;

\$write("\n** Starting test set %s  (%0d outputs, %0d inputs) for %s **\n",
tinfo.name, tinfo.no, tinfo.ni, mname[mut] );

for ( int tnum=0; tnum < ntests;  tnum++ ) begin

for ( int io=0; io<no; io++ )
for ( int ii=0; ii<ni; ii++ )
wht[io][ii] = io < tinfo.no && ii < tinfo.ni ? \$random : 0;

for ( int ii=0; ii<ni; ii++ )
ai[ii] = ii < tinfo.ni ? \$random : 0;

#1;

for ( int io=0; io<tinfo.no; io++ ) begin

if ( ao[0][io] !== ao[mut][io] ) begin
n_err++;
if ( n_err < 4 )
\$write
("Error test # %0d, output %0d: %0d != %0d (correct)\n",
tnum, io, ao[mut][io], ao[0][io] );
end
end

end

begin
automatic string sat_key =
\$sformatf("Sat %0d",sat);
automatic string long_key =
\$sformatf("%s %s",mname[mut],sat_key);
automatic string msg =
\$sformatf("%0d %s tests on %s: %0d errors found.",
ntests, tinfo.name, mname[mut], n_err);
\$write("Done with %s\n",msg);
testbench.test_summary =
{ testbench.test_summary, \$sformatf("Results of %s\n", msg) };
testbench.t_errs += n_err;
testbench.t_errs_cat[{"All ",sat_key}] += n_err;
testbench.t_errs_cat[long_key] += n_err;
testbench.t_errs_cat[mname[mut]] += n_err;
end

end

end

done = 1;

end

endmodule