main.wgsl

webgpu-based path tracer

src/shaders/main.wgsl

21.15 KB
#include random.wgsl
#include brdf.wgsl
#include sky.wgsl
#include any_hit.wgsl
#include utils.wgsl

@group(0) @binding(0) var output_buffer : texture_storage_2d<rgba32float, write>;
@group(0) @binding(1) var<storage, read> objects: Objects;
@group(0) @binding(2) var<uniform> uniforms: UniformLayout;
@group(0) @binding(5) var<storage, read> node_tree: BVH;
@group(0) @binding(6) var<storage, read> tri_lut: ObjectIndices;
@group(0) @binding(7) var<storage, read_write> input_buffer:array<vec3f>;

@group(1) @binding(0) var<storage, read> materials:array<Material>;
@group(1) @binding(1) var textures: texture_2d_array<f32>;
@group(1) @binding(2) var t_sampler: sampler;
@group(1) @binding(4) var<uniform> textureSizes: array<vec4<f32>, 128>;
@group(1) @binding(6) var blueNoiseTexture : texture_storage_2d<rgba8unorm, read>;
@group(1) @binding(7) var<storage, read> emissiveTriangleIndices : array<f32>;

// @group(0) @binding(3) var skybox: texture_2d<f32>;
// @group(0) @binding(4) var skybox_sampler: sampler;
// @group(1) @binding(5) var skyboxCDF: texture_storage_2d<rg32float, read>;
// @group(1) @binding(3) var<storage, read> areaLights:array<AreaLight>;

struct Triangle {
  corner_a: vec3<f32>,
  corner_b: vec3<f32>,
  corner_c: vec3<f32>,
  normal_a: vec3<f32>,
  normal_b: vec3<f32>,
  normal_c: vec3<f32>,
  material_idx: f32,
  uv_a: vec2<f32>,
  uv_b: vec2<f32>,
  uv_c: vec2<f32>,
  tangent_a: vec4f,
  tangent_b: vec4f,
  tangent_c: vec4f,
}

// struct AreaLight {
//   center: vec3<f32>,   
//   u: vec3<f32>,         
//   v: vec3<f32>,    
//   normal: vec3<f32>,
//   emission: vec3<f32>, 
// };

struct Ray {
  direction: vec3<f32>,
  origin: vec3<f32>,
}

struct HitInfo {
  dist: f32,
  hit: bool,
  position: vec3<f32>,
  normal: vec3<f32>, 
  material_idx: i32,
  geo_normal: vec3f,
  tri: Triangle,
  uv: vec2f,
  tangent: vec3<f32>,
  bitangent: vec3<f32>,
}

struct UniformLayout {
  position: vec3<f32>,
  frame_idx: f32,
  view: mat4x4<f32>,    
  inverse_view: mat4x4<f32>,
  projection: mat4x4<f32>,
  sun_direction: vec3<f32>,
  sun_angular_size: f32,
  sun_radiance: vec3<f32>,
  sample_count: f32,
  max_depth: f32,
  aperture: f32,
  focus_distance: f32,
  emissive_triangle_count: f32,
  thin_lens: f32,
}

struct Node {
    min_corner: vec3<f32>,
    left_child: f32,
    max_corner: vec3<f32>,
    primitive_count: f32,
}

struct BVH {
    nodes: array<Node>,
}

struct ObjectIndices {
    primitive_indices: array<f32>,
}

struct Objects {
  triangles: array<Triangle>,
}

struct Material {
    albedo: vec4<f32>,  
    metallic: f32,
    alpha_mode: f32, 
    alpha_cutoff: f32, 
    double_sided: f32,
    emission: vec3<f32>,
    roughness: f32,
    base_color_texture: f32,
    normal_texture: f32,
    metallic_roughness_texture: f32,
    emissive_texture: f32,
}



const EPSILON :f32 = 0.00001f;
const PI :f32 = 3.1415927f;
// ray tracing gems part 1 chapter 6
const FLOAT_SCALE = 1.0 / 65536.0;
const INT_SCALE = 256.0;
const ORIGIN = 1.0 / 32.0;

