`default_nettype none
module nnOxI
#( int no = 4, ni = 2, wo = 10, wi = 4, ww = 5, tr = 0, sat = 0 )
( output uwire [wo-1:0] ao[no],
input uwire [wi-1:0] ai[ni],
input uwire [ww-1:0] wht[no][ni] );
localparam int wr = $clog2( ( 2**wi - 1 ) * ( 2**ww - 1 ) * ni );
if ( sat < 2 || wr <= wo ) begin
localparam int satp = wr <= wo ? 0 : sat;
for ( genvar i = 0; i < no; i++ )
nn1xI #(wo,wi,ww,ni,tr,satp) row( ao[i], ai, wht[i] );
end else begin
for ( genvar i = 0; i < no; i++ ) begin
uwire [wr-1:0] ar;
nn1xI #(wr,wi,ww,ni,tr,0) row( ar, ai, wht[i] );
assign ao[i] = ar[wr-1:wo] ? ~wo'(0) : ar[wo-1:0];
end
end
endmodule
module nn1xI
#( int wo = 10, wi = 4, ww = 5, ni = 2, tr = 0, sat = 0 )
( output uwire [wo-1:0] ao,
input uwire [wi-1:0] ai[ni],
input uwire [ww-1:0] wht[ni] );
if ( tr ) begin
if ( ni == 1 ) begin
nnMult #(wi,ww,wo,sat) mult(ao, ai[0], wht[0] );
end else begin
localparam int nlo = ni / 2;
localparam int nhi = ni - nlo;
uwire [wo-1:0] aolo, aohi;
nn1xI #(wo,wi,ww,nlo,1,sat) nnlo(aolo, ai[0:nlo-1], wht[0:nlo-1]);
nn1xI #(wo,wi,ww,nhi,1,sat) nnhi(aohi, ai[nlo:ni-1], wht[nlo:ni-1]);
nnAdd #(wo,sat) add(ao,aolo,aohi);
end
end else begin
uwire [wo-1:0] s[ni-1:-1];
assign s[-1] = 0;
assign ao = s[ni-1];
for ( genvar i = 0; i < ni; i++ )
nnMADD #(ww,wi,wo,sat) madd( s[i], wht[i], ai[i], s[i-1] );
end
endmodule
module nnMADD
#( int wa = 10, wb = 5, ws = wa + wb, sat = 0 )
( output uwire [ws-1:0] so,
input uwire [wa-1:0] a, input uwire [wb-1:0] b, input uwire [ws-1:0] si);
uwire [ws-1:0] p;
nnMult #(wa,wb,ws,sat) mu(p, a, b);
nnAdd #(ws,sat) ad(so, si, p);
endmodule
module nnAdd
#( int w = 5, sat = 0 )
( output uwire [w-1:0] so,
input uwire [w-1:0] a, b );
uwire [w:0] s = a + b;
localparam logic [w-1:0] smax = ~w'(0);
assign so = sat && s[w] ? smax : s[w-1:0];
endmodule
module nnMult
#( int wa = 5, wb = 6, wp = wa + wb, sat = 0 )
( output uwire [wp-1:0] p,
input uwire [wa-1:0] a, input uwire [wb-1:0] b );
localparam logic [wp-1:0] pmax = ~wp'(0);
localparam int wmx = wp > wa+wb ? wp : wa+wb;
uwire [wmx-wp:0] phi;
uwire [wp-1:0] plo;
assign {phi,plo} = a * b;
assign p = sat && wp < wa + wb && phi ? pmax : plo;
endmodule
588588753631594594
783779951916800771
module nnOxIbe
#( int no = 4, ni = 4, wo = 10, wi = 4, ww = 5, sat = 0 )
( output logic [wo-1:0] ao[no],
input uwire [wi-1:0] ai[ni], input uwire [ww-1:0] wht[no][ni] );
localparam logic [wo-1:0] smax = ~wo'(0);
always_comb
for ( int o = 0; o < no; o++ ) begin
automatic int unsigned acc = 0;
for ( int i=0; i<ni; i++ ) acc += ai[i] * wht[o][i];
ao[o] = sat && acc > smax ? smax : acc;
end
endmodule
cadence
typedef struct { int ni, no, wa, ww; } Config;
module testbench;
localparam int nc = 2;
localparam int configs[nc][4] = '{ '{ 4,4, 16,9 }, '{ 5,3, 15,8 } };
`ifdef grrr
initial if ( nc != configs.size() )
$fatal(1, "Constant nc should be %0d.\n", configs.size() );
`endif
int t_errs; int t_errs_cat[string]; string test_summary;
initial begin
t_errs = 0;
test_summary = "";
end
final begin
automatic int mlen = 0;
foreach ( t_errs_cat[key] ) if ( mlen < key.len() ) mlen = key.len();
$write("\n** Summary of Results **\n%s", test_summary);
foreach ( t_errs_cat[key] )
$write("%0s%0s %5d errors.\n",
key, { mlen - key.len() {" "}}, t_errs_cat[key]);
$write("Total number of errors: %0d\n",t_errs);
end
localparam int maxsat = 3;
uwire d[maxsat*nc:-1]; assign d[-1] = 1;
for ( genvar i=0; i<nc; i++ ) begin
localparam int c[4] = configs[i];
for ( genvar sat=0; sat<maxsat; sat++ ) begin
localparam int idx = maxsat*i + sat;
testbench_x #(c[0],c[1],c[2],c[3],sat)
t2( .done(d[idx]), .start(d[idx-1]) );
end
end
endmodule
module testbench_x
#( int ni = 4, no = 4, wo = 16, ww = 8, sat = 0 )
( output logic done, input uwire start );
localparam int wi = ww + 1;
localparam int nmut = 3;
localparam int ntests = 10;
logic [wi-1:0] ai[ni];
uwire [wo-1:0] ao[nmut][no];
logic [ww-1:0] wht[no][ni];
string mname[] = { "Behav", "Linear", "Tree" };
typedef struct { string name; int no, ni; } Test_Set;
Test_Set ts[] = '{ '{ "n12", 1, 2 }, '{ "n1*", 1, ni }, '{ "n**", no, ni } };
nnOxIbe #(no,ni,wo,wi,ww,sat) nn0(ao[0],ai,wht);
nnOxI #(no,ni,wo,wi,ww,0,sat) nn1(ao[1],ai,wht);
nnOxI #(no,ni,wo,wi,ww,1,sat) nn2(ao[2],ai,wht);
initial begin
automatic string config_label =
$sformatf("no=%0d, ni=%0d, wo=%0d, wi=%0d, ww=%0d, sat=%0d",
no, ni, wo, wi, ww, sat );
wait( start );
$write("\n** Starting tests for %s\n", config_label);
testbench.test_summary =
{ testbench.test_summary, $sformatf("Results from %s\n",config_label) };
for ( int mut = 1; mut < nmut; mut++ ) begin
$write("Testing module %s\n", mname[mut]);
foreach ( ts[ti] ) begin
automatic Test_Set tinfo = ts[ti];
automatic int n_err = 0;
$write("\n** Starting test set %s (%0d outputs, %0d inputs) for %s **\n",
tinfo.name, tinfo.no, tinfo.ni, mname[mut] );
for ( int tnum=0; tnum < ntests; tnum++ ) begin
for ( int io=0; io<no; io++ )
for ( int ii=0; ii<ni; ii++ )
wht[io][ii] = io < tinfo.no && ii < tinfo.ni ? $random : 0;
for ( int ii=0; ii<ni; ii++ )
ai[ii] = ii < tinfo.ni ? $random : 0;
#1;
for ( int io=0; io<tinfo.no; io++ ) begin
if ( ao[0][io] !== ao[mut][io] ) begin
n_err++;
if ( n_err < 4 )
$write
("Error test # %0d, output %0d: %0d != %0d (correct)\n",
tnum, io, ao[mut][io], ao[0][io] );
end
end
end
begin
automatic string sat_key =
$sformatf("Sat %0d",sat);
automatic string long_key =
$sformatf("%s %s",mname[mut],sat_key);
automatic string msg =
$sformatf("%0d %s tests on %s: %0d errors found.",
ntests, tinfo.name, mname[mut], n_err);
$write("Done with %s\n",msg);
testbench.test_summary =
{ testbench.test_summary, $sformatf("Results of %s\n", msg) };
testbench.t_errs += n_err;
testbench.t_errs_cat[{"All ",sat_key}] += n_err;
testbench.t_errs_cat[long_key] += n_err;
testbench.t_errs_cat[mname[mut]] += n_err;
end
end
end
done = 1;
end
endmodule
cadence