#version 460

#extension GL_NV_ray_tracing : require
#extension GL_GOOGLE_include_directive : enable
#include <light.h>

#include <rt-shdr-common.h>

#ifdef _RT_RAYGEN_

layout ( binding = BIND_UNI_FB, set = 0, rgba8) uniform image2D fb_image;
layout ( location = 0 ) rayPayloadNV vec4 rp_color;

void
main()
{
  const vec2 pixel_p = vec2(gl_LaunchIDNV.xy) + vec2(0.5);
  vec2 inUV = pixel_p/vec2(gl_LaunchSizeNV.xy);
  inUV.y = 1 - inUV.y;
  vec2 pixel_c = inUV * 2.0 - 1.0;

  vec4 eye_g = ut.object_from_eye * vec4(0,0,0,1);
  vec4 pixel_e = ut.eye_from_clip * vec4(pixel_c, 1, 1) ;
  vec4 eye_to_pixel_g = ut.object_from_eye * vec4(normalize(pixel_e.xyz), 0) ;

  uint rayFlags = gl_RayFlagsOpaqueNV;
  uint cullMask = 0xff;
  float tmin = 0.001;
  float tmax = 10000.0;

  rp_color = vec4(0);

  traceNV
   ( topLevelAS, rayFlags, cullMask,
     0 /* intersect/hit sbtRecordOffset*/,
     0 /* intersect/hit sbtRecordStride*/,
     0 /*missIndex*/,
     eye_g.xyz, tmin, eye_to_pixel_g.xyz, tmax,
     0 /*payload location */);

  if ( rp_color.a == 0 ) return;

  imageStore(fb_image, ivec2(gl_LaunchIDNV.xy), vec4(rp_color.rgb, 0.0));
}

#endif

#ifdef _RT_MISS_EYE_GEO_

layout ( location = 0 ) rayPayloadInNV vec4 rp_color;
void main(){}

#endif

#ifdef _RT_MISS_GEO_LIGHT_

layout ( location = 2 ) rayPayloadInNV bool rp_shadowed;

void main()
{
  rp_shadowed = false;
}

#endif

#ifdef _RT_CLOSEST_HIT_

#extension GL_EXT_nonuniform_qualifier : enable

layout ( location = 0 ) rayPayloadInNV vec4 rp_color;
layout ( location = 2 ) rayPayloadNV bool rp_shadowed;

hitAttributeNV vec2 ha_bary_coor;

struct Lighted_Colors { vec3 primary, specular; };

Lighted_Colors generic_lighting_specular
(vec4 vertex_e, vec4 color, vec3 normal_er, uint shadow_vec);

vec3 generic_lighting
(vec4 v_e, vec4 color, vec3 n_er, uint shadow_vec)
{ return generic_lighting_specular(v_e,color,n_er,shadow_vec).primary; }

vec4 homogenize(vec4 uh) { return uh / uh.w; }