// Slightly offsets a ray to prevent self intersection artifacts
// Ray tracing gems part 1 chapter 6
fn offset_ray(p: vec3<f32>, n: vec3<f32>) -> vec3<f32> {
    let of_i = vec3<i32>(
        i32(INT_SCALE * n.x),
        i32(INT_SCALE * n.y),
        i32(INT_SCALE * n.z)
    );

    let p_i = vec3<f32>(
        int_to_float(float_to_int(p.x) + select(of_i.x, -of_i.x, p.x < 0.0)),
        int_to_float(float_to_int(p.y) + select(of_i.y, -of_i.y, p.y < 0.0)),
        int_to_float(float_to_int(p.z) + select(of_i.z, -of_i.z, p.z < 0.0))
    );

    return vec3<f32>(
        select(p.x + FLOAT_SCALE * n.x, p_i.x, abs(p.x) >= ORIGIN),
        select(p.y + FLOAT_SCALE * n.y, p_i.y, abs(p.y) >= ORIGIN),
        select(p.z + FLOAT_SCALE * n.z, p_i.z, abs(p.z) >= ORIGIN)
    );
}

fn sample_material_texture(uv: vec2<f32>, texture_index: u32) -> vec4<f32> {
    let tex_size = textureSizes[texture_index].xy;
    let max_tex_size = vec2<f32>(textureDimensions(textures).xy);
    // let scaled_uv = uv * tex_size / max_tex_size;
    // let clamped_uv = clamp(scaled_uv, vec2<f32>(0.0), vec2<f32>(1.0));
    // compute the valid uv bounds inside the texture array
    let tex_uv_min = vec2<f32>(0.0); // always starts at (0,0)
    let tex_uv_max = tex_size / max_tex_size; // upper-right boundary in the atlas
    // remap u_vs to this valid range
    let mapped_uv = mix(tex_uv_min, tex_uv_max, uv);
    return textureSampleLevel(textures, t_sampler, mapped_uv, texture_index, 1.0).rgba;
}


fn parse_textures(curr_material: Material, result: HitInfo) -> Material {
    var material = curr_material;
    if material.base_color_texture > -1.0 {
        material.albedo *= sample_material_texture(result.uv, u32(curr_material.base_color_texture)).rgba;
    }
    if material.metallic_roughness_texture > -1.0 {
        let metallic_roughness_texture = sample_material_texture(result.uv, u32(curr_material.metallic_roughness_texture));
        material.roughness *= metallic_roughness_texture.g;
        material.metallic *= metallic_roughness_texture.b;
    }
    if material.emissive_texture > -1.0 {
        material.emission = sample_material_texture(result.uv, u32(curr_material.emissive_texture)).rgb;
    }
    return material;
}


fn point_in_unit_disk(u: vec2f) -> vec2f {
    let r = sqrt(u.x);
    let theta = 2f * PI * u.y;
    return vec2f(r * cos(theta), r * sin(theta));
}

fn generate_pinhole_camera_ray(ndc: vec2<f32>, noise: vec2f) -> Ray {
    var ray : Ray;
    let aspect = uniforms.projection[1][1] / uniforms.projection[0][0]; // same as 1/tan_half_fov_y divided by 1/tan_half_fov_x
    let tan_half_fov_y = 1.0 / uniforms.projection[1][1];

    let x = ndc.x * aspect * tan_half_fov_y;
    let y = ndc.y * tan_half_fov_y;

    // camera basis vectors from the view matrix
    let right   =  uniforms.inverse_view[0].xyz;
    let up      =  uniforms.inverse_view[1].xyz;
    let forward = -uniforms.inverse_view[2].xyz;
    let origin  = uniforms.position;

    let pinhole_dir = normalize(x * right + y * up + forward);

    let focus_dist = uniforms.focus_distance;
    let aperture  = uniforms.aperture;
    let focus_point = origin + pinhole_dir * focus_dist;
    
    // sample lens (in local right-up plane)
    let lens_sample = point_in_unit_disk(noise) * aperture;
    let lens_offset = lens_sample.x * right + lens_sample.y * up;

    if (uniforms.thin_lens == 0.0){
         ray.origin = origin;
         ray.direction = pinhole_dir;
    } else {
        ray.origin = origin + lens_offset;
        ray.direction = normalize(focus_point - ray.origin);
    }
    return ray;
}


