KDBush Hover

Description

Mapbox and internet hero Vladimir Agafonkin (aka @mourner) has provided a package called KDBush that uses a kd-tree (a more generalized version of a binary or quadtree) to help us find things in 2D space more efficiently. In this scatterplot, we'll follow the approach taken in the Quadtree Hover example, but we'll use KDBush instead of d3-quadtree.

The main difference in the code is that KDBush provides a within(x, y, radius) method that returns all of the points within the radius – not just the closest one as we saw with find() in d3-quadtree. This means we'll have to do a linear scan of the points within the radius to find which one is closest to the mouse, but hey, that's faster than scanning every single point, right? Maybe so, but the cost of initializing a kdbush may not make it worth it. I'd suggest sticking with the basic Closest Point Hover approach until your application requires better lookup performance.

Vladimir also provides a similar package called RBush with a different set of trade-offs. Check it out and see what works best for your particular use case.

Code

1import * as React from 'react' // v17.0.2
2import { extent, least, min } from 'd3-array' // v^2.12.1
3import { csvParse } from 'd3-dsv' // v^2.0.0
4import { format } from 'd3-format' // v^2.0.0
5import { lab, rgb } from 'd3-color' // v^2.0.0
6import { scaleLinear, scaleSequential, scaleSqrt } from 'd3-scale' // v^3.2.4
7import { interpolateBuPu } from 'd3-scale-chromatic' // v^2.0.0
8import { pointer } from 'd3-selection' // v^2.0.0
9import KDBush from 'kdbush' // an alternative to d3-quadtree!
10
11const Scatterplot = ({}) => {
12 const data = useMovieData()
13
14 const width = 800
15 const height = 400
16 const margin = { top: 10, right: 100, bottom: 30, left: 50 }
17 const innerWidth = width - margin.left - margin.right
18 const innerHeight = height - margin.top - margin.bottom
19
20 // read from pre-defined metric/dimension ("fields") bundles
21 const xField = fields.revenue
22 const yField = fields.vote_average
23 const rField = fields.vote_count
24 const colorField = fields.vote_average
25 const labelField = fields.original_title
26
27 // optionally pull out values into local variables
28 const { accessor: xAccessor, title: xTitle, formatter: xFormatter } = xField
29 const { accessor: yAccessor, title: yTitle, formatter: yFormatter } = yField
30 const { accessor: rAccessor } = rField
31 const { accessor: colorAccessor } = colorField
32
33 // memoize creating our scales so we can optimize re-renders with React.memo
34 // (e.g. <Points> only re-renders when its props change)
35 const { xScale, yScale, rScale, colorScale } = React.useMemo(() => {
36 if (!data) return {}
37 const xExtent = extent(data, xAccessor)
38 const yExtent = extent(data, yAccessor)
39 const rExtent = extent(data, rAccessor)
40 const colorExtent = extent(data, colorAccessor)
41 // const colorDomain = Array.from(new Set(data.map(colorAccessor))).sort()
42
43 // const radius = 4
44 const xScale = scaleLinear().domain(xExtent).range([0, innerWidth])
45 const yScale = scaleLinear().domain(yExtent).range([innerHeight, 0])
46 const rScale = scaleSqrt().domain(rExtent).range([2, 16])
47 const colorScale = scaleSequential(interpolateBuPu).domain(colorExtent)
48 // const colorScale = scaleOrdinal().domain(colorDomain).range(tableau20)
49
50 return {
51 xScale,
52 yScale,
53 rScale,
54 colorScale,
55 }
56 }, [
57 colorAccessor,
58 data,
59 innerHeight,
60 innerWidth,
61 rAccessor,
62 xAccessor,
63 yAccessor,
64 ])
65
66 // interaction setup
67 const interactionRef = React.useRef(null)
68 const hoverPoint = useClosestHoverPointKDBush({
69 interactionRef,
70 data,
71 xScale,
72 xAccessor,
73 yScale,
74 yAccessor,
75 radius: 60,
76 })
77
78 if (!data) return <div style={{ width, height }} />
79
80 return (
81 <div style={{ width }} className="relative">
82 <svg width={width} height={height}>
83 <g transform={`translate(${margin.left} ${margin.top})`}>
84 <XAxis
85 xScale={xScale}
86 formatter={xFormatter}
87 title={xTitle}
88 innerHeight={innerHeight}
89 gridLineHeight={innerHeight}
90 />
91 <YAxis
92 gridLineWidth={innerWidth}
93 yScale={yScale}
94 formatter={yFormatter}
95 title={yTitle}
96 />
97 <Points
98 data={data}
99 xScale={xScale}
100 xAccessor={xAccessor}
101 yScale={yScale}
102 yAccessor={yAccessor}
103 rScale={rScale}
104 rAccessor={rAccessor}
105 colorScale={colorScale}
106 colorAccessor={colorAccessor}
107 />
108 <HoverPoint
109 labelField={labelField}
110 xScale={xScale}
111 xField={xField}
112 yScale={yScale}
113 yField={yField}
114 rScale={rScale}
115 rField={rField}
116 colorScale={colorScale}
117 colorField={colorField}
118 hoverPoint={hoverPoint}
119 />
120
121 <rect
122 /* this node absorbs all mouse events */
123 ref={interactionRef}
124 width={innerWidth}
125 height={innerHeight}
126 x={0}
127 y={0}
128 fill="tomato"
129 fillOpacity={0}
130 />
131 </g>
132 </svg>
133 </div>
134 )
135}
136
137export default Scatterplot
138
139/**
140 * Custom hook to get the closest point to the mouse based on
141 * a kdbush. Supports a max distance from
142 * the mouse via the radius prop. You must provide an ref to a
143 * DOM node that can be used to capture the mouse, typically a
144 * <rect> or <g> that covers the entire visualization.
145 */
146function useClosestHoverPointKDBush({
147 interactionRef,
148 data,
149 xScale,
150 xAccessor,
151 yScale,
152 yAccessor,
153 radius,
154}) {
155 // capture our hover point or undefined if none
156 const [hoverPoint, setHoverPoint] = React.useState(undefined)
157
158 // we can throttle our updates by using requestAnimationFrame (raf)
159 const rafRef = React.useRef(null)
160
161 // precompute the quadtree
162 const kdbush = React.useMemo(() => {
163 if (data == null) return null
164
165 const kdbush = new KDBush(
166 data,
167 (d) => xScale(xAccessor(d)),
168 (d) => yScale(yAccessor(d))
169 )
170
171 return kdbush
172 }, [data, xScale, xAccessor, yScale, yAccessor])
173
174 React.useEffect(() => {
175 const interactionRect = interactionRef.current
176 if (interactionRect == null) return
177
178 const handleMouseMove = (evt) => {
179 // here we use d3-selection's pointer. You could also try react-use useMouse.
180 const [mouseX, mouseY] = pointer(evt)
181
182 // if we already had a pending update, cancel it in favour of this one
183 if (rafRef.current) {
184 cancelAnimationFrame(rafRef.current)
185 }
186
187 rafRef.current = requestAnimationFrame(() => {
188 // find point indices within radius via kdbush
189 const indicesWithinRadius = kdbush.within(mouseX, mouseY, radius)
190
191 let newHoverPoint
192 if (indicesWithinRadius.length) {
193 // search all the points we got back to find the
194 // closest to the mouse (there may be multiple results
195 // sorted in array order, which isn't helpful for us)
196 const closestIndex = least(indicesWithinRadius, (i) => {
197 const x = xScale(xAccessor(data[i]))
198 const y = yScale(yAccessor(data[i]))
199 return Math.hypot(x - mouseX, y - mouseY)
200 })
201 newHoverPoint = data[closestIndex]
202 } else {
203 newHoverPoint = undefined
204 }
205 setHoverPoint(newHoverPoint)
206 })
207 }
208 interactionRect.addEventListener('mousemove', handleMouseMove)
209
210 // make sure we handle when the mouse leaves the interaction area to remove
211 // our active hover point
212 const handleMouseLeave = () => setHoverPoint(undefined)
213 interactionRect.addEventListener('mouseleave', handleMouseLeave)
214
215 // cleanup our listeners
216 return () => {
217 interactionRect.removeEventListener('mousemove', handleMouseMove)
218 interactionRect.removeEventListener('mouseleave', handleMouseLeave)
219 }
220 }, [
221 interactionRef,
222 data,
223 xScale,
224 kdbush,
225 yScale,
226 radius,
227 xAccessor,
228 yAccessor,
229 ])
230
231 return hoverPoint
232}
233
234/** draws our hover marks: a crosshair + point + basic tooltip */
235const HoverPoint = ({
236 hoverPoint,
237 xScale,
238 xField,
239 yField,
240 yScale,
241 rScale,
242 rField,
243 labelField,
244 color = 'cyan',
245}) => {
246 if (!hoverPoint) return null
247
248 const d = hoverPoint
249 const x = xScale(xField.accessor(d))
250 const y = yScale(yField.accessor(d))
251 const r = rScale?.(rField.accessor(d))
252 const darkerColor = darker(color)
253
254 const [xPixelMin, xPixelMax] = xScale.range()
255 const [yPixelMin, yPixelMax] = yScale.range()
256
257 return (
258 <g className="pointer-events-none">
259 <g data-testid="xCrosshair">
260 <line
261 x1={xPixelMin}
262 x2={xPixelMax}
263 y1={y}
264 y2={y}
265 stroke="#fff"
266 strokeWidth={4}
267 />
268 <line
269 x1={xPixelMin}
270 x2={xPixelMax}
271 y1={y}
272 y2={y}
273 stroke={darkerColor}
274 strokeWidth={1}
275 />
276 </g>
277 <g data-testid="yCrosshair">
278 <line
279 y1={yPixelMin}
280 y2={yPixelMax}
281 x1={x}
282 x2={x}
283 stroke="#fff"
284 strokeWidth={4}
285 />
286 <line
287 y1={yPixelMin}
288 y2={yPixelMax}
289 x1={x}
290 x2={x}
291 stroke={darkerColor}
292 strokeWidth={1}
293 />
294 </g>
295 <circle cx={x} cy={y} r={r} fill={color} stroke="#fff" strokeWidth={4} />
296 <circle
297 cx={x}
298 cy={y}
299 r={r}
300 fill={color}
301 stroke={darkerColor}
302 strokeWidth={2}
303 />
304 <g transform={`translate(${x + 8} ${y + 4})`}>
305 <OutlinedSvgText
306 stroke="#fff"
307 strokeWidth={5}
308 className="text-sm font-bold"
309 dy="0.8em"
310 >
311 {labelField.accessor(d)}
312 </OutlinedSvgText>
313 <OutlinedSvgText
314 stroke="#fff"
315 strokeWidth={5}
316 className="text-xs"
317 dy="0.8em"
318 y={16}
319 >
320 {`${xField.title}: ${xField.formatter(xField.accessor(d))}`}
321 </OutlinedSvgText>
322 <OutlinedSvgText
323 stroke="#fff"
324 strokeWidth={5}
325 className="text-xs"
326 dy="0.8em"
327 y={30}
328 >
329 {`${yField.title}: ${yField.formatter(yField.accessor(d))}`}
330 </OutlinedSvgText>
331 </g>
332 </g>
333 )
334}
335
336/**
337 * A memoized component that renders all our points, but only re-renders
338 * when its props change.
339 */
340const Points = React.memo(
341 ({
342 data,
343 xScale,
344 xAccessor,
345 yAccessor,
346 yScale,
347 rScale,
348 rAccessor,
349 radius = 8,
350 colorScale,
351 colorAccessor,
352 defaultColor = 'tomato',
353 onHover,
354 }) => {
355 return (
356 <g data-testid="Points">
357 {data.map((d, i) => {
358 // const x = (width * (d.revenue - minRevenue)) / (maxRevenue - minRevenue)
359 const x = xScale(xAccessor(d))
360 const y = yScale(yAccessor(d))
361 const r = rScale?.(rAccessor(d)) ?? radius
362 const color = colorScale?.(colorAccessor(d)) ?? defaultColor
363 const darkerColor = darker(color)
364
365 return (
366 <circle
367 key={d.id ?? i}
368 cx={x}
369 cy={y}
370 r={r}
371 fill={color}
372 stroke={darkerColor}
373 strokeWidth={1}
374 strokeOpacity={1}
375 fillOpacity={1}
376 onClick={() => console.log(d)}
377 onMouseEnter={onHover ? () => onHover(d) : null}
378 onMouseLeave={onHover ? () => onHover(undefined) : null}
379 />
380 )
381 })}
382 </g>
383 )
384 }
385)
386
387/** dynamically create a darker color */
388function darker(color, factor = 0.85) {
389 const labColor = lab(color)
390 labColor.l *= factor
391
392 // rgb doesn't correspond to visual perception, but is
393 // easy for computers
394 // const rgbColor = rgb(color)
395 // rgbColor.r *= 0.8
396 // rgbColor.g *= 0.8
397 // rgbColor.b *= 0.8
398
399 // rgb(100, 50, 50);
400 // rgb(75, 25, 25); // is this half has light perceptually?
401 return labColor.toString()
402}
403
404/** fancier way of getting a nice svg text stroke */
405const OutlinedSvgText = ({ stroke, strokeWidth, children, ...other }) => {
406 return (
407 <>
408 <text stroke={stroke} strokeWidth={strokeWidth} {...other}>
409 {children}
410 </text>
411 <text {...other}>{children}</text>
412 </>
413 )
414}
415
416/** determine number of ticks based on space available */
417function numTicksForPixels(pixelsAvailable, pixelsPerTick = 70) {
418 return Math.floor(Math.abs(pixelsAvailable) / pixelsPerTick)
419}
420
421/** Y-axis with title and grid lines */
422const YAxis = ({ yScale, title, formatter, gridLineWidth }) => {
423 const [yMin, yMax] = yScale.range()
424 const ticks = yScale.ticks(numTicksForPixels(yMax - yMin, 50))
425
426 return (
427 <g data-testid="YAxis">
428 <OutlinedSvgText
429 stroke="#fff"
430 strokeWidth={2.5}
431 dx={4}
432 dy="0.8em"
433 fill="var(--gray-600)"
434 className="font-semibold text-2xs"
435 >
436 {title}
437 </OutlinedSvgText>
438
439 <line x1={0} x2={0} y1={yMin} y2={yMax} stroke="var(--gray-400)" />
440 {ticks.map((tick) => {
441 const y = yScale(tick)
442 return (
443 <g key={tick} transform={`translate(0 ${y})`}>
444 <text
445 dy="0.34em"
446 textAnchor="end"
447 dx={-12}
448 fill="currentColor"
449 className="text-gray-400 text-2xs"
450 >
451 {formatter(tick)}
452 </text>
453 <line
454 x1={0}
455 x2={-8}
456 stroke="var(--gray-300)"
457 data-testid="tickmark"
458 />
459 {gridLineWidth ? (
460 <line
461 x1={0}
462 x2={gridLineWidth}
463 stroke="var(--gray-200)"
464 strokeOpacity={0.8}
465 data-testid="gridline"
466 />
467 ) : null}
468 </g>
469 )
470 })}
471 </g>
472 )
473}
474
475/** X-axis with title and grid lines */
476const XAxis = ({ xScale, title, formatter, innerHeight, gridLineHeight }) => {
477 const [xMin, xMax] = xScale.range()
478 const ticks = xScale.ticks(numTicksForPixels(xMax - xMin))
479
480 return (
481 <g data-testid="XAxis" transform={`translate(0 ${innerHeight})`}>
482 <text
483 x={xMax}
484 textAnchor="end"
485 dy={-4}
486 fill="var(--gray-600)"
487 className="font-semibold text-2xs text-shadow-white-stroke"
488 >
489 {title}
490 </text>
491
492 <line x1={xMin} x2={xMax} y1={0} y2={0} stroke="var(--gray-400)" />
493 {ticks.map((tick) => {
494 const x = xScale(tick)
495 return (
496 <g key={tick} transform={`translate(${x} 0)`}>
497 <text
498 y={10}
499 dy="0.8em"
500 textAnchor="middle"
501 fill="currentColor"
502 className="text-gray-400 text-2xs"
503 >
504 {formatter(tick)}
505 </text>
506 <line
507 y1={0}
508 y2={8}
509 stroke="var(--gray-300)"
510 data-testid="tickmark"
511 />
512 {gridLineHeight ? (
513 <line
514 y1={0}
515 y2={-gridLineHeight}
516 stroke="var(--gray-200)"
517 strokeOpacity={0.8}
518 data-testid="gridline"
519 />
520 ) : null}
521 </g>
522 )
523 })}
524 </g>
525 )
526}
527
528// fetch our data from CSV and translate to JSON
529const useMovieData = () => {
530 const [data, setData] = React.useState(undefined)
531
532 React.useEffect(() => {
533 fetch('/datasets/tmdb_1000_movies_small.csv')
534 // fetch('/datasets/tmdb_5000_movies.csv')
535 .then((response) => response.text())
536 .then((csvString) => {
537 const data = csvParse(csvString, (row) => {
538 return {
539 budget: +row.budget,
540 vote_average: +row.vote_average,
541 vote_count: +row.vote_count,
542 genres: JSON.parse(row.genres),
543 primary_genre: JSON.parse(row.genres)[0]?.name,
544 revenue: +row.revenue,
545 original_title: row.original_title,
546 }
547 })
548 .filter((d) => d.revenue > 0)
549 .slice(0, 30)
550
551 console.log('[data]', data)
552
553 setData(data)
554 })
555 }, [])
556
557 return data
558}
559
560// very lazy large number money formatter ($1.5M, $1.65B etc)
561const bigMoneyFormat = (value) => {
562 if (value == null) return value
563 const formatted = format('$~s')(value)
564 return formatted.replace(/G$/, 'B')
565}
566
567// metrics (numeric) + dimensions (non-numeric) = fields
568const fields = {
569 revenue: {
570 accessor: (d) => d.revenue,
571 title: 'Revenue',
572 formatter: bigMoneyFormat,
573 },
574 budget: {
575 accessor: (d) => d.budget,
576 title: 'Budget',
577 formatter: bigMoneyFormat,
578 },
579 vote_average: {
580 accessor: (d) => d.vote_average,
581 title: 'Vote Average out of 10',
582 formatter: format('.1f'),
583 },
584 vote_count: {
585 accessor: (d) => d.vote_count,
586 title: 'Vote Count',
587 formatter: format('.1f'),
588 },
589 primary_genre: {
590 accessor: (d) => d.primary_genre,
591 title: 'Primary Genre',
592 formatter: (d) => d,
593 },
594 original_title: {
595 accessor: (d) => d.original_title,
596 title: 'Original Title',
597 formatter: (d) => d,
598 },
599}
600