bvh.ts

webgpu-based path tracer

src/bvh.ts

11.01 KB
import { Vec3 } from "gl-matrix";

const MAX_BOUND = 999999;

interface Triangle {
  cornerA: Vec3;
  cornerB: Vec3;
  cornerC: Vec3;
  centroid: Vec3;
}

interface Bin {
  instanceCount: number;
  bounds: AABB;
}

// just for documentation
// interface BVHNode { 
//   min: Vec3; // 12 
//   max: Vec3; // 12 
//   left: number; // 4 
//   instanceCount: number; // 4 
// }

class AABB {
  bmin: Vec3;
  bmax: Vec3;
  constructor() {
    this.bmin = Vec3.fromValues(MAX_BOUND, MAX_BOUND, MAX_BOUND);
    this.bmax = Vec3.fromValues(-MAX_BOUND, -MAX_BOUND, -MAX_BOUND);
  }

  grow(p: Vec3) {
    Vec3.min(this.bmin, this.bmin, p);
    Vec3.max(this.bmax, this.bmax, p);
  }

  growAABB(aabb: AABB) {
    // Only grow if the other AABB is valid.
    if (aabb.bmin[0] !== MAX_BOUND) {
      this.grow(aabb.bmin);
      this.grow(aabb.bmax);
    }
  }

  area() {
    const e = Vec3.create();
    Vec3.subtract(e, this.bmax, this.bmin);
    // standard surface area measure (omitting the factor 2 is acceptable since SAH is relative)
    return e[0] * e[1] + e[1] * e[2] + e[2] * e[0];
  }
}

class BVH {
  nodesMin: Float32Array; // min x,y,z for each node
  nodesMax: Float32Array; // max x,y,z for each node
  nodesLeft: Uint32Array;
  nodesInstanceCount: Uint32Array;
  nodesUsed: number;
  triangles: Triangle[];
  triIdx: Uint32Array;

  constructor(triangles: Triangle[]) {
    this.triangles = triangles;
    this.triIdx = new Uint32Array(triangles.length);
    this.nodesUsed = 0;
    const maxNodes = 2 * triangles.length - 1;
    this.nodesMin = new Float32Array(maxNodes * 3);
    this.nodesMax = new Float32Array(maxNodes * 3);
    this.nodesLeft = new Uint32Array(maxNodes);
    this.nodesInstanceCount = new Uint32Array(maxNodes);
  }

  construct() {
    for (let i = 0; i < this.triangles.length; i++) {
      this.triIdx[i] = i;
    }
    this.nodesInstanceCount[0] = this.triangles.length;
    this.nodesLeft[0] = 0; // root node
    this.nodesUsed = 1;
    this.bounding(0);
    this.subdivide(0);
  }

  bounding(nodeIdx: number) {
    const off = nodeIdx * 3;
    // initialize the node's AABB.
    this.nodesMin[off + 0] = MAX_BOUND;
    this.nodesMin[off + 1] = MAX_BOUND;
    this.nodesMin[off + 2] = MAX_BOUND;
    this.nodesMax[off + 0] = -MAX_BOUND;
    this.nodesMax[off + 1] = -MAX_BOUND;
    this.nodesMax[off + 2] = -MAX_BOUND;

    const count = this.nodesInstanceCount[nodeIdx];
    const start = this.nodesLeft[nodeIdx];

    // temp vectors
    const minVec = Vec3.create();
    const maxVec = Vec3.create();

    for (let i = 0; i < count; i++) {
      const tri = this.triangles[this.triIdx[start + i]];

      // walk through each tri and update the bounds
      Vec3.min(minVec, [this.nodesMin[off], this.nodesMin[off + 1], this.nodesMin[off + 2]], tri.cornerA);
      Vec3.max(maxVec, [this.nodesMax[off], this.nodesMax[off + 1], this.nodesMax[off + 2]], tri.cornerA);
      this.nodesMin[off + 0] = minVec[0];
      this.nodesMin[off + 1] = minVec[1];
      this.nodesMin[off + 2] = minVec[2];
      this.nodesMax[off + 0] = maxVec[0];
      this.nodesMax[off + 1] = maxVec[1];
      this.nodesMax[off + 2] = maxVec[2];

      Vec3.min(minVec, [this.nodesMin[off], this.nodesMin[off + 1], this.nodesMin[off + 2]], tri.cornerB);
      Vec3.max(maxVec, [this.nodesMax[off], this.nodesMax[off + 1], this.nodesMax[off + 2]], tri.cornerB);
      this.nodesMin[off + 0] = minVec[0];
      this.nodesMin[off + 1] = minVec[1];
      this.nodesMin[off + 2] = minVec[2];
      this.nodesMax[off + 0] = maxVec[0];
      this.nodesMax[off + 1] = maxVec[1];
      this.nodesMax[off + 2] = maxVec[2];

      Vec3.min(minVec, [this.nodesMin[off], this.nodesMin[off + 1], this.nodesMin[off + 2]], tri.cornerC);
      Vec3.max(maxVec, [this.nodesMax[off], this.nodesMax[off + 1], this.nodesMax[off + 2]], tri.cornerC);
      this.nodesMin[off + 0] = minVec[0];
      this.nodesMin[off + 1] = minVec[1];
      this.nodesMin[off + 2] = minVec[2];
      this.nodesMax[off + 0] = maxVec[0];
      this.nodesMax[off + 1] = maxVec[1];
      this.nodesMax[off + 2] = maxVec[2];
    }
  }

