/// LSU EE 4702-1 (Fall 2014), GPU Programming
//
 /// Homework 4
 //
 ///  SOLUTION

 /// Updated for Vulkan Fall 2021

 /// Instructions
 //
 //  Read the assignment: http://www.ece.lsu.edu/koppel/gpup/2014/hw04.pdf


// Specify version of OpenGL Shading Language.
//
#version 460

#extension GL_GOOGLE_include_directive : enable
#include <light.h>
#include <transform.h>
#include "shdr-common.h"

layout ( binding = BIND_UNI_COMMON ) uniform UC
{
  Shdr_Uni_Common uc;
};

// Use this variable to debug your code. Press 'd' to toggle
// debug_bool.x and 'D' to toggle debug_bool.y (between true and
// false).
//
bvec2 debug_bool = bvec2( bool(uc.debug_bool.x), bool(uc.debug_bool.y) );

// Use this to debug your code. Press TAB until "debug_float"
// appears, then press +/- to adjust its value.
//
float debug_float = uc.debug_float;

// Array of ball positions.
//
layout ( binding = BIND_BALLS_POS ) buffer Balls_Pos { vec4 balls_pos[]; };


#ifdef _VERTEX_SHADER_

// Redefine this vertex shader input to be an integer vector.
//
layout ( location = LOC_IN_INT4 ) in ivec4 in_indices;

// Interface block for vertex shader output / geometry shader input.
//
layout ( location = 0 ) out Data_to_GS
{
  vec4 vertex_c;
  vec3 normal_e;
  vec4 vertex_e;
  vec2 tcoor;

  // Any changes here must also be made to the fragment shader input.

  /// SOLUTION - Problem 2
  //
  vec4 vertex_e_upper;
  vec4 vertex_c_upper;
  vec3 radial_e;
  ivec3 indices;
};


void
vs_main()
{
  const float spiral_radius = 0.5;
  const float omega = 10;

  const int bidx = in_indices.x;
  const int ti = in_indices.y;
  const bool inner = in_indices.z == 1;

  const int radial_idx = bidx * uc.opt_segments + ti;
  const float delta_t = 1.0 / uc.opt_segments;
  const float t = float(ti) * delta_t;
  const float theta = delta_t * radial_idx * omega;

  vec3 pos1 = balls_pos[bidx-1].xyz;
  vec3 pos2 = balls_pos[bidx].xyz;

  vec3 v12 = pos2.xyz - pos1.xyz;

  // Find a vector that's orthogonal to v12.
  //
  vec3 ax =
    normalize(v12.x == 0 ? vec3(0,v12.z,-v12.y) : vec3(v12.y,-v12.x,0));

  // Find a vector that's orthogonal to v12 and ax.
  //
  vec3 ay = normalize(cross(v12,ax));

  vec3 vx = ax * spiral_radius;
  vec3 vy = ay * spiral_radius;

  // Point on line between ball1 and ball2.
  //
  vec3 p = pos1 + t * v12;

  // Vector from p to spiral outer edge.
  //
  vec3 radial = vx * cos(theta) + vy * sin(theta);
  vec3 p_outer = p + radial;

  const float inner_frac = 0.5;
  vec3 p_inner = p + inner_frac * radial;

  // Compute surface normals.
  //
  vec3 tangial = -omega * vx * sin(theta) + omega * vy * cos(theta);
  vec3 tang = v12 + tangial;
  vec3 tang_inner = v12 + inner_frac * tangial;
  vec3 norm = normalize(cross(radial,tang));
  vec3 norm_inner = normalize(cross(radial,tang_inner));

  // The code above computed both the inner and outer spiral
  // points. But, we only need one of them. Tsk, tsk, that's wasteful!
  //
  vec4 vertex_o = vec4( inner ? p_inner : p_outer, 1 );
  vec3 normal_o = inner ? norm_inner : norm;

  // Write position and normal to shader output variables. Position
  // is written in both eye and clip space.
  //
  vertex_c = gl_ModelViewProjectionMatrix * vertex_o;
  normal_e = gl_NormalMatrix * normal_o;
  vertex_e = gl_ModelViewMatrix * vertex_o;

  // Amount by which to zoom the texture.
  //
  float tex_zoom = 0.5;

  // Uncomment the line below to use "debug_float" to zoom text.
  // tex_zoom /= debug_float;

  const float du = 0.5 * tex_zoom / uc.chain_length;
  const float u = float(bidx) * du;

  tcoor.x = tex_zoom * t;
  tcoor.y = 0.18 + u + (inner ? du : 0 );

  /// SOLUTION - Problem 2
  //
  //  Compute the position and radial of a point on the
  //  second spiral and write them to new vertex shader outputs.
  //  The radial is used as the surface normal for the edge triangles.

  vec3 v12n = normalize(v12);
  vec3 depth_vector = 0.1f * v12n;
  vec4 vertex_o_upper = vertex_o + vec4(depth_vector,0);
  vertex_c_upper = gl_ModelViewProjectionMatrix * vertex_o_upper;
  vertex_e_upper = gl_ModelViewMatrix * vertex_o_upper;
  indices = in_indices.xyz;
  radial_e = gl_NormalMatrix * radial;
}

#endif



#ifdef _GEOMETRY_SHADER_

layout ( location = 0 ) in Data_to_GS
{
  vec4 vertex_c;
  vec3 normal_e;
  vec4 vertex_e;
  vec2 tcoor;

  /// SOLUTION - Problem 2
  //
  vec4 vertex_e_upper;
  vec4 vertex_c_upper;
  vec3 radial_e;
  ivec3 indices;

} In[];

