#version 460

#extension GL_NV_ray_tracing : require
#extension GL_GOOGLE_include_directive : enable
#include <light.h>

#include "rt-shdr-common.h"

#ifdef _RT_INTERSECTION_

#extension GL_EXT_nonuniform_qualifier : enable

void
main()
{
  const uint ii = gl_InstanceID;
  RT_Uni_Per_Instance upi = uni_per_instane_a[ii];
  const int sphere_idx = gl_PrimitiveID;
  vec4 pos_rad = pos[ii].pos[sphere_idx];
  float rad = pos_rad.w;

  vec3 c_o = pos_rad.xyz; // Sphere center.
  vec3 vrc = c_o - gl_ObjectRayOriginNV;
  float vrc_sq = dot(vrc,vrc);
  float dot_r_rc = dot(gl_ObjectRayDirectionNV,vrc);
  float vr_sq = dot(gl_ObjectRayDirectionNV,gl_ObjectRayDirectionNV);
  float dist_sq_min = vrc_sq - dot_r_rc * dot_r_rc / vr_sq;
  if ( dist_sq_min > rad ) return;

  float qa = vr_sq;
  float qb = -2 * dot(vrc,gl_ObjectRayDirectionNV);
  float qc = vrc_sq - rad * rad;
  float radical = qb * qb - 4 * qa * qc;
  if ( radical < 0 ) return;
  float rsq = sqrt(radical);

  float t0 = (-qb - rsq) / ( 2 * qa );
  float t1 = (-qb + rsq) / ( 2 * qa );
  bool outside = t0 >= gl_RayTminNV;
  float t = outside ? t0 : t1;

  if ( t < gl_RayTminNV ) return;

  uint hit_kind = outside
    ? gl_HitKindFrontFacingTriangleEXT : gl_HitKindBackFacingTriangleEXT;

  reportIntersectionNV( t, hit_kind );
}

#endif

#ifdef _RT_CLOSEST_HIT_

#extension GL_EXT_nonuniform_qualifier : enable

layout(location = 0) rayPayloadInNV vec3 rp_color;
layout(location = 2) rayPayloadNV bool rp_shadowed;

vec4 generic_lighting(vec4 vertex_e, vec4 color, vec3 normal_er);

vec4 homogenize(vec4 uh) { return uh / uh.w; }