@compute @workgroup_size(16, 16)
fn main(
    @builtin(global_invocation_id) GlobalInvocationID: vec3<u32>,
    @builtin(local_invocation_id) LocalInvocationID: vec3<u32>,
    @builtin(workgroup_id) GroupIndex: vec3<u32>) {
    // https://www.w3.org/TR/webgpu/#coordinate-systems
    let output_dimension: vec2<i32> = vec2<i32>(textureDimensions(output_buffer));
    let pixel_position: vec2<i32> = vec2<i32>(i32(GlobalInvocationID.x), i32(GlobalInvocationID.y));
    let pixel_idx: i32 = pixel_position.y * output_dimension.x + pixel_position.x;
    
    let pixel_center: vec2<f32> = vec2<f32>(pixel_position) + vec2f(0.5);
    let uv: vec2<f32> = pixel_center / vec2f(output_dimension);
    let ndc: vec2<f32> = uv * 2.0 - vec2f(1.0);

    let noise = animated_blue_noise(pixel_position, u32(uniforms.frame_idx), u32(64)); 
    var rnd_state = u32(0);
    init_random(&rnd_state, u32(uniforms.frame_idx));
    init_random(&rnd_state, u32(pixel_position.x));
    init_random(&rnd_state, u32(pixel_position.y));

    let jitter_scale: f32 = 1;   
    // Apply blue noise instead of uniformFloat
    let jitter_x: f32 = (noise.x - 0.5) / f32(output_dimension.x) * jitter_scale;
    let jitter_y: f32 = (noise.y - 0.5) / f32(output_dimension.y) * jitter_scale;
    
    let n2 = (ndc.x + jitter_x);
    let n3 = ndc.y + jitter_y;
    let ray = generate_pinhole_camera_ray(vec2f(n2, n3), noise);

    var accumulated_color: vec3<f32> = vec3<f32>(0.0);
    let frame_weight: f32 = 1.0 / (uniforms.frame_idx + 1);
    let samples_per_pixel: i32 = i32(uniforms.sample_count);
    for (var i: i32 = 0; i < samples_per_pixel; i ++) {
        var pixel_color: vec3<f32> = shade_hit(ray, rnd_state, noise);
        var r = pixel_color.x;
        var g = pixel_color.y;
        var b = pixel_color.z;
        // lazy NaN catching
        if (r != r){ pixel_color.r = 0.0;};
        if (g != g){ pixel_color.g = 0.0;};
        if (b != b){ pixel_color.b = 0.0;};
        accumulated_color += pixel_color;
    }
    
    accumulated_color = accumulated_color / f32(samples_per_pixel);
    var prev_color: vec3<f32> = input_buffer[pixel_idx];
    var final_output : vec3f = (prev_color * uniforms.frame_idx + accumulated_color) / (uniforms.frame_idx + 1.0);
    input_buffer[pixel_idx] = final_output;
    textureStore(output_buffer, pixel_position, vec4f(final_output, 1.0));
}

fn trace(ray: Ray) -> HitInfo {
    var render_state: HitInfo;
    render_state.hit = false;
    var nearest_hit: f32 = 999.0;

    // set up for bvh traversal
    var node: Node = node_tree.nodes[0];
    var stack: array<Node, 32>;
    var stack_location: i32 = 0;

    while true {
        var primitive_count: u32 = u32(node.primitive_count);
        var contents: u32 = u32(node.left_child);

        if primitive_count == 0 {
            var child1: Node = node_tree.nodes[contents];
            var child2: Node = node_tree.nodes[contents + 1];

            var distance1: f32 = hit_aabb(ray, child1);
            var distance2: f32 = hit_aabb(ray, child2);

            if distance1 > distance2 {
                var temp_dist: f32 = distance1;
                distance1 = distance2;
                distance2 = temp_dist;

                var temp_child: Node = child1;
                child1 = child2;
                child2 = temp_child;
            }

            if distance1 > nearest_hit {
                if stack_location == 0 {
                    break;
                } else {
                    stack_location -= 1;
                    node = stack[stack_location];
                }
            } else {
                node = child1;
                if distance1 < nearest_hit {
                    stack[stack_location] = child2;
                    stack_location += 1;
                }
            }
        } else {
            for (var i: u32 = 0; i < primitive_count; i++) {
                var new_render_state: HitInfo = hit_triangle(
                    ray,
                    objects.triangles[u32(tri_lut.primitive_indices[i + contents])],
                    0.001,
                    nearest_hit,
                    render_state,
                );
                if new_render_state.hit {
                    nearest_hit = new_render_state.dist;
                    render_state = new_render_state;
                }
            }
            if stack_location == 0 {
                break;
            } else {
                stack_location -= 1;
                node = stack[stack_location];
            }
        }
    }
    return render_state;
}

