#version 460

#extension GL_NV_ray_tracing : require
#extension GL_GOOGLE_include_directive : enable
#include <light.h>

#include <rt-shdr-common.h>

vec4 homogenize(vec4 uh) { return uh / uh.w; }

#ifdef _RT_RAYGEN_

layout ( binding = BIND_UNI_FB, rgba8 ) uniform image2D fb_image;
layout ( location = 0 ) rayPayloadNV vec4 rp_color;

void
main()
{
  // Window-space coordinate of pixel to write.
  //
  const uvec2 pixel_p = gl_LaunchIDNV.xy;

  // Compute global- (object-) space coordinates of pixel.
  //
  vec2 inUV = ( pixel_p + vec2(0.5) ) / gl_LaunchSizeNV.xy;
  inUV.y = 1 - inUV.y;
  vec4 pixel_c = vec4( inUV * 2.0 - 1.0,  0,  1 );
  vec4 pixel_g = homogenize( ut.object_from_clip * pixel_c );

  // Compute global- (object-) space coordinates of eye.
  //
  vec4 eye_e = vec4(0,0,0,1);
  vec4 eye_g = ut.object_from_eye * eye_e;

  // Compute vector from eye to pixel.
  //
  vec3 eye_to_pixel_g = pixel_g.xyz - eye_g.xyz;
  //
  // This is the ray to be cast by this shader.

  uint rayFlags = gl_RayFlagsOpaqueNV;
  uint cullMask = 0xff;
  float tmin = 1;
  float tmax = 10000.0;  // This does not match far plane. Future exam question?

  rp_color = vec4(0);

  traceNV
   ( topLevelAS,  // Acceleration Structure: Holds optimized geometry info.
     rayFlags, cullMask, // Types of objects to consider (or ignore).
     //
     // Specify which sets of shaders to use. We have chosen 0 for
     // shaders from the ray generation shader (this shader). (The
     // values of these numbers is determined by how the acceleration
     // structure is constructed.)
     0 /* intersect/hit sbtRecordOffset*/,
     0 /* intersect/hit sbtRecordStride*/,
     0 /*missIndex*/,
     //
     eye_g.xyz,       // Ray Origin.  For ray generation, that's the eye.
     tmin,            // Minimum t. Ignore intersections with t < tmin.
     eye_to_pixel_g,  // Ray Direction.
     tmax,            // Maximum t. Ignore intersections with t > tmax.

     // Payload Location. Should match location in rayPayloadNV declaration.
     0
     );

  // Our shaders would write rp_color.a with 1. So a 0 means no intersect.
  if ( rp_color.a == 0 ) return;

  // At this point our shaders, such as closest hit shaders, should have
  // written rp_color with the color for this pixel. So write it to the
  // frame buffer (actually just window) image.
  //
  imageStore(fb_image, ivec2(gl_LaunchIDNV.xy), vec4(rp_color.rgb, 0.0));
}

#endif

#ifdef _RT_MISS_EYE_GEO_

layout ( location = 0 ) rayPayloadInNV vec4 rp_color;
void main(){}

#endif

#ifdef _RT_MISS_GEO_LIGHT_

layout ( location = 2 ) rayPayloadInNV bool rp_shadowed;

void main()
{
  rp_shadowed = false;
}

#endif

#ifdef _RT_CLOSEST_HIT_

#extension GL_EXT_nonuniform_qualifier : enable

layout ( location = 0 ) rayPayloadInNV vec4 rp_color;
layout ( location = 2 ) rayPayloadNV bool rp_shadowed;

hitAttributeNV vec2 ha_bary_coor;

struct Lighted_Colors { vec3 primary, specular; };

Lighted_Colors generic_lighting_specular
(vec4 vertex_e, vec4 color, vec3 normal_er, uint shadow_vec, uint iidx);

vec3 generic_lighting
(vec4 v_e, vec4 color, vec3 n_er, uint shadow_vec, uint iidx)
{ return generic_lighting_specular(v_e,color,n_er,shadow_vec,iidx).primary; }