void main()
{
  const uint ii = gl_InstanceID;
  const int sphere_idx = gl_PrimitiveID;
  RT_Uni_Per_Instance upi = uni_per_instane_a[ii];
  const bool do_lighting = true;
# define MIX(a) ( a[ii].a[sphere_idx] )

  vec4 vertex_g =
    vec4(gl_WorldRayOriginNV + gl_WorldRayDirectionNV * gl_HitTNV,1);
  vec4 vertex_o =
    vec4(gl_ObjectRayOriginNV + gl_ObjectRayDirectionNV * gl_HitTNV,1);
  vec4 pos_rad = MIX(pos);
  float rad = pos_rad.w;
  vec3 c_o = pos_rad.xyz; // Sphere center.

  const bool per_vtx_color = bool( upi.in_usage & USAGE_BVEC_COLOR );
  vec4 color_front, color_back;
  if ( per_vtx_color )
    {
      color_back = color_front = MIX(color);
    }
  else
    {
      color_front = upi.color_front;
      color_back = upi.color_back;
    }

  const bool front_facing = gl_HitKindNV == gl_HitKindFrontFacingTriangleEXT;
  const vec4 color = front_facing ? color_front : color_back;
  vec3 c;

  vec4 vertex_e = ut.mv * vertex_g;
  vec3 normal_o = vertex_o.xyz - c_o.xyz;

  if ( do_lighting )
    {
      vec3 normal_g = mat3(gl_ObjectToWorldNV) * normal_o;
      vec3 normal_e = mat3(ut.mv) * normal_g;
      c = generic_lighting(vertex_e,color,normal_e).xyz;
    }
  else
    {
      c = color.rgb;
    }

  bool use_texture = !opt_tryout1
    && bool( upi.in_usage & USAGE_BVEC_SAMPLER );
  bool use_rot = bool( upi.in_usage & USAGE_BVEC_ROT );

  if ( use_texture )
    {
      vec3 c_g = vec3( mat4(gl_ObjectToWorldNV) * vec4(c_o,1) );
      vec3 normal_g = vertex_g.xyz - c_g;
      vec3 sur_l = normal_o;
      if ( !opt_tryout2 && use_rot ) sur_l = transpose(mat3(MIX(rot))) * sur_l;
      sur_l = normalize(sur_l);

      float pi = 3.14159265359;
      float tpi = 2 * pi;
      float theta = atan(sur_l.x,sur_l.z);
      float eta = acos(sur_l.y);
      vec2 tcoord = vec2( ( 1.5f * pi + theta ) / tpi, eta / pi );

      float dang = 0.01;
      float sin_dang = dang;
      float cos_dang = sqrt(1-sin_dang*sin_dang);
      mat3 rot_th =
        mat3( vec3( cos_dang,   0, sin_dang ),
              vec3( 0,          1, 0           ),
              vec3( -sin_dang,  0, cos_dang ) );
      mat3 rot_et =
        mat3( vec3( cos_dang,  sin_dang, 0 ),
              vec3( -sin_dang, cos_dang, 0 ),
              vec3( 0,            0      , 1 ) );

      vec2 win_dim_px = gl_LaunchSizeNV.xy;
      vec2 tex_dim_px = textureSize(textureSamplers[ii],0);
      vec3 norm_g1 = rot_th * normal_g;
      vec3 norm_g2 = rot_et * normal_g;
      vec4 vtx_g1 = vec4( c_g + norm_g1, 1 );
      vec4 vtx_g2 = vec4( c_g + norm_g2, 1 );
      vec2 vtx_c = homogenize( ut.mvp * vertex_g ).xy;
      vec2 vtx_c1 = homogenize( ut.mvp * vtx_g1 ).xy;
      vec2 vtx_c2 = homogenize( ut.mvp * vtx_g2 ).xy;
      vec2 vtx_p = 0.5f * win_dim_px * vtx_c.xy;
      vec2 vtx_p1 = 0.5f * win_dim_px * vtx_c1;
      vec2 vtx_p2 = 0.5f * win_dim_px * vtx_c2;
      float dist_1_p = distance(vtx_p,vtx_p1);
      float dist_2_p = distance(vtx_p,vtx_p2);
      float dist_tex_1_t = tex_dim_px.x * dang / tpi;
      float dist_tex_2_t = tex_dim_px.y * dang / pi;
      float scale = dist_1_p > dist_2_p
        ? dist_tex_1_t / dist_1_p : dist_tex_2_t / dist_2_p;
      float lod = log2(scale);
      c *= textureLod(textureSamplers[ii], tcoord, lod).xyz;
    }

  rp_shadowed = true;
  vec4 light_g = ut.mvi * ul.cgl_LightSource[0].position;
  vec3 vtx_to_light_g = light_g.xyz - vertex_g.xyz;
  float dist_min = 0.001; // How big is a pixel?
  float tmin = dist_min / length(vtx_to_light_g);
  float tmax = 1.0;

  traceNV
   ( topLevelAS, 
     gl_RayFlagsTerminateOnFirstHitNV
     | gl_RayFlagsOpaqueNV | gl_RayFlagsSkipClosestHitShaderNV,
     0xfe, // Don't include lights in intersection test.
     0 /* sbtRecordOffset */,
     0 /* sbtRecordStride */, 
     1 /* missIndex */, 
     vertex_g.xyz,
     tmin, vtx_to_light_g, tmax, 2 /*payload location*/);

  rp_color = c;
  if ( rp_shadowed ) rp_color *= 0.3f;
}

#endif


#if 1
vec4
generic_lighting(vec4 vertex_e, vec4 color, vec3 normal_er)
{
  // Return lighted color of vertex_e.
  //

#ifdef LIGHTING_SPECULAR
  const bool specular_lighting = true;
#else
  const bool specular_lighting = false;
#endif

  vec3 nspc_color = color.rgb * ul.cgl_LightModel.ambient.rgb;
  vec3 spec_color = vec3(0);

  vec3 v_vtx_eye = -vertex_e.xyz;
  float n_er_eye = dot(v_vtx_eye,normal_er);
  vec3 normal_ef = n_er_eye < 0 ? -normal_er : normal_er;
  vec3 normal_e = normalize(normal_ef);

  for ( int i=0; i<1; i++ )
    {
      //  if ( ( lighting_options & ( 1 << i ) ) == 0 ) continue;
      vec4 light_pos = ul.cgl_LightSource[i].position;
      vec3 v_vtx_light = light_pos.xyz - vertex_e.xyz;
      float dist = length(v_vtx_light);
      float dist_vl_inv = 1.0f / dist;
      vec3 v_vtx_l_n = v_vtx_light * dist_vl_inv;

      float d_n_vl = dot(normal_e, v_vtx_l_n);
      float phase_light = max(0, d_n_vl );

      vec3 ambient_light = ul.cgl_LightSource[i].ambient.rgb;
      vec3 diffuse_light = ul.cgl_LightSource[i].diffuse.rgb;
      float distsq = dist * dist;
      float atten_inv =
        ul.cgl_LightSource[i].constantAttenuation +
        ul.cgl_LightSource[i].linearAttenuation * dist +
        ul.cgl_LightSource[i].quadraticAttenuation * distsq;
      vec3 lighted_color =
        color.rgb
        * ( ambient_light + phase_light * diffuse_light ) / atten_inv;
      nspc_color += lighted_color;

      if ( !specular_lighting ) continue;

      vec3 h = normalize( v_vtx_l_n - normalize(vertex_e.xyz) );

      spec_color +=
        pow(max(0.0,dot(normal_e,h)),16)
        * color.rgb
        * ul.cgl_LightSource[i].specular.rgb / atten_inv;
    }

  return vec4(nspc_color+spec_color,1);
}
#endif