  subdivide(nodeIdx: number) {
    // not enough primitives
    if (this.nodesInstanceCount[nodeIdx] <= 2) return;

    let [split, axis, cost] = this.findBestPlane(nodeIdx);

    // eval the parent node’s extent.
    const off = nodeIdx * 3;
    const extent = Vec3.create();
    Vec3.subtract(
      extent,
      [this.nodesMax[off + 0], this.nodesMax[off + 1], this.nodesMax[off + 2]],
      [this.nodesMin[off + 0], this.nodesMin[off + 1], this.nodesMin[off + 2]]
    );
    const parentArea = extent[0] * extent[1] + extent[1] * extent[2] + extent[2] * extent[0];
    const parentCost = this.nodesInstanceCount[nodeIdx] * parentArea;

    // fallback to median split if SAH cost is not better
    if (cost >= parentCost) {
      let longestAxis = 0;
      if (extent[1] > extent[0]) longestAxis = 1;
      if (extent[2] > extent[longestAxis]) longestAxis = 2;
      const start = this.nodesLeft[nodeIdx];
      const count = this.nodesInstanceCount[nodeIdx];
      const centroids: number[] = [];
      for (let i = 0; i < count; i++) {
        centroids.push(this.triangles[this.triIdx[start + i]].centroid[longestAxis]);
      }
      centroids.sort((a, b) => a - b);
      split = centroids[Math.floor(count / 2)];
      axis = longestAxis;
    }

    // partition primitives based on the chosen split
    let i = this.nodesLeft[nodeIdx];
    let j = i + this.nodesInstanceCount[nodeIdx] - 1;
    while (i <= j) {
      const tri = this.triangles[this.triIdx[i]];
      if (tri.centroid[axis] < split) {
        i++;
      } else {
        const tmp = this.triIdx[i];
        this.triIdx[i] = this.triIdx[j];
        this.triIdx[j] = tmp;
        j--;
      }
    }
    const leftCount = i - this.nodesLeft[nodeIdx];
    if (leftCount === 0 || leftCount === this.nodesInstanceCount[nodeIdx]) return;

    // construct child nodes.
    const leftIdx = this.nodesUsed++;
    const rightIdx = this.nodesUsed++;

    this.nodesLeft[leftIdx] = this.nodesLeft[nodeIdx];
    this.nodesInstanceCount[leftIdx] = leftCount;
    this.nodesLeft[rightIdx] = i;
    this.nodesInstanceCount[rightIdx] = this.nodesInstanceCount[nodeIdx] - leftCount;

    // internal node
    this.nodesLeft[nodeIdx] = leftIdx;
    this.nodesInstanceCount[nodeIdx] = 0;

    // keep going 
    this.bounding(leftIdx);
    this.bounding(rightIdx);
    this.subdivide(leftIdx);
    this.subdivide(rightIdx);
  }