void
main()
{
  /// Closest Hit Shader
  //
  //  Designed to work for three different groupings (topology).
  //  Optionally applies textures.
  //  Optionally reflects rays.
  //
  /// Inputs Used
  //
  //  gl_InstanceCustomIndexNV
  //    Index of the geometry instance that ray intersected.
  //  gl_PrimitiveID
  //    Index of primitive within geometry instance.


  const uint ii = gl_InstanceCustomIndexNV;
  RT_Uni_Per_Instance upi = uni_per_instance_a[ii];
  const bool do_lighting = bool( upi.in_usage & USAGE_BVEC_NORMAL );
  const float this_depth = rp_color.a + 1;
  const bool can_cast = this_depth < ray_recursion_max;
  const bool can_refl = com.opt_mirror > 0 && this_depth < ray_recursion_max-1;
  rp_color.a = this_depth;

  uint i0, i1, i2;
  if ( upi.grouping == RT_Geometry_Triangle_Strip )
    {
      int parity = ( gl_PrimitiveID & 1 ) == 1 ? 1 : -1;
      i1 = gl_PrimitiveID + 1;
      i0 = i1 + parity;
      i2 = i1 - parity;
    }
  else if ( upi.grouping == RT_Geometry_Triangles )
    {
      i0 = 3 * gl_PrimitiveID;
      i1 = i0 + 1;
      i2 = i0 + 2;
    }
  else
    {
      // Indexed
      i0 = idx[ii].idx[3 * gl_PrimitiveID];
      i1 = idx[ii].idx[3 * gl_PrimitiveID + 1];
      i2 = idx[ii].idx[3 * gl_PrimitiveID + 2];
    }

  const float b0 = 1.0 - ha_bary_coor.x - ha_bary_coor.y;
  const float b1 = ha_bary_coor.x;
  const float b2 = ha_bary_coor.y;
# define MIX(a) ( b0 * a[ii].a[i0] + b1 * a[ii].a[i1] + b2 * a[ii].a[i2] )

  vec4 vertex_g =
    vec4(gl_WorldRayOriginNV + gl_WorldRayDirectionNV * gl_HitTNV,1);

  const bool per_vtx_color = bool( upi.in_usage & USAGE_BVEC_COLOR );
  vec4 color_front, color_back;
  if ( per_vtx_color )
    {
      color_back = color_front = MIX(color);
    }
  else
    {
      color_front = upi.color_front;
      color_back = upi.color_back;
    }

  const bool front_facing = gl_HitKindNV == gl_HitKindFrontFacingTriangleEXT;
  const vec4 color = front_facing ? color_front : color_back;
  vec3 color_tex = color.rgb;
  const vec3 normal_o = MIX(normal).xyz;
  const vec3 normal_g = mat3(gl_ObjectToWorldNV) * normal_o;

  Lighted_Colors lc;
  const bool reflective = color.a < 0;
  const bool do_reflection = can_refl && color.a < 0;

  rp_shadowed = true;
  if ( do_lighting )
    {
      uint shadow_vec = 0;
      if ( bool(com.opt_shadows) && can_cast )
        for ( int i = 0; i<cgl_MaxLights; i++ )
          {
            uint lv = 1 << i;
            if ( ( lv & com.light_on_vec ) == 0 ) continue;
            vec4 light_g = ut.object_from_eye * ul.cgl_LightSource[i].position;
            vec3 vtx_to_light_g = light_g.xyz - vertex_g.xyz;
            float dist_min = 0.001; // How big is a pixel?
            float tmin = dist_min / length(vtx_to_light_g);
            float tmax = 1.0;
            rp_shadowed = true;
            traceNV
              ( topLevelAS, 
                gl_RayFlagsTerminateOnFirstHitNV
                | gl_RayFlagsOpaqueNV | gl_RayFlagsSkipClosestHitShaderNV,
                0xfe,    // Don't include lights in intersection test.
                0, 0, 1, // sbtRecordOffset, sbtRecordStride, missIndex
                vertex_g.xyz,
                tmin, vtx_to_light_g, tmax, 2 /*payload location*/);

            shadow_vec |= rp_shadowed ? lv : 0;
          }
      vec3 normal_e = mat3(ut.eye_from_object) * normal_g;
      vec4 vertex_e = ut.eye_from_object * vertex_g;
      lc = generic_lighting_specular(vertex_e,color,normal_e,shadow_vec,ii);
    }
  else
    {
      lc.primary = color.rgb;
      lc.specular = vec3(0);
    }

  bool use_texture =
    bool( upi.in_usage & USAGE_BVEC_TCOOR )
    && bool( upi.in_usage & USAGE_BVEC_SAMPLER );

  if ( use_texture )
    {
      vec2 texCoord = MIX(tcoor);

      vec2 win_dim_px = gl_LaunchSizeNV.xy;
      vec2 tex_dim_px = textureSize(textureSamplers[ii],0);

      // Pixel Coordinates of Each Vertex's Texel in Base Image
      vec2 t0_t = tex_dim_px * tcoor[ii].tcoor[i0];
      vec2 t1_t = tex_dim_px * tcoor[ii].tcoor[i1];
      vec2 t2_t = tex_dim_px * tcoor[ii].tcoor[i2];

      float dist_tex_01_t = distance(t0_t,t1_t);
      float dist_tex_12_t = distance(t1_t,t2_t);

      vec2 p0_c = homogenize( ut.clip_from_object * pos[ii].pos[i0] ).xy;
      vec2 p1_c = homogenize( ut.clip_from_object * pos[ii].pos[i1] ).xy;
      vec2 p2_c = homogenize( ut.clip_from_object * pos[ii].pos[i2] ).xy;
      vec2 p0_p = 0.5 * win_dim_px * p0_c;
      vec2 p1_p = 0.5 * win_dim_px * p1_c;
      vec2 p2_p = 0.5 * win_dim_px * p2_c;
      float dist_p_01_p = distance(p0_p,p1_p);
      float dist_p_12_p = distance(p1_p,p2_p);
      float scale = dist_p_01_p > dist_p_12_p
        ? dist_tex_01_t / dist_p_01_p : dist_tex_12_t / dist_p_12_p;
      float lod = log2( scale );
      vec3 texel = textureLod(textureSamplers[ii], texCoord, lod).xyz;
      lc.primary *= texel;
      color_tex *= texel;
    }

  vec3 c = lc.primary + lc.specular;

  if ( do_reflection )
    {
      vec3 ray_gn = normalize(gl_WorldRayDirectionNV);
      const vec3 normal_gn = normalize( normal_g );
      float phase = dot( normal_gn, ray_gn );
      vec3 ray_refl = ray_gn - 2 * phase * normal_gn;
      float tmin = 0.001, tmax = 1000;
      float blend_factor = -color.a;

      rp_color.rgb = vec3(0);
      traceNV
        ( topLevelAS, gl_RayFlagsOpaqueNV,
          0xff,
          0, 0, 0, // sbtRecordOffset, sbtRecordStride, missIndex
          vertex_g.xyz,
          tmin, ray_refl, tmax, 0 /*payload location*/);
      rp_color.rgb = mix( c, color_tex * rp_color.rgb, blend_factor );
      rp_color.a = this_depth;
    }
  else
    {
      rp_color.rgb = c;
    }
}