void
main()
{
  const uint ii = gl_InstanceCustomIndexNV;
  RT_Uni_Per_Instance upi = uni_per_instane_a[ii];
  const bool do_lighting = bool( upi.in_usage & USAGE_BVEC_NORMAL );
  const float this_depth = rp_color.a + 1;
  rp_color.a = this_depth;

  uint i0, i1, i2;
  if ( upi.grouping == RT_Geometry_Triangle_Strip )
    {
      int parity = ( gl_PrimitiveID & 1 ) == 1 ? 1 : -1;
      i1 = gl_PrimitiveID + 1;
      i0 = i1 + parity;
      i2 = i1 - parity;
    }
  else if ( upi.grouping == RT_Geometry_Triangles )
    {
      i0 = 3 * gl_PrimitiveID;
      i1 = i0 + 1;
      i2 = i0 + 2;
    }
  else
    {
      // Indexed
      i0 = idx[ii].idx[3 * gl_PrimitiveID];
      i1 = idx[ii].idx[3 * gl_PrimitiveID + 1];
      i2 = idx[ii].idx[3 * gl_PrimitiveID + 2];
    }

  const float b0 = 1.0 - ha_bary_coor.x - ha_bary_coor.y;
  const float b1 = ha_bary_coor.x;
  const float b2 = ha_bary_coor.y;
# define MIX(a) ( b0 * a[ii].a[i0] + b1 * a[ii].a[i1] + b2 * a[ii].a[i2] )

#if 0
  vec4 c0 = my_color_list[3*gl_PrimitiveID];
  vec4 c1 = my_color_list[3*gl_PrimitiveID+1];
  vec4 c2 = my_color_list[3*gl_PrimitiveID+2];
  vec4 blended_color = (1.0-bcoor.x-bcoor.y) * c0 +  bcoor.x * c1 + bcoor.y * c2;
#endif

  vec4 vertex_g =
    vec4(gl_WorldRayOriginNV + gl_WorldRayDirectionNV * gl_HitTNV,1);

  const bool per_vtx_color = bool( upi.in_usage & USAGE_BVEC_COLOR );
  vec4 color_front, color_back;
  if ( per_vtx_color )
    {
      color_back = color_front = MIX(color);
    }
  else
    {
      color_front = upi.color_front;
      color_back = upi.color_back;
    }

  const bool front_facing = gl_HitKindNV == gl_HitKindFrontFacingTriangleEXT;
  const vec4 color = front_facing ? color_front : color_back;
  vec3 color_tex = color.rgb;
  const vec3 normal_o = MIX(normal).xyz;
  const vec3 normal_g = mat3(gl_ObjectToWorldNV) * normal_o;

  Lighted_Colors lc;
  const bool reflective = color.a < 0;
  const bool do_reflection = this_depth <= com.opt_mirror && color.a < 0;

  rp_shadowed = true;
  if ( do_lighting )
    {
      uint shadow_vec = 0;
      if ( bool(com.opt_shadows) )
        for ( int i = 0; i<cgl_MaxLights; i++ )
          {
            uint lv = 1 << i;
            if ( ( lv & com.light_on_vec ) == 0 ) continue;
            vec4 light_g = ut.object_from_eye * ul.cgl_LightSource[i].position;
            vec3 vtx_to_light_g = light_g.xyz - vertex_g.xyz;
            float dist_min = 0.001; // How big is a pixel?
            float tmin = dist_min / length(vtx_to_light_g);
            float tmax = 1.0;
            rp_shadowed = true;
            traceNV
              ( topLevelAS, 
                gl_RayFlagsTerminateOnFirstHitNV
                | gl_RayFlagsOpaqueNV | gl_RayFlagsSkipClosestHitShaderNV,
                0xfe,    // Don't include lights in intersection test.
                0, 0, 1, // sbtRecordOffset, sbtRecordStride, missIndex
                vertex_g.xyz,
                tmin, vtx_to_light_g, tmax, 2 /*payload location*/);

            shadow_vec |= rp_shadowed ? lv : 0;
          }
      vec3 normal_e = mat3(ut.eye_from_object) * normal_g;
      vec4 vertex_e = ut.eye_from_object * vertex_g;
      lc = generic_lighting_specular(vertex_e,color,normal_e,shadow_vec);
    }
  else
    {
      lc.primary = color.rgb;
      lc.specular = vec3(0);
    }

  bool use_texture =
    bool( upi.in_usage & USAGE_BVEC_TCOOR )
    && bool( upi.in_usage & USAGE_BVEC_SAMPLER );

  if ( use_texture )
    {
      vec2 texCoord = MIX(tcoor);

      vec2 win_dim_px = gl_LaunchSizeNV.xy;
      vec2 tex_dim_px = textureSize(textureSamplers[ii],0);

      vec2 t0_t = tex_dim_px * tcoor[ii].tcoor[i0];
      vec2 t1_t = tex_dim_px * tcoor[ii].tcoor[i1];
      vec2 t2_t = tex_dim_px * tcoor[ii].tcoor[i2];

      float dist_tex_01_t = distance(t0_t,t1_t);
      float dist_tex_12_t = distance(t1_t,t2_t);
      vec2 p0_c = homogenize( ut.clip_from_object * pos[ii].pos[i0] ).xy;
      vec2 p1_c = homogenize( ut.clip_from_object * pos[ii].pos[i1] ).xy;
      vec2 p2_c = homogenize( ut.clip_from_object * pos[ii].pos[i2] ).xy;
      vec2 p0_p = 0.5 * win_dim_px * p0_c;
      vec2 p1_p = 0.5 * win_dim_px * p1_c;
      vec2 p2_p = 0.5 * win_dim_px * p2_c;
      float dist_p_01_p = distance(p0_p,p1_p);
      float dist_p_12_p = distance(p1_p,p2_p);
      float scale = dist_p_01_p > dist_p_12_p
        ? dist_tex_01_t / dist_p_01_p : dist_tex_12_t / dist_p_12_p;
      float lod = log2( scale );
      vec3 texel = textureLod(textureSamplers[ii], texCoord, lod).xyz;
      lc.primary *= texel;
      color_tex *= texel;
    }

  vec3 c = lc.primary + lc.specular;

  if ( do_reflection )
    {
      vec3 ray_gn = normalize(gl_WorldRayDirectionNV);
      const vec3 normal_gn = normalize( normal_g );
      float phase = dot( normal_gn, ray_gn );
      vec3 ray_refl = ray_gn - 2 * phase * normal_gn;
      float tmin = 0.001, tmax = 1000;
      float blend_factor = -color.a;

      rp_color.rgb = vec3(0);
      traceNV
        ( topLevelAS, gl_RayFlagsOpaqueNV,
          0xff,
          0, 0, 0, // sbtRecordOffset, sbtRecordStride, missIndex
          vertex_g.xyz,
          tmin, ray_refl, tmax, 0 /*payload location*/);
      rp_color.rgb = mix( c, color_tex * rp_color.rgb, blend_factor );
      rp_color.a = this_depth;
    }
  else
    {
      rp_color.rgb = c;
    }
}


