webgpu-pt

monte carlo path tracer
Contents

bvh.ts

11 kB
  1import { Vec3 } from "gl-matrix";
  2
  3const MAX_BOUND = 999999;
  4
  5interface Triangle {
  6  cornerA: Vec3;
  7  cornerB: Vec3;
  8  cornerC: Vec3;
  9  centroid: Vec3;
 10}
 11
 12interface Bin {
 13  instanceCount: number;
 14  bounds: AABB;
 15}
 16
 17// just for documentation
 18// interface BVHNode { 
 19//   min: Vec3; // 12 
 20//   max: Vec3; // 12 
 21//   left: number; // 4 
 22//   instanceCount: number; // 4 
 23// }
 24
 25class AABB {
 26  bmin: Vec3;
 27  bmax: Vec3;
 28  constructor() {
 29    this.bmin = Vec3.fromValues(MAX_BOUND, MAX_BOUND, MAX_BOUND);
 30    this.bmax = Vec3.fromValues(-MAX_BOUND, -MAX_BOUND, -MAX_BOUND);
 31  }
 32
 33  grow(p: Vec3) {
 34    Vec3.min(this.bmin, this.bmin, p);
 35    Vec3.max(this.bmax, this.bmax, p);
 36  }
 37
 38  growAABB(aabb: AABB) {
 39    // Only grow if the other AABB is valid.
 40    if (aabb.bmin[0] !== MAX_BOUND) {
 41      this.grow(aabb.bmin);
 42      this.grow(aabb.bmax);
 43    }
 44  }
 45
 46  area() {
 47    const e = Vec3.create();
 48    Vec3.subtract(e, this.bmax, this.bmin);
 49    // standard surface area measure (omitting the factor 2 is acceptable since SAH is relative)
 50    return e[0] * e[1] + e[1] * e[2] + e[2] * e[0];
 51  }
 52}
 53
 54class BVH {
 55  nodesMin: Float32Array; // min x,y,z for each node
 56  nodesMax: Float32Array; // max x,y,z for each node
 57  nodesLeft: Uint32Array;
 58  nodesInstanceCount: Uint32Array;
 59  nodesUsed: number;
 60  triangles: Triangle[];
 61  triIdx: Uint32Array;
 62
 63  constructor(triangles: Triangle[]) {
 64    this.triangles = triangles;
 65    this.triIdx = new Uint32Array(triangles.length);
 66    this.nodesUsed = 0;
 67    const maxNodes = 2 * triangles.length - 1;
 68    this.nodesMin = new Float32Array(maxNodes * 3);
 69    this.nodesMax = new Float32Array(maxNodes * 3);
 70    this.nodesLeft = new Uint32Array(maxNodes);
 71    this.nodesInstanceCount = new Uint32Array(maxNodes);
 72  }
 73
 74  construct() {
 75    for (let i = 0; i < this.triangles.length; i++) {
 76      this.triIdx[i] = i;
 77    }
 78    this.nodesInstanceCount[0] = this.triangles.length;
 79    this.nodesLeft[0] = 0; // root node
 80    this.nodesUsed = 1;
 81    this.bounding(0);
 82    this.subdivide(0);
 83  }
 84
 85  bounding(nodeIdx: number) {
 86    const off = nodeIdx * 3;
 87    // initialize the node's AABB.
 88    this.nodesMin[off + 0] = MAX_BOUND;
 89    this.nodesMin[off + 1] = MAX_BOUND;
 90    this.nodesMin[off + 2] = MAX_BOUND;
 91    this.nodesMax[off + 0] = -MAX_BOUND;
 92    this.nodesMax[off + 1] = -MAX_BOUND;
 93    this.nodesMax[off + 2] = -MAX_BOUND;
 94
 95    const count = this.nodesInstanceCount[nodeIdx];
 96    const start = this.nodesLeft[nodeIdx];
 97
 98    // temp vectors
 99    const minVec = Vec3.create();
100    const maxVec = Vec3.create();
101
102    for (let i = 0; i < count; i++) {
103      const tri = this.triangles[this.triIdx[start + i]];
104
105      // walk through each tri and update the bounds
106      Vec3.min(minVec, [this.nodesMin[off], this.nodesMin[off + 1], this.nodesMin[off + 2]], tri.cornerA);
107      Vec3.max(maxVec, [this.nodesMax[off], this.nodesMax[off + 1], this.nodesMax[off + 2]], tri.cornerA);
108      this.nodesMin[off + 0] = minVec[0];
109      this.nodesMin[off + 1] = minVec[1];
110      this.nodesMin[off + 2] = minVec[2];
111      this.nodesMax[off + 0] = maxVec[0];
112      this.nodesMax[off + 1] = maxVec[1];
113      this.nodesMax[off + 2] = maxVec[2];
114
115      Vec3.min(minVec, [this.nodesMin[off], this.nodesMin[off + 1], this.nodesMin[off + 2]], tri.cornerB);
116      Vec3.max(maxVec, [this.nodesMax[off], this.nodesMax[off + 1], this.nodesMax[off + 2]], tri.cornerB);
117      this.nodesMin[off + 0] = minVec[0];
118      this.nodesMin[off + 1] = minVec[1];
119      this.nodesMin[off + 2] = minVec[2];
120      this.nodesMax[off + 0] = maxVec[0];
121      this.nodesMax[off + 1] = maxVec[1];
122      this.nodesMax[off + 2] = maxVec[2];
123
124      Vec3.min(minVec, [this.nodesMin[off], this.nodesMin[off + 1], this.nodesMin[off + 2]], tri.cornerC);
125      Vec3.max(maxVec, [this.nodesMax[off], this.nodesMax[off + 1], this.nodesMax[off + 2]], tri.cornerC);
126      this.nodesMin[off + 0] = minVec[0];
127      this.nodesMin[off + 1] = minVec[1];
128      this.nodesMin[off + 2] = minVec[2];
129      this.nodesMax[off + 0] = maxVec[0];
130      this.nodesMax[off + 1] = maxVec[1];
131      this.nodesMax[off + 2] = maxVec[2];
132    }
133  }
134
135  subdivide(nodeIdx: number) {
136    // not enough primitives
137    if (this.nodesInstanceCount[nodeIdx] <= 2) return;
138
139    let [split, axis, cost] = this.findBestPlane(nodeIdx);
140
141    // eval the parent node’s extent.
142    const off = nodeIdx * 3;
143    const extent = Vec3.create();
144    Vec3.subtract(
145      extent,
146      [this.nodesMax[off + 0], this.nodesMax[off + 1], this.nodesMax[off + 2]],
147      [this.nodesMin[off + 0], this.nodesMin[off + 1], this.nodesMin[off + 2]]
148    );
149    const parentArea = extent[0] * extent[1] + extent[1] * extent[2] + extent[2] * extent[0];
150    const parentCost = this.nodesInstanceCount[nodeIdx] * parentArea;
151
152    // fallback to median split if SAH cost is not better
153    if (cost >= parentCost) {
154      let longestAxis = 0;
155      if (extent[1] > extent[0]) longestAxis = 1;
156      if (extent[2] > extent[longestAxis]) longestAxis = 2;
157      const start = this.nodesLeft[nodeIdx];
158      const count = this.nodesInstanceCount[nodeIdx];
159      const centroids: number[] = [];
160      for (let i = 0; i < count; i++) {
161        centroids.push(this.triangles[this.triIdx[start + i]].centroid[longestAxis]);
162      }
163      centroids.sort((a, b) => a - b);
164      split = centroids[Math.floor(count / 2)];
165      axis = longestAxis;
166    }
167
168    // partition primitives based on the chosen split
169    let i = this.nodesLeft[nodeIdx];
170    let j = i + this.nodesInstanceCount[nodeIdx] - 1;
171    while (i <= j) {
172      const tri = this.triangles[this.triIdx[i]];
173      if (tri.centroid[axis] < split) {
174        i++;
175      } else {
176        const tmp = this.triIdx[i];
177        this.triIdx[i] = this.triIdx[j];
178        this.triIdx[j] = tmp;
179        j--;
180      }
181    }
182    const leftCount = i - this.nodesLeft[nodeIdx];
183    if (leftCount === 0 || leftCount === this.nodesInstanceCount[nodeIdx]) return;
184
185    // construct child nodes.
186    const leftIdx = this.nodesUsed++;
187    const rightIdx = this.nodesUsed++;
188
189    this.nodesLeft[leftIdx] = this.nodesLeft[nodeIdx];
190    this.nodesInstanceCount[leftIdx] = leftCount;
191    this.nodesLeft[rightIdx] = i;
192    this.nodesInstanceCount[rightIdx] = this.nodesInstanceCount[nodeIdx] - leftCount;
193
194    // internal node
195    this.nodesLeft[nodeIdx] = leftIdx;
196    this.nodesInstanceCount[nodeIdx] = 0;
197
198    // keep going 
199    this.bounding(leftIdx);
200    this.bounding(rightIdx);
201    this.subdivide(leftIdx);
202    this.subdivide(rightIdx);
203  }
204
205  findBestPlane(nodeIdx: number): [number, number, number] {
206    let bestAxis = -1;
207    let bestSplit = 0;
208    let bestCost = Infinity;
209
210    const count = this.nodesInstanceCount[nodeIdx];
211    const start = this.nodesLeft[nodeIdx];
212
213    // eval centroid bounds
214    const centroidMin = [Infinity, Infinity, Infinity];
215    const centroidMax = [-Infinity, -Infinity, -Infinity];
216    for (let i = 0; i < count; i++) {
217      const tri = this.triangles[this.triIdx[start + i]];
218      for (let axis = 0; axis < 3; axis++) {
219        centroidMin[axis] = Math.min(centroidMin[axis], tri.centroid[axis]);
220        centroidMax[axis] = Math.max(centroidMax[axis], tri.centroid[axis]);
221      }
222    }
223
224    // fallback if this centroid has a degenerate centroid distributions
225    const EPSILON = 1e-5;
226    let degenerate = false;
227    for (let axis = 0; axis < 3; axis++) {
228      if (Math.abs(centroidMax[axis] - centroidMin[axis]) < EPSILON) {
229        degenerate = true;
230        break;
231      }
232    }
233    if (degenerate) {
234      // use median split along the longest axis of actual bounds
235      let longestAxis = 0;
236      const actualMin = [Infinity, Infinity, Infinity];
237      const actualMax = [-Infinity, -Infinity, -Infinity];
238      for (let i = 0; i < count; i++) {
239        const tri = this.triangles[this.triIdx[start + i]];
240        for (let axis = 0; axis < 3; axis++) {
241          actualMin[axis] = Math.min(actualMin[axis], tri.cornerA[axis], tri.cornerB[axis], tri.cornerC[axis]);
242          actualMax[axis] = Math.max(actualMax[axis], tri.cornerA[axis], tri.cornerB[axis], tri.cornerC[axis]);
243        }
244      }
245      const extent = [actualMax[0] - actualMin[0], actualMax[1] - actualMin[1], actualMax[2] - actualMin[2]];
246      if (extent[1] > extent[0]) longestAxis = 1;
247      if (extent[2] > extent[longestAxis]) longestAxis = 2;
248      const centroids: number[] = [];
249      for (let i = 0; i < count; i++) {
250        centroids.push(this.triangles[this.triIdx[start + i]].centroid[longestAxis]);
251      }
252      centroids.sort((a, b) => a - b);
253      bestSplit = centroids[Math.floor(count / 2)];
254      bestAxis = longestAxis;
255      return [bestSplit, bestAxis, bestCost];
256    }
257
258    // use adaptive binning (bin count based on number of primitives, clamped between 8 and 32)
259    const BINS = Math.max(8, Math.min(32, count));
260    const bins: Bin[] = Array.from({ length: BINS }, () => ({
261      instanceCount: 0,
262      bounds: new AABB(),
263    }));
264
265    // for each axis, evaluate candidate splits.
266    for (let axis = 0; axis < 3; axis++) {
267      const axisMin = centroidMin[axis];
268      const axisMax = centroidMax[axis];
269      if (axisMax === axisMin) continue; // skip degenerate axis
270
271      const axisScale = BINS / (axisMax - axisMin);
272      // reset bins for this axis.
273      for (let i = 0; i < BINS; i++) {
274        bins[i].instanceCount = 0;
275        bins[i].bounds = new AABB();
276      }
277
278      // distribute primitives into bins.
279      for (let i = 0; i < count; i++) {
280        const tri = this.triangles[this.triIdx[start + i]];
281        let binIDX = Math.floor((tri.centroid[axis] - axisMin) * axisScale);
282        if (binIDX >= BINS) binIDX = BINS - 1;
283        bins[binIDX].instanceCount++;
284        bins[binIDX].bounds.grow(tri.cornerA);
285        bins[binIDX].bounds.grow(tri.cornerB);
286        bins[binIDX].bounds.grow(tri.cornerC);
287      }
288
289
290      // compute cumulative sums from left and right
291      const leftCount = new Array(BINS - 1).fill(0);
292      const leftArea = new Array(BINS - 1).fill(0);
293      const rightCount = new Array(BINS - 1).fill(0);
294      const rightArea = new Array(BINS - 1).fill(0);
295
296      let leftBox = new AABB();
297      for (let i = 0, sum = 0; i < BINS - 1; i++) {
298        sum += bins[i].instanceCount;
299        leftCount[i] = sum;
300        leftBox.growAABB(bins[i].bounds);
301        leftArea[i] = leftBox.area();
302      }
303      let rightBox = new AABB();
304      for (let i = BINS - 1, sum = 0; i > 0; i--) {
305        sum += bins[i].instanceCount;
306        rightCount[i - 1] = sum;
307        rightBox.growAABB(bins[i].bounds);
308        rightArea[i - 1] = rightBox.area();
309      }
310
311      // highly axis aligned polygons like lucy and sponza fail
312      // duct taped solution to handle degen cases
313      const binWidth = (axisMax - axisMin) / BINS;
314      // eval candidate splits.
315      for (let i = 0; i < BINS - 1; i++) {
316        // small epsilon jitter to help break ties.
317        const jitter = 1e-6;
318        const candidatePos = axisMin + binWidth * (i + 1) + jitter;
319        const cost = leftCount[i] * leftArea[i] + rightCount[i] * rightArea[i];
320        if (cost < bestCost) {
321          bestCost = cost;
322          bestAxis = axis;
323          bestSplit = candidatePos;
324        }
325      }
326    }
327    return [bestSplit, bestAxis, bestCost];
328  }
329}
330
331export default BVH;