  findBestPlane(nodeIdx: number): [number, number, number] {
    let bestAxis = -1;
    let bestSplit = 0;
    let bestCost = Infinity;

    const count = this.nodesInstanceCount[nodeIdx];
    const start = this.nodesLeft[nodeIdx];

    // eval centroid bounds
    const centroidMin = [Infinity, Infinity, Infinity];
    const centroidMax = [-Infinity, -Infinity, -Infinity];
    for (let i = 0; i < count; i++) {
      const tri = this.triangles[this.triIdx[start + i]];
      for (let axis = 0; axis < 3; axis++) {
        centroidMin[axis] = Math.min(centroidMin[axis], tri.centroid[axis]);
        centroidMax[axis] = Math.max(centroidMax[axis], tri.centroid[axis]);
      }
    }

    // fallback if this centroid has a degenerate centroid distributions
    const EPSILON = 1e-5;
    let degenerate = false;
    for (let axis = 0; axis < 3; axis++) {
      if (Math.abs(centroidMax[axis] - centroidMin[axis]) < EPSILON) {
        degenerate = true;
        break;
      }
    }
    if (degenerate) {
      // use median split along the longest axis of actual bounds
      let longestAxis = 0;
      const actualMin = [Infinity, Infinity, Infinity];
      const actualMax = [-Infinity, -Infinity, -Infinity];
      for (let i = 0; i < count; i++) {
        const tri = this.triangles[this.triIdx[start + i]];
        for (let axis = 0; axis < 3; axis++) {
          actualMin[axis] = Math.min(actualMin[axis], tri.cornerA[axis], tri.cornerB[axis], tri.cornerC[axis]);
          actualMax[axis] = Math.max(actualMax[axis], tri.cornerA[axis], tri.cornerB[axis], tri.cornerC[axis]);
        }
      }
      const extent = [actualMax[0] - actualMin[0], actualMax[1] - actualMin[1], actualMax[2] - actualMin[2]];
      if (extent[1] > extent[0]) longestAxis = 1;
      if (extent[2] > extent[longestAxis]) longestAxis = 2;
      const centroids: number[] = [];
      for (let i = 0; i < count; i++) {
        centroids.push(this.triangles[this.triIdx[start + i]].centroid[longestAxis]);
      }
      centroids.sort((a, b) => a - b);
      bestSplit = centroids[Math.floor(count / 2)];
      bestAxis = longestAxis;
      return [bestSplit, bestAxis, bestCost];
    }

    // use adaptive binning (bin count based on number of primitives, clamped between 8 and 32)
    const BINS = Math.max(8, Math.min(32, count));
    const bins: Bin[] = Array.from({ length: BINS }, () => ({
      instanceCount: 0,
      bounds: new AABB(),
    }));

    // for each axis, evaluate candidate splits.
    for (let axis = 0; axis < 3; axis++) {
      const axisMin = centroidMin[axis];
      const axisMax = centroidMax[axis];
      if (axisMax === axisMin) continue; // skip degenerate axis

      const axisScale = BINS / (axisMax - axisMin);
      // reset bins for this axis.
      for (let i = 0; i < BINS; i++) {
        bins[i].instanceCount = 0;
        bins[i].bounds = new AABB();
      }

      // distribute primitives into bins.
      for (let i = 0; i < count; i++) {
        const tri = this.triangles[this.triIdx[start + i]];
        let binIDX = Math.floor((tri.centroid[axis] - axisMin) * axisScale);
        if (binIDX >= BINS) binIDX = BINS - 1;
        bins[binIDX].instanceCount++;
        bins[binIDX].bounds.grow(tri.cornerA);
        bins[binIDX].bounds.grow(tri.cornerB);
        bins[binIDX].bounds.grow(tri.cornerC);
      }


      // compute cumulative sums from left and right
      const leftCount = new Array(BINS - 1).fill(0);
      const leftArea = new Array(BINS - 1).fill(0);
      const rightCount = new Array(BINS - 1).fill(0);
      const rightArea = new Array(BINS - 1).fill(0);

      let leftBox = new AABB();
      for (let i = 0, sum = 0; i < BINS - 1; i++) {
        sum += bins[i].instanceCount;
        leftCount[i] = sum;
        leftBox.growAABB(bins[i].bounds);
        leftArea[i] = leftBox.area();
      }
      let rightBox = new AABB();
      for (let i = BINS - 1, sum = 0; i > 0; i--) {
        sum += bins[i].instanceCount;
        rightCount[i - 1] = sum;
        rightBox.growAABB(bins[i].bounds);
        rightArea[i - 1] = rightBox.area();
      }

      // highly axis aligned polygons like lucy and sponza fail
      // duct taped solution to handle degen cases
      const binWidth = (axisMax - axisMin) / BINS;
      // eval candidate splits.
      for (let i = 0; i < BINS - 1; i++) {
        // small epsilon jitter to help break ties.
        const jitter = 1e-6;
        const candidatePos = axisMin + binWidth * (i + 1) + jitter;
        const cost = leftCount[i] * leftArea[i] + rightCount[i] * rightArea[i];
        if (cost < bestCost) {
          bestCost = cost;
          bestAxis = axis;
          bestSplit = candidatePos;
        }
      }
    }
    return [bestSplit, bestAxis, bestCost];
  }
}

export default BVH;