module mux2
  #( int w = 5 )
   ( output uwire [w-1:0] x,
     input uwire s,
     input uwire [w-1:0] a[2] );
   assign x = a[s];
endmodule
module mux
  #( int n = 4,
     int w = 5,
     int swid = $clog2(n) )
   ( output uwire [w-1:0] x,
     input uwire [swid-1:0] s,
     input uwire [w-1:0] a[n] );
   if ( n == 1 ) begin
      assign x = a[0];
   end else if ( n == 2 ) begin
      assign x = a[s];
   end else begin
                        localparam int srec = $clog2(n) - 1;
      localparam int nlo = 1 << srec;
                  localparam int nhi = n - nlo;
      uwire [w-1:0] x2[2];
      mux #(nlo,w     ) muxlo( x2[0], s[srec-1:0], a[0:nlo-1] );
      mux #(nhi,w,srec) muxhi( x2[1], s[srec-1:0], a[nlo:n-1] );
      mux2 #(w)         mux21(x,     s[swid-1],   x2);
   end
endmodule
module muxs
  #( int n = 4,
     int w = 5,
     int swid = $clog2(n) )
   ( output uwire [w-1:0] x,
     input uwire [swid-1:0] s,
     input uwire [w-1:0] a[n] );
   uwire [w-1:0] x2[2];
   localparam int slo = $clog2(n) - 1;
   localparam int nlo = 1 << slo;
   localparam int nhi = n - nlo;
   localparam int shi = $clog2(nhi);
   if ( nlo == 1 )
     assign x2[0] = a[0];
   else
     muxs #(nlo,w) muxlo( x2[0], s[slo-1:0], a[0:nlo-1] );
   if ( nhi == 1 )
     assign x2[1] = a[nlo];
   else
     muxs #(nhi,w) muxhi( x2[1], s[shi-1:0], a[nlo:n-1] );
   mux2 #(w) mux21 (x, s[swid-1], x2);
endmodule
cadence
module testbench();
   localparam int w = 10;
   localparam int n_in_max = 8;
   localparam int n_size = 5;
   localparam int n_mod = 2;
   localparam int n_mut = n_size * n_mod;
   uwire [w-1:0] x[n_mut];
   logic [2:0]  s;
   logic [w-1:0] a[n_in_max];
   localparam    int nin[n_size] = { 2,4,8,7,5};
   string descr[] = { "mux", "muxs" };
   for ( genvar i=0; i<n_size; i++ ) begin
      localparam int n = nin[i];
      localparam int swid = $clog2(n);
      mux  #(n,w) mi( x[i], s[swid-1:0], a[0:n-1] );
      muxs #(n,w) mni( x[n_size+i], s[swid-1:0], a[0:n-1] );
   end
   initial begin
      automatic int n_test = 0;
      automatic int n_err = 0;
      for ( int i=0; i < n_in_max; i++ ) begin
         n_test++;
         #1;
         s = i;
         for ( int j=0; j<n_in_max; j++ ) a[j] = {$random,j[3:0]};
         #1;
         for ( int m=0; m<n_mut; m++ ) begin
            automatic int ni = m % n_size;
            automatic int mi = m / n_size;
            automatic int n_in = nin[ni];
            automatic int sm = i;
            if ( i >= n_in ) continue;
            if ( x[m] !== a[sm] ) begin
               n_err++;
               $write
                 ("Error in %0d-input %s for s=%0d, 0x%0x != 0x%0x (correct)\n",
                  n_in, descr[mi], sm, x[m], a[sm]);
            end
         end
      end
      $write("Done with %0d tests, %0d errors found.\n",n_test,n_err);
   end
endmodule
cadence