Lighted_Colors
generic_lighting_specular
(vec4 vertex_e, vec4 color, vec3 normal_er, uint shade_vec)
{
  // Return lighted color of vertex_e.
  //
  const uint ii = gl_InstanceCustomIndexNV;
  RT_Uni_Per_Instance upi = uni_per_instane_a[ii];

#ifdef LIGHTING_SPECULAR
  const bool specular_lighting = upi.shininess_front != 0;
#else
  const bool specular_lighting = false;
#endif

  vec3 nspc_color = color.rgb * ul.cgl_LightModel.ambient.rgb;
  vec3 spec_color = vec3(0);

  vec3 n_ray_e = normalize(mat3(ut.eye_from_object) * gl_WorldRayDirectionNV);
  vec3 v_vtx_eye_n = -n_ray_e;
  float n_er_eye = dot(v_vtx_eye_n,normal_er);
  vec3 normal_ef = n_er_eye < 0 ? -normal_er : normal_er;
  vec3 normal_e = normalize(normal_ef);


  for ( int i=0; i<cgl_MaxLights; i++ )
    {
      uint lv = 1 << i;
      if ( ( lv & com.light_on_vec ) == 0 ) continue;
      if ( ( lv & shade_vec ) != 0 ) continue;
      vec4 light_pos = ul.cgl_LightSource[i].position;
      vec3 v_vtx_light = light_pos.xyz - vertex_e.xyz;
      float dist = length(v_vtx_light);
      float dist_vl_inv = 1.0f / dist;
      vec3 v_vtx_l_n = v_vtx_light * dist_vl_inv;

      float d_n_vl = dot(normal_e, v_vtx_l_n);
      float phase_light = max(0, d_n_vl );

      vec3 ambient_light = ul.cgl_LightSource[i].ambient.rgb;
      vec3 diffuse_light = ul.cgl_LightSource[i].diffuse.rgb;
      float distsq = dist * dist;
      float atten_inv =
        ul.cgl_LightSource[i].constantAttenuation +
        ul.cgl_LightSource[i].linearAttenuation * dist +
        ul.cgl_LightSource[i].quadraticAttenuation * distsq;
      vec3 lighted_color =
        color.rgb
        * ( ambient_light + phase_light * diffuse_light ) / atten_inv;
      nspc_color += lighted_color;

      if ( !specular_lighting || phase_light == 0 ) continue;

      vec3 h = normalize( v_vtx_l_n + v_vtx_eye_n );

      spec_color +=
        pow(max(0.0,dot(normal_e,h)),upi.shininess_front)
        * color.rgb * ul.cgl_LightSource[i].specular.rgb / atten_inv;
    }
  return Lighted_Colors(nspc_color,spec_color);
}

#endif