fn shade_hit(ray: Ray, seed: u32, noise: vec2f) -> vec3<f32> {
    var current_seed = seed;
    var radiance = vec3f(0.0);
    var throughput = vec3f(1.0);
    var result: HitInfo;

    var temp_ray = ray;
    let bounces: u32 = u32(uniforms.max_depth);

    var pdf: f32;
    var env_pdf: f32;
    var mis_weight : f32 = 1.0;

    var sun_solid_angle = 2.0 * PI * (1.0 - cos(uniforms.sun_angular_size));
    let sun_pdf = 1.0 / sun_solid_angle; 
    let sky_pdf = 1.0 / PI;

    for (var bounce: u32 = 0; bounce < bounces; bounce++) {
        result = trace(temp_ray);
        if (!result.hit) {
            // We hit the environment; skip the sun for now. Atleast till this rudimentry temporal accmulation exists. 
            // let to_sun = dot(temp_ray.direction, uniforms.sun_direction) > cos(uniforms.sun_angular_size);            
            // let sun_radiance = sun_glow(temp_ray.direction, uniforms.sun_direction);
            // if (to_sun) {
            //  radianceOut += sun_radiance;
            // }
            // if (to_sun) {
            //  env_pdf_eval = 0.5 * sun_pdf;
            // }
            let viewZenith = abs(temp_ray.direction.z);
            let extinction = exp(-2.0 * pow(1.0 - viewZenith, 3.0));
            let skyRadiance = sky_glow(temp_ray.direction, uniforms.sun_direction) * extinction;
            let radianceOut = skyRadiance;
            if (bounce == 0) {
                radiance += throughput * radianceOut;
                break;
            }
            // bsdf generated ray carries the PDF forward to this bounce
            var env_pdf_eval = 0.5 * sky_pdf;
            let env_mis_weight = pdf / (pdf + env_pdf_eval);
            radiance += clamp_hdr(throughput * radianceOut * env_mis_weight, 10.0);
            break;
        }

        let rand = vec2f(uniform_float(&current_seed), uniform_float(&current_seed));
        var material: Material = parse_textures(materials[result.material_idx], result);
        if (material.emission.x > 0.0 || material.emission.y > 0.0 || material.emission.z > 0.0) {
            radiance += throughput * material.emission;
            // break;
        }

        // sun nee, mis weight based on prior bounce brdf 
        let env_dir = sample_sun_cone_dir(rand);
        let env_color = sun_glow(env_dir, uniforms.sun_direction);
        let env_pdf = sun_pdf;
        let n_dot_env = dot(result.normal, env_dir);
        if (n_dot_env > 0.0 && !is_occluded(result.position, result.geo_normal, env_dir, 99999.9)) {
            let env_brdf = eval_brdf(result.normal, -temp_ray.direction, env_dir, material);
            let diffuse_density = cosine_pdf(result.normal, env_dir);
            let specular_density = ggx_pdf(-temp_ray.direction, result.normal, normalize(-temp_ray.direction + env_dir), material.roughness);
            let bsdf_pdf = 0.5 * specular_density + 0.5 * diffuse_density;
            let weight = env_pdf / (env_pdf + bsdf_pdf);
            radiance += clamp_hdr(throughput * env_brdf * env_color * n_dot_env * weight / env_pdf, 10.0);
        }

        // TODO: Better selection, and also move this out.
        // emissive nee, uniformly sample emissives
        let light_index = min(u32(floor(rand.x * f32(uniforms.emissive_triangle_count))), u32(uniforms.emissive_triangle_count - 1.0));
        let tri_index = emissiveTriangleIndices[light_index];
        let tri = objects.triangles[i32(tri_index)];
        // uniformly sample point on triangle
        let u = 1.0 - rand.x;
        let v = rand.x * (1.0 - rand.y);
        let w = rand.x * rand.y;
        let light_pos = u * tri.corner_a + v * tri.corner_b + w * tri.corner_c;
        let light_normal =  normalize(cross(tri.corner_b - tri.corner_a, tri.corner_c - tri.corner_a));
        let to_light = light_pos - result.position;
        let dist2 = dot(to_light, to_light);
        let dist = sqrt(dist2);
        let light_dir = to_light / dist;
        let cos_surf = dot(result.normal, light_dir);
        let cos_light = dot(light_normal, -light_dir);

        if (cos_surf > 0.0 && cos_light > 0.0 && !is_occluded(result.position, result.geo_normal, light_dir, dist)) {
            var mat = materials[i32(tri.material_idx)];
            let direct_light_emissive_brdf = eval_brdf(result.normal, -temp_ray.direction, light_dir, material);
            // compute area of the triangle
            let edge1 = tri.corner_b - tri.corner_a;
            let edge2 = tri.corner_c - tri.corner_a;
            let area = 0.5 * length(cross(edge1, edge2));
            let light_power = area * mat.emission;
            // area to solid angle PDF conversion
            let pdf_solid_angle = dist2 / ( area);

            let diffuse_pdf = cosine_pdf(result.normal, light_dir);
            let specular_pdf = ggx_pdf(-temp_ray.direction, result.normal, normalize(-temp_ray.direction + light_dir), material.roughness);
            let bsdf_pdf = 0.5 * diffuse_pdf + 0.5 * specular_pdf;
            let mis_weight = pdf_solid_angle / (pdf_solid_angle + bsdf_pdf);
            let contrib = (throughput * direct_light_emissive_brdf * light_power * mis_weight) / pdf_solid_angle;
            radiance += clamp_hdr(contrib, 10.0);
        }
        
        // rr
        if (bounce > u32(2)) {
            let rrProbability = min(0.9, luminance(throughput));
            if (rrProbability < rand.y) {
                break;
            } else {
                throughput /= rrProbability;
            }
        }

        var view_dir = -temp_ray.direction;
        var new_dir: vec3<f32>;
        var specular_density: f32;
        var diffuse_density: f32;

        if (uniform_float(&current_seed) < 0.5) {
            new_dir = ggx_specular_sample(view_dir, result.normal, rand, material.roughness);
        } else {
            new_dir = cosine_hemisphere_sample(result.normal, vec2f(rand.y, rand.x));
        }
        let n_dot_l = dot(result.normal, new_dir);
        if (n_dot_l <= 0.0) { break; }
        specular_density = ggx_pdf(view_dir, result.normal, normalize(view_dir + new_dir), material.roughness);
        diffuse_density = cosine_pdf(result.normal, normalize(new_dir));
        pdf = 0.5  * specular_density + 0.5 * diffuse_density;
        
        let indirect_brdf = eval_brdf(result.normal, view_dir, new_dir, material);
        throughput *= (indirect_brdf * n_dot_l) / pdf;
        
        temp_ray.origin = offset_ray(result.position, result.geo_normal);
        temp_ray.direction = new_dir;
    }

    return radiance;
}