Lighted_Colors
generic_lighting_specular
(vec4 vertex_e, vec4 color, vec3 normal_er, uint shade_vec, uint ii)
{
  // Return lighted color of vertex_e.
  //
  RT_Uni_Per_Instance upi = uni_per_instance_a[ii];

#ifdef LIGHTING_SPECULAR
  const bool specular_lighting = upi.shininess_front != 0;
#else
  const bool specular_lighting = false;
#endif

  vec3 nspc_color = color.rgb * ul.cgl_LightModel.ambient.rgb;
  vec3 spec_color = vec3(0);

  vec3 n_ray_e = normalize(mat3(ut.eye_from_object) * gl_WorldRayDirectionNV);
  vec3 v_vtx_eye_n = -n_ray_e;
  float n_er_eye = dot(v_vtx_eye_n,normal_er);
  vec3 normal_ef = n_er_eye < 0 ? -normal_er : normal_er;
  vec3 normal_e = normalize(normal_ef);


  for ( int i=0; i<cgl_MaxLights; i++ )
    {
      uint lv = 1 << i;
      if ( ( lv & com.light_on_vec ) == 0 ) continue;
      if ( ( lv & shade_vec ) != 0 ) continue;
      vec4 light_pos = ul.cgl_LightSource[i].position;
      vec3 v_vtx_light = light_pos.xyz - vertex_e.xyz;
      float dist = length(v_vtx_light);
      float dist_vl_inv = 1.0f / dist;
      vec3 v_vtx_l_n = v_vtx_light * dist_vl_inv;

      float d_n_vl = dot(normal_e, v_vtx_l_n);
      float phase_light = max(0, d_n_vl );

      vec3 ambient_light = ul.cgl_LightSource[i].ambient.rgb;
      vec3 diffuse_light = ul.cgl_LightSource[i].diffuse.rgb;
      float distsq = dist * dist;
      float atten_inv =
        ul.cgl_LightSource[i].constantAttenuation +
        ul.cgl_LightSource[i].linearAttenuation * dist +
        ul.cgl_LightSource[i].quadraticAttenuation * distsq;
      vec3 lighted_color =
        color.rgb
        * ( ambient_light + phase_light * diffuse_light ) / atten_inv;
      nspc_color += lighted_color;

      if ( !specular_lighting || phase_light == 0 ) continue;

      vec3 h = normalize( v_vtx_l_n + v_vtx_eye_n );

      spec_color +=
        pow(max(0.0,dot(normal_e,h)),upi.shininess_front)
        * color.rgb * ul.cgl_LightSource[i].specular.rgb / atten_inv;
    }
  return Lighted_Colors(nspc_color,spec_color);
}

#endif