Quadtree Hover

Description

In this scatterplot, we're building off of Basic Hover Scatterplot and mirroring the approach done in Delaunay Hover, except we're going to try a different data structure: the magnificent quadtree. Think of it like a binary tree but in two dimensions. The actual code for representing our data as a quadtree happens in the handy d3-quatree package, which even provides a find(x, y, radius) function. Thanks D3, you the best.

We will have to pay an upfront cost to build the quadtree, but we will have faster lookup times to get the points near the mouse. As with the other approaches, this may be unnecessary – stick with the Closest Point Hover brute-force method until you need more performance. Why would you use this over delaunay? It's not entirely clear – I know there have been issues in the past with d3-voronoi failing when it is given colinear points (which can often happen when points are in the same exact position) that you can avoid by using quadtrees, but I'm not sure if the Delaunay triangulation also has a problem with them. Anyway, try it out and see which you prefer!

I will note that while find() only returns the closest point within the radius, you can use visit() to traverse the quadtree if you need more flexibility (e.g. you want all the points within the radius). I wrote a post about using this approach for brushing years ago. Nowadays I'd still stick with a brute-force search until things got too slow. Be sure to check out the related post on KDBush Hover that is very similar to the d3-quadtree approach, but by default returns all points in a given radius.

Code

1import * as React from 'react' // v17.0.2
2import { extent } from 'd3-array' // v^2.12.1
3import { csvParse } from 'd3-dsv' // v^2.0.0
4import { format } from 'd3-format' // v^2.0.0
5import { lab, rgb } from 'd3-color' // v^2.0.0
6import { scaleLinear, scaleSequential, scaleSqrt } from 'd3-scale' // v^3.2.4
7import { interpolateCool } from 'd3-scale-chromatic' // v^2.0.0
8import { pointer } from 'd3-selection' // v^2.0.0
9import { quadtree } from 'd3-quadtree' // v^2.0.0
10
11const Scatterplot = ({}) => {
12 const data = useMovieData()
13
14 const width = 650
15 const height = 400
16 const margin = { top: 10, right: 100, bottom: 30, left: 50 }
17 const innerWidth = width - margin.left - margin.right
18 const innerHeight = height - margin.top - margin.bottom
19
20 // read from pre-defined metric/dimension ("fields") bundles
21 const xField = fields.revenue
22 const yField = fields.vote_average
23 const rField = fields.vote_count
24 const colorField = fields.vote_average
25 const labelField = fields.original_title
26
27 // optionally pull out values into local variables
28 const { accessor: xAccessor, title: xTitle, formatter: xFormatter } = xField
29 const { accessor: yAccessor, title: yTitle, formatter: yFormatter } = yField
30 const { accessor: rAccessor } = rField
31 const { accessor: colorAccessor } = colorField
32
33 // memoize creating our scales so we can optimize re-renders with React.memo
34 // (e.g. <Points> only re-renders when its props change)
35 const { xScale, yScale, rScale, colorScale } = React.useMemo(() => {
36 if (!data) return {}
37 const xExtent = extent(data, xAccessor)
38 const yExtent = extent(data, yAccessor)
39 const rExtent = extent(data, rAccessor)
40 const colorExtent = extent(data, colorAccessor)
41 // const colorDomain = Array.from(new Set(data.map(colorAccessor))).sort()
42
43 // const radius = 4
44 const xScale = scaleLinear().domain(xExtent).range([0, innerWidth])
45 const yScale = scaleLinear().domain(yExtent).range([innerHeight, 0])
46 const rScale = scaleSqrt().domain(rExtent).range([2, 16])
47 const colorScale = scaleSequential(interpolateCool).domain(colorExtent)
48 // const colorScale = scaleOrdinal().domain(colorDomain).range(tableau20)
49
50 return {
51 xScale,
52 yScale,
53 rScale,
54 colorScale,
55 }
56 }, [
57 colorAccessor,
58 data,
59 innerHeight,
60 innerWidth,
61 rAccessor,
62 xAccessor,
63 yAccessor,
64 ])
65
66 // interaction setup
67 const interactionRef = React.useRef(null)
68 const hoverPoint = useClosestHoverPointQuadtree({
69 interactionRef,
70 data,
71 xScale,
72 xAccessor,
73 yScale,
74 yAccessor,
75 radius: 60,
76 })
77
78 if (!data) return <div style={{ width, height }} />
79
80 return (
81 <div style={{ width }} className="relative">
82 <svg width={width} height={height}>
83 <g transform={`translate(${margin.left} ${margin.top})`}>
84 <XAxis
85 xScale={xScale}
86 formatter={xFormatter}
87 title={xTitle}
88 innerHeight={innerHeight}
89 gridLineHeight={innerHeight}
90 />
91 <YAxis
92 gridLineWidth={innerWidth}
93 yScale={yScale}
94 formatter={yFormatter}
95 title={yTitle}
96 />
97 <Points
98 data={data}
99 xScale={xScale}
100 xAccessor={xAccessor}
101 yScale={yScale}
102 yAccessor={yAccessor}
103 rScale={rScale}
104 rAccessor={rAccessor}
105 colorScale={colorScale}
106 colorAccessor={colorAccessor}
107 />
108 <HoverPoint
109 labelField={labelField}
110 xScale={xScale}
111 xField={xField}
112 yScale={yScale}
113 yField={yField}
114 rScale={rScale}
115 rField={rField}
116 colorScale={colorScale}
117 colorField={colorField}
118 hoverPoint={hoverPoint}
119 />
120
121 <rect
122 /* this node absorbs all mouse events */
123 ref={interactionRef}
124 width={innerWidth}
125 height={innerHeight}
126 x={0}
127 y={0}
128 fill="tomato"
129 fillOpacity={0}
130 />
131 </g>
132 </svg>
133 </div>
134 )
135}
136
137export default Scatterplot
138
139/**
140 * Custom hook to get the closest point to the mouse based on
141 * a quadtree. Supports a max distance from
142 * the mouse via the radius prop. You must provide an ref to a
143 * DOM node that can be used to capture the mouse, typically a
144 * <rect> or <g> that covers the entire visualization.
145 */
146function useClosestHoverPointQuadtree({
147 interactionRef,
148 data,
149 xScale,
150 xAccessor,
151 yScale,
152 yAccessor,
153 radius,
154}) {
155 // capture our hover point or undefined if none
156 const [hoverPoint, setHoverPoint] = React.useState(undefined)
157
158 // we can throttle our updates by using requestAnimationFrame (raf)
159 const rafRef = React.useRef(null)
160
161 // precompute the quadtree
162 const quadtreeInstance = React.useMemo(() => {
163 if (data == null) return null
164 const quadtreeInstance = quadtree()
165 .x((d) => xScale(xAccessor(d)))
166 .y((d) => yScale(yAccessor(d)))
167 .addAll(data)
168
169 return quadtreeInstance
170 }, [data, xScale, xAccessor, yScale, yAccessor])
171
172 React.useEffect(() => {
173 const interactionRect = interactionRef.current
174 if (interactionRect == null) return
175
176 const handleMouseMove = (evt) => {
177 // here we use d3-selection's pointer. You could also try react-use useMouse.
178 const [mouseX, mouseY] = pointer(evt)
179
180 // if we already had a pending update, cancel it in favour of this one
181 if (rafRef.current) {
182 cancelAnimationFrame(rafRef.current)
183 }
184
185 rafRef.current = requestAnimationFrame(() => {
186 // find closest point via handy quadtree
187 const newHoverPoint = quadtreeInstance.find(mouseX, mouseY, radius)
188 setHoverPoint(newHoverPoint)
189 })
190 }
191 interactionRect.addEventListener('mousemove', handleMouseMove)
192
193 // make sure we handle when the mouse leaves the interaction area to remove
194 // our active hover point
195 const handleMouseLeave = () => setHoverPoint(undefined)
196 interactionRect.addEventListener('mouseleave', handleMouseLeave)
197
198 // cleanup our listeners
199 return () => {
200 interactionRect.removeEventListener('mousemove', handleMouseMove)
201 interactionRect.removeEventListener('mouseleave', handleMouseLeave)
202 }
203 }, [
204 interactionRef,
205 data,
206 xScale,
207 quadtreeInstance,
208 yScale,
209 radius,
210 xAccessor,
211 yAccessor,
212 ])
213
214 return hoverPoint
215}
216
217/** draws our hover marks: a crosshair + point + basic tooltip */
218const HoverPoint = ({
219 hoverPoint,
220 xScale,
221 xField,
222 yField,
223 yScale,
224 rScale,
225 rField,
226 labelField,
227 color = 'cyan',
228}) => {
229 if (!hoverPoint) return null
230
231 const d = hoverPoint
232 const x = xScale(xField.accessor(d))
233 const y = yScale(yField.accessor(d))
234 const r = rScale?.(rField.accessor(d))
235 const darkerColor = darker(color)
236
237 const [xPixelMin, xPixelMax] = xScale.range()
238 const [yPixelMin, yPixelMax] = yScale.range()
239
240 return (
241 <g className="pointer-events-none">
242 <g data-testid="xCrosshair">
243 <line
244 x1={xPixelMin}
245 x2={xPixelMax}
246 y1={y}
247 y2={y}
248 stroke="#fff"
249 strokeWidth={4}
250 />
251 <line
252 x1={xPixelMin}
253 x2={xPixelMax}
254 y1={y}
255 y2={y}
256 stroke={darkerColor}
257 strokeWidth={1}
258 />
259 </g>
260 <g data-testid="yCrosshair">
261 <line
262 y1={yPixelMin}
263 y2={yPixelMax}
264 x1={x}
265 x2={x}
266 stroke="#fff"
267 strokeWidth={4}
268 />
269 <line
270 y1={yPixelMin}
271 y2={yPixelMax}
272 x1={x}
273 x2={x}
274 stroke={darkerColor}
275 strokeWidth={1}
276 />
277 </g>
278 <circle cx={x} cy={y} r={r} fill={color} stroke="#fff" strokeWidth={4} />
279 <circle
280 cx={x}
281 cy={y}
282 r={r}
283 fill={color}
284 stroke={darkerColor}
285 strokeWidth={2}
286 />
287 <g transform={`translate(${x + 8} ${y + 4})`}>
288 <OutlinedSvgText
289 stroke="#fff"
290 strokeWidth={5}
291 className="text-sm font-bold"
292 dy="0.8em"
293 >
294 {labelField.accessor(d)}
295 </OutlinedSvgText>
296 <OutlinedSvgText
297 stroke="#fff"
298 strokeWidth={5}
299 className="text-xs"
300 dy="0.8em"
301 y={16}
302 >
303 {`${xField.title}: ${xField.formatter(xField.accessor(d))}`}
304 </OutlinedSvgText>
305 <OutlinedSvgText
306 stroke="#fff"
307 strokeWidth={5}
308 className="text-xs"
309 dy="0.8em"
310 y={30}
311 >
312 {`${yField.title}: ${yField.formatter(yField.accessor(d))}`}
313 </OutlinedSvgText>
314 </g>
315 </g>
316 )
317}
318
319/**
320 * A memoized component that renders all our points, but only re-renders
321 * when its props change.
322 */
323const Points = React.memo(
324 ({
325 data,
326 xScale,
327 xAccessor,
328 yAccessor,
329 yScale,
330 rScale,
331 rAccessor,
332 radius = 8,
333 colorScale,
334 colorAccessor,
335 defaultColor = 'tomato',
336 onHover,
337 }) => {
338 return (
339 <g data-testid="Points">
340 {data.map((d, i) => {
341 // const x = (width * (d.revenue - minRevenue)) / (maxRevenue - minRevenue)
342 const x = xScale(xAccessor(d))
343 const y = yScale(yAccessor(d))
344 const r = rScale?.(rAccessor(d)) ?? radius
345 const color = colorScale?.(colorAccessor(d)) ?? defaultColor
346 const darkerColor = darker(color)
347
348 return (
349 <circle
350 key={d.id ?? i}
351 cx={x}
352 cy={y}
353 r={r}
354 fill={color}
355 stroke={darkerColor}
356 strokeWidth={1}
357 strokeOpacity={1}
358 fillOpacity={1}
359 onClick={() => console.log(d)}
360 onMouseEnter={onHover ? () => onHover(d) : null}
361 onMouseLeave={onHover ? () => onHover(undefined) : null}
362 />
363 )
364 })}
365 </g>
366 )
367 }
368)
369
370/** dynamically create a darker color */
371function darker(color, factor = 0.85) {
372 const labColor = lab(color)
373 labColor.l *= factor
374
375 // rgb doesn't correspond to visual perception, but is
376 // easy for computers
377 // const rgbColor = rgb(color)
378 // rgbColor.r *= 0.8
379 // rgbColor.g *= 0.8
380 // rgbColor.b *= 0.8
381
382 // rgb(100, 50, 50);
383 // rgb(75, 25, 25); // is this half has light perceptually?
384 return labColor.toString()
385}
386
387/** fancier way of getting a nice svg text stroke */
388const OutlinedSvgText = ({ stroke, strokeWidth, children, ...other }) => {
389 return (
390 <>
391 <text stroke={stroke} strokeWidth={strokeWidth} {...other}>
392 {children}
393 </text>
394 <text {...other}>{children}</text>
395 </>
396 )
397}
398
399/** determine number of ticks based on space available */
400function numTicksForPixels(pixelsAvailable, pixelsPerTick = 70) {
401 return Math.floor(Math.abs(pixelsAvailable) / pixelsPerTick)
402}
403
404/** Y-axis with title and grid lines */
405const YAxis = ({ yScale, title, formatter, gridLineWidth }) => {
406 const [yMin, yMax] = yScale.range()
407 const ticks = yScale.ticks(numTicksForPixels(yMax - yMin, 50))
408
409 return (
410 <g data-testid="YAxis">
411 <OutlinedSvgText
412 stroke="#fff"
413 strokeWidth={2.5}
414 dx={4}
415 dy="0.8em"
416 fill="var(--gray-600)"
417 className="font-semibold text-2xs"
418 >
419 {title}
420 </OutlinedSvgText>
421
422 <line x1={0} x2={0} y1={yMin} y2={yMax} stroke="var(--gray-400)" />
423 {ticks.map((tick) => {
424 const y = yScale(tick)
425 return (
426 <g key={tick} transform={`translate(0 ${y})`}>
427 <text
428 dy="0.34em"
429 textAnchor="end"
430 dx={-12}
431 fill="currentColor"
432 className="text-gray-400 text-2xs"
433 >
434 {formatter(tick)}
435 </text>
436 <line
437 x1={0}
438 x2={-8}
439 stroke="var(--gray-300)"
440 data-testid="tickmark"
441 />
442 {gridLineWidth ? (
443 <line
444 x1={0}
445 x2={gridLineWidth}
446 stroke="var(--gray-200)"
447 strokeOpacity={0.8}
448 data-testid="gridline"
449 />
450 ) : null}
451 </g>
452 )
453 })}
454 </g>
455 )
456}
457
458/** X-axis with title and grid lines */
459const XAxis = ({ xScale, title, formatter, innerHeight, gridLineHeight }) => {
460 const [xMin, xMax] = xScale.range()
461 const ticks = xScale.ticks(numTicksForPixels(xMax - xMin))
462
463 return (
464 <g data-testid="XAxis" transform={`translate(0 ${innerHeight})`}>
465 <text
466 x={xMax}
467 textAnchor="end"
468 dy={-4}
469 fill="var(--gray-600)"
470 className="font-semibold text-2xs text-shadow-white-stroke"
471 >
472 {title}
473 </text>
474
475 <line x1={xMin} x2={xMax} y1={0} y2={0} stroke="var(--gray-400)" />
476 {ticks.map((tick) => {
477 const x = xScale(tick)
478 return (
479 <g key={tick} transform={`translate(${x} 0)`}>
480 <text
481 y={10}
482 dy="0.8em"
483 textAnchor="middle"
484 fill="currentColor"
485 className="text-gray-400 text-2xs"
486 >
487 {formatter(tick)}
488 </text>
489 <line
490 y1={0}
491 y2={8}
492 stroke="var(--gray-300)"
493 data-testid="tickmark"
494 />
495 {gridLineHeight ? (
496 <line
497 y1={0}
498 y2={-gridLineHeight}
499 stroke="var(--gray-200)"
500 strokeOpacity={0.8}
501 data-testid="gridline"
502 />
503 ) : null}
504 </g>
505 )
506 })}
507 </g>
508 )
509}
510
511// fetch our data from CSV and translate to JSON
512const useMovieData = () => {
513 const [data, setData] = React.useState(undefined)
514
515 React.useEffect(() => {
516 fetch('/datasets/tmdb_1000_movies_small.csv')
517 // fetch('/datasets/tmdb_5000_movies.csv')
518 .then((response) => response.text())
519 .then((csvString) => {
520 const data = csvParse(csvString, (row) => {
521 return {
522 budget: +row.budget,
523 vote_average: +row.vote_average,
524 vote_count: +row.vote_count,
525 genres: JSON.parse(row.genres),
526 primary_genre: JSON.parse(row.genres)[0]?.name,
527 revenue: +row.revenue,
528 original_title: row.original_title,
529 }
530 })
531 .filter((d) => d.revenue > 0)
532 .slice(0, 30)
533
534 console.log('[data]', data)
535
536 setData(data)
537 })
538 }, [])
539
540 return data
541}
542
543// very lazy large number money formatter ($1.5M, $1.65B etc)
544const bigMoneyFormat = (value) => {
545 if (value == null) return value
546 const formatted = format('$~s')(value)
547 return formatted.replace(/G$/, 'B')
548}
549
550// metrics (numeric) + dimensions (non-numeric) = fields
551const fields = {
552 revenue: {
553 accessor: (d) => d.revenue,
554 title: 'Revenue',
555 formatter: bigMoneyFormat,
556 },
557 budget: {
558 accessor: (d) => d.budget,
559 title: 'Budget',
560 formatter: bigMoneyFormat,
561 },
562 vote_average: {
563 accessor: (d) => d.vote_average,
564 title: 'Vote Average out of 10',
565 formatter: format('.1f'),
566 },
567 vote_count: {
568 accessor: (d) => d.vote_count,
569 title: 'Vote Count',
570 formatter: format('.1f'),
571 },
572 primary_genre: {
573 accessor: (d) => d.primary_genre,
574 title: 'Primary Genre',
575 formatter: (d) => d,
576 },
577 original_title: {
578 accessor: (d) => d.original_title,
579 title: 'Original Title',
580 formatter: (d) => d,
581 },
582}
583