fn hit_triangle(ray: Ray, tri: Triangle, dist_min: f32, dist_max: f32, prevRay: HitInfo) -> HitInfo {
    var hit: HitInfo;
    hit.hit = false;
    
    let edge1 = tri.corner_b - tri.corner_a;
    let edge2 = tri.corner_c - tri.corner_a;
    
    let pvec = cross(ray.direction, edge2);
    let determinant = dot(edge1, pvec);
    
    // reject nearly parallel rays.
    if abs(determinant) < EPSILON {
        return hit;
    }
    
    let inv_det = 1.0 / determinant;
    let tvec = ray.origin - tri.corner_a;
    
    // compute barycentric coordinate u.
    let u = dot(tvec, pvec) * inv_det;
    if (u < 0.0 || u > 1.0) {
        return hit;
    }
    
    // compute barycentric coordinate v.
    let qvec = cross(tvec, edge1);
    let v = dot(ray.direction, qvec) * inv_det;
    if (v < 0.0 || (u + v) > 1.0) {
        return hit;
    }
    
    // calculate ray parameter (distance).
    let dist = dot(edge2, qvec) * inv_det;
    if (dist < dist_min || dist > dist_max) {
        return hit;
    }
    
    // no early outs; valid hit
    hit.hit = true;
    hit.dist = dist;
    hit.position = ray.origin + ray.direction * dist;
    hit.tri = tri;
    hit.material_idx = i32(tri.material_idx);
    
    var geo_normal = normalize(cross(edge1, edge2));
    var shading_normal = normalize((1.0 - u - v) * tri.normal_a + u * tri.normal_b + v * tri.normal_c);
    let tangent = normalize((1.0 - u - v) * tri.tangent_a + u * tri.tangent_b + v * tri.tangent_c);
    
    // shadow terminator fix: warp the hit position based on vertex normals
    // normal aware EPSILON on hit position basically
    let w = 1.0 - u - v;
    let tmpu = hit.position - tri.corner_a;
    let tmpv = hit.position - tri.corner_b;
    let tmpw = hit.position - tri.corner_c;

    let dotu = min(0.0, dot(tmpu, tri.normal_a));
    let dotv = min(0.0, dot(tmpv, tri.normal_b));
    let dotw = min(0.0, dot(tmpw, tri.normal_c));

    let pu = tmpu - dotu * tri.normal_a;
    let pv = tmpv - dotv * tri.normal_b;
    let pw = tmpw - dotw * tri.normal_c;

    let warped_offset = w * pu + u * pv + v * pw;
    // Move the hit point slightly along the warped vector field
    hit.position = hit.position + warped_offset;

    // TBN
    let T = normalize(tangent.xyz);
    let N = normalize(shading_normal);
    let B = normalize(cross(N, T)) * tangent.w;

    hit.tangent = cross(B, N);
    hit.normal = shading_normal;
    hit.uv = (1.0 - u - v) * tri.uv_a + u * tri.uv_b + v * tri.uv_c;

    // If a normal map is present, perturb the shading normal.
    let material = materials[i32(tri.material_idx)];
    if (material.normal_texture > -1.0) {
        var normal_map = sample_material_texture(hit.uv, u32(material.normal_texture));
        var normalized_map = normalize(normal_map * 2.0 - 1.0);
        normalized_map.y = -normalized_map.y;
        let world_normal = normalize(
            normalized_map.x * T +
            normalized_map.y * B +
            normalized_map.z * N
        );
        hit.normal = world_normal;
    }
    var ray_dot_tri: f32 = dot(ray.direction, geo_normal);
    if (ray_dot_tri > 0.0) {
        hit.geo_normal = -hit.geo_normal;
        hit.normal     = -hit.normal;
    }
    return hit;
}

fn hit_aabb(ray: Ray, node: Node) -> f32 {
    var reciprocal : vec3<f32> = vec3f(1.0) / ray.direction;
    var t_near: vec3<f32> = (node.min_corner - ray.origin) * reciprocal;
    var t_far: vec3<f32> = (node.max_corner - ray.origin) * reciprocal;
    var t_min: vec3<f32> = min(t_near, t_far);
    var t_max: vec3<f32> = max(t_near, t_far);

    var min_intersection: f32 = max(max(t_min.x, t_min.y), t_min.z);  // t0
    var max_intersection: f32 = min(min(t_max.x, t_max.y), t_max.z);  // t1 

    var mask: f32 = step(max_intersection, min_intersection) + step(max_intersection, 0.0);
    if min_intersection > max_intersection || max_intersection < 0 {
        return 9999.0;
    } else {
        return min_intersection;
    }
}