layout ( location = 0 ) out Data_to_FS
{
  vec3 normal_e;
  vec4 vertex_e;
  vec2 tcoor;

  /// SOLUTION  - Problem 2
  //
  flat int is_edge;  // True if primitive is an inner or outer edge.
};

// Type of primitive at geometry shader input.
//
layout ( triangles ) in;

 /// SOLUTION - Problem 2
// layout ( triangle_strip, max_vertices = 3 ) out;
layout ( triangle_strip, max_vertices = 12 ) out;

void
gs_main_simple()
{
  // DO NOT MODIFY THIS ROUTINE.
  // This shader used with Method 1.

  // Pass the triangle unchanged.

  for ( int i=0; i<3; i++ )
    {
      normal_e = In[i].normal_e;
      vertex_e = In[i].vertex_e;
      gl_Position = In[i].vertex_c;
      tcoor = In[i].tcoor;
      EmitVertex();
    }
  EndPrimitive();
  // DO NOT MODIFY THIS ROUTINE.
}


void
gs_main_solution()
{
  // PROBLEM 2 Solution Goes Here

  /// SOLUTION - Problem 2

  // Emit the triangles on the upper and lower spirals.
  //
  for ( int level=0; level<2; level++ )
    {
      const bool upper = level == 1;

      for ( int i=0; i<3; i++ )
        {
          normal_e = In[i].normal_e;
          vertex_e = upper ? In[i].vertex_e_upper : In[i].vertex_e;
          gl_Position = upper ? In[i].vertex_c_upper : In[i].vertex_c;
          tcoor = In[i].tcoor;
          is_edge = 0;
          EmitVertex();
        }
      EndPrimitive();
    }

  //
  // Emit the triangles on the edge.
  //

  // First, find two vertices that are both on the outer edge or both
  // on the inner edge.
  //
  int idx[2];
  if ( In[0].indices.z == In[1].indices.z )       { idx[0] = 0;  idx[1] = 1; }
  else if ( In[0].indices.z == In[2].indices.z )  { idx[0] = 0;  idx[1] = 2; }
  else                                            { idx[0] = 1;  idx[1] = 2; }

  bool is_inner = In[idx[0]].indices.z == 1;

  // Emit the edge triangles.
  //
  for ( int i=0; i<2; i++ )
    {
      const int v = idx[i];
      vertex_e = In[v].vertex_e;
      gl_Position = In[v].vertex_c;
      normal_e = is_inner ? -In[v].radial_e : In[v].radial_e;
      is_edge = 1;
      EmitVertex();
      vertex_e = In[v].vertex_e_upper;
      gl_Position = In[v].vertex_c_upper;
      normal_e = is_inner ? -In[v].radial_e : In[v].radial_e;
      is_edge = 1;
      EmitVertex();
    }
  EndPrimitive();
}

#endif


#ifdef _FRAGMENT_SHADER_

layout ( location = 0 ) in Data_to_FS
{
  vec3 normal_e;
  vec4 vertex_e;
  vec2 tcoor;

  /// SOLUTION - Problem 2
  //
  flat int is_edge;
};

layout ( binding = BIND_TEXUNIT ) uniform sampler2D tex_unit_0;
layout ( location = 0 ) out vec4 frag_color;
vec4 generic_lighting(vec4 vertex_e, vec4 color, vec3 normal_e);


void
fs_main()
{
  /// SOLUTION - Problem 1 (mostly).

  /// SOLUTION
  //
  vec4 color =
    bool(is_edge) ? uc.color_edge :
    gl_FrontFacing ? uc.color_front : uc.color_back;

  // Get filtered texel, unless the fragment belongs to an edge primitive.
  //
  vec4 texel = bool(is_edge) ? vec4(1,1,1,1) : texture(tex_unit_0,tcoor);

  // If texel is too dark don't write fragment.
  //
  bool hole = texel.r + texel.g + texel.b < 0.05;
  if ( hole ) discard;

  // Multiply filtered texel color with lighted color of fragment.
  //
  frag_color = texel * generic_lighting(vertex_e, color, normalize(normal_e));
}

void
fs_main_orig()
{
  vec4 color = gl_FrontFacing ? uc.color_front : uc.color_back;

  // Get filtered texel.
  //
  vec4 texel = texture( tex_unit_0, tcoor );

  // Multiply filtered texel color with lighted color of fragment.
  //
  frag_color = texel * generic_lighting(vertex_e, color, normalize(normal_e));
}


///
/// Routine used by Either Vertex or Fragment Shader
///

vec4
generic_lighting(vec4 vertex_e, vec4 color, vec3 normal_e)
{
  // Return lighted color of vertex_e.
  //
  vec4 light_pos = gl_LightSource[0].position;
  vec3 v_vtx_light = light_pos.xyz - vertex_e.xyz;
  float d_n_ve = -dot(normal_e,vertex_e.xyz);
  float d_n_vl = dot(normal_e, normalize(v_vtx_light).xyz);
  bool same_sign = ( d_n_ve > 0 ) == ( d_n_vl > 0 );
  float phase_light = same_sign ? abs(d_n_vl) : 0;

  vec3 ambient_light = gl_LightSource[0].ambient.rgb;
  vec3 diffuse_light = gl_LightSource[0].diffuse.rgb;
  float dist = length(v_vtx_light);
  float distsq = dist * dist;
  float atten_inv =
    gl_LightSource[0].constantAttenuation +
    gl_LightSource[0].linearAttenuation * dist +
    gl_LightSource[0].quadraticAttenuation * distsq;
  vec4 lighted_color;
  lighted_color.rgb =
    color.rgb * gl_LightModel.ambient.rgb
    + color.rgb * ( ambient_light + phase_light * diffuse_light ) / atten_inv;
  lighted_color.a = color.a;
  return lighted_color;
}

#endif