KDBush Hover
Description
Mapbox and internet hero Vladimir Agafonkin (aka @mourner) has provided a package called KDBush that uses a kd-tree (a more generalized version of a binary or quadtree) to help us find things in 2D space more efficiently. In this scatterplot, we'll follow the approach taken in the Quadtree Hover example, but we'll use KDBush instead of d3-quadtree.
The main difference in the code is that KDBush provides a within(x, y, radius)
method that returns all of the points within the radius – not just the closest one as we saw with find()
in d3-quadtree. This means we'll have to do a linear scan of the points within the radius to find which one is closest to the mouse, but hey, that's faster than scanning every single point, right? Maybe so, but the cost of initializing a kdbush may not make it worth it. I'd suggest sticking with the basic Closest Point Hover approach until your application requires better lookup performance.
Vladimir also provides a similar package called RBush with a different set of trade-offs. Check it out and see what works best for your particular use case.
Code
1import * as React from 'react' // v17.0.22import { extent, least, min } from 'd3-array' // v^2.12.13import { csvParse } from 'd3-dsv' // v^2.0.04import { format } from 'd3-format' // v^2.0.05import { lab, rgb } from 'd3-color' // v^2.0.06import { scaleLinear, scaleSequential, scaleSqrt } from 'd3-scale' // v^3.2.47import { interpolateBuPu } from 'd3-scale-chromatic' // v^2.0.08import { pointer } from 'd3-selection' // v^2.0.09import KDBush from 'kdbush' // an alternative to d3-quadtree!1011const Scatterplot = ({}) => {12 const data = useMovieData()1314 const width = 80015 const height = 40016 const margin = { top: 10, right: 100, bottom: 30, left: 50 }17 const innerWidth = width - margin.left - margin.right18 const innerHeight = height - margin.top - margin.bottom1920 // read from pre-defined metric/dimension ("fields") bundles21 const xField = fields.revenue22 const yField = fields.vote_average23 const rField = fields.vote_count24 const colorField = fields.vote_average25 const labelField = fields.original_title2627 // optionally pull out values into local variables28 const { accessor: xAccessor, title: xTitle, formatter: xFormatter } = xField29 const { accessor: yAccessor, title: yTitle, formatter: yFormatter } = yField30 const { accessor: rAccessor } = rField31 const { accessor: colorAccessor } = colorField3233 // memoize creating our scales so we can optimize re-renders with React.memo34 // (e.g. <Points> only re-renders when its props change)35 const { xScale, yScale, rScale, colorScale } = React.useMemo(() => {36 if (!data) return {}37 const xExtent = extent(data, xAccessor)38 const yExtent = extent(data, yAccessor)39 const rExtent = extent(data, rAccessor)40 const colorExtent = extent(data, colorAccessor)41 // const colorDomain = Array.from(new Set(data.map(colorAccessor))).sort()4243 // const radius = 444 const xScale = scaleLinear().domain(xExtent).range([0, innerWidth])45 const yScale = scaleLinear().domain(yExtent).range([innerHeight, 0])46 const rScale = scaleSqrt().domain(rExtent).range([2, 16])47 const colorScale = scaleSequential(interpolateBuPu).domain(colorExtent)48 // const colorScale = scaleOrdinal().domain(colorDomain).range(tableau20)4950 return {51 xScale,52 yScale,53 rScale,54 colorScale,55 }56 }, [57 colorAccessor,58 data,59 innerHeight,60 innerWidth,61 rAccessor,62 xAccessor,63 yAccessor,64 ])6566 // interaction setup67 const interactionRef = React.useRef(null)68 const hoverPoint = useClosestHoverPointKDBush({69 interactionRef,70 data,71 xScale,72 xAccessor,73 yScale,74 yAccessor,75 radius: 60,76 })7778 if (!data) return <div style={{ width, height }} />7980 return (81 <div style={{ width }} className="relative">82 <svg width={width} height={height}>83 <g transform={`translate(${margin.left} ${margin.top})`}>84 <XAxis85 xScale={xScale}86 formatter={xFormatter}87 title={xTitle}88 innerHeight={innerHeight}89 gridLineHeight={innerHeight}90 />91 <YAxis92 gridLineWidth={innerWidth}93 yScale={yScale}94 formatter={yFormatter}95 title={yTitle}96 />97 <Points98 data={data}99 xScale={xScale}100 xAccessor={xAccessor}101 yScale={yScale}102 yAccessor={yAccessor}103 rScale={rScale}104 rAccessor={rAccessor}105 colorScale={colorScale}106 colorAccessor={colorAccessor}107 />108 <HoverPoint109 labelField={labelField}110 xScale={xScale}111 xField={xField}112 yScale={yScale}113 yField={yField}114 rScale={rScale}115 rField={rField}116 colorScale={colorScale}117 colorField={colorField}118 hoverPoint={hoverPoint}119 />120121 <rect122 /* this node absorbs all mouse events */123 ref={interactionRef}124 width={innerWidth}125 height={innerHeight}126 x={0}127 y={0}128 fill="tomato"129 fillOpacity={0}130 />131 </g>132 </svg>133 </div>134 )135}136137export default Scatterplot138139/**140 * Custom hook to get the closest point to the mouse based on141 * a kdbush. Supports a max distance from142 * the mouse via the radius prop. You must provide an ref to a143 * DOM node that can be used to capture the mouse, typically a144 * <rect> or <g> that covers the entire visualization.145 */146function useClosestHoverPointKDBush({147 interactionRef,148 data,149 xScale,150 xAccessor,151 yScale,152 yAccessor,153 radius,154}) {155 // capture our hover point or undefined if none156 const [hoverPoint, setHoverPoint] = React.useState(undefined)157158 // we can throttle our updates by using requestAnimationFrame (raf)159 const rafRef = React.useRef(null)160161 // precompute the quadtree162 const kdbush = React.useMemo(() => {163 if (data == null) return null164165 const kdbush = new KDBush(166 data,167 (d) => xScale(xAccessor(d)),168 (d) => yScale(yAccessor(d))169 )170171 return kdbush172 }, [data, xScale, xAccessor, yScale, yAccessor])173174 React.useEffect(() => {175 const interactionRect = interactionRef.current176 if (interactionRect == null) return177178 const handleMouseMove = (evt) => {179 // here we use d3-selection's pointer. You could also try react-use useMouse.180 const [mouseX, mouseY] = pointer(evt)181182 // if we already had a pending update, cancel it in favour of this one183 if (rafRef.current) {184 cancelAnimationFrame(rafRef.current)185 }186187 rafRef.current = requestAnimationFrame(() => {188 // find point indices within radius via kdbush189 const indicesWithinRadius = kdbush.within(mouseX, mouseY, radius)190191 let newHoverPoint192 if (indicesWithinRadius.length) {193 // search all the points we got back to find the194 // closest to the mouse (there may be multiple results195 // sorted in array order, which isn't helpful for us)196 const closestIndex = least(indicesWithinRadius, (i) => {197 const x = xScale(xAccessor(data[i]))198 const y = yScale(yAccessor(data[i]))199 return Math.hypot(x - mouseX, y - mouseY)200 })201 newHoverPoint = data[closestIndex]202 } else {203 newHoverPoint = undefined204 }205 setHoverPoint(newHoverPoint)206 })207 }208 interactionRect.addEventListener('mousemove', handleMouseMove)209210 // make sure we handle when the mouse leaves the interaction area to remove211 // our active hover point212 const handleMouseLeave = () => setHoverPoint(undefined)213 interactionRect.addEventListener('mouseleave', handleMouseLeave)214215 // cleanup our listeners216 return () => {217 interactionRect.removeEventListener('mousemove', handleMouseMove)218 interactionRect.removeEventListener('mouseleave', handleMouseLeave)219 }220 }, [221 interactionRef,222 data,223 xScale,224 kdbush,225 yScale,226 radius,227 xAccessor,228 yAccessor,229 ])230231 return hoverPoint232}233234/** draws our hover marks: a crosshair + point + basic tooltip */235const HoverPoint = ({236 hoverPoint,237 xScale,238 xField,239 yField,240 yScale,241 rScale,242 rField,243 labelField,244 color = 'cyan',245}) => {246 if (!hoverPoint) return null247248 const d = hoverPoint249 const x = xScale(xField.accessor(d))250 const y = yScale(yField.accessor(d))251 const r = rScale?.(rField.accessor(d))252 const darkerColor = darker(color)253254 const [xPixelMin, xPixelMax] = xScale.range()255 const [yPixelMin, yPixelMax] = yScale.range()256257 return (258 <g className="pointer-events-none">259 <g data-testid="xCrosshair">260 <line261 x1={xPixelMin}262 x2={xPixelMax}263 y1={y}264 y2={y}265 stroke="#fff"266 strokeWidth={4}267 />268 <line269 x1={xPixelMin}270 x2={xPixelMax}271 y1={y}272 y2={y}273 stroke={darkerColor}274 strokeWidth={1}275 />276 </g>277 <g data-testid="yCrosshair">278 <line279 y1={yPixelMin}280 y2={yPixelMax}281 x1={x}282 x2={x}283 stroke="#fff"284 strokeWidth={4}285 />286 <line287 y1={yPixelMin}288 y2={yPixelMax}289 x1={x}290 x2={x}291 stroke={darkerColor}292 strokeWidth={1}293 />294 </g>295 <circle cx={x} cy={y} r={r} fill={color} stroke="#fff" strokeWidth={4} />296 <circle297 cx={x}298 cy={y}299 r={r}300 fill={color}301 stroke={darkerColor}302 strokeWidth={2}303 />304 <g transform={`translate(${x + 8} ${y + 4})`}>305 <OutlinedSvgText306 stroke="#fff"307 strokeWidth={5}308 className="text-sm font-bold"309 dy="0.8em"310 >311 {labelField.accessor(d)}312 </OutlinedSvgText>313 <OutlinedSvgText314 stroke="#fff"315 strokeWidth={5}316 className="text-xs"317 dy="0.8em"318 y={16}319 >320 {`${xField.title}: ${xField.formatter(xField.accessor(d))}`}321 </OutlinedSvgText>322 <OutlinedSvgText323 stroke="#fff"324 strokeWidth={5}325 className="text-xs"326 dy="0.8em"327 y={30}328 >329 {`${yField.title}: ${yField.formatter(yField.accessor(d))}`}330 </OutlinedSvgText>331 </g>332 </g>333 )334}335336/**337 * A memoized component that renders all our points, but only re-renders338 * when its props change.339 */340const Points = React.memo(341 ({342 data,343 xScale,344 xAccessor,345 yAccessor,346 yScale,347 rScale,348 rAccessor,349 radius = 8,350 colorScale,351 colorAccessor,352 defaultColor = 'tomato',353 onHover,354 }) => {355 return (356 <g data-testid="Points">357 {data.map((d, i) => {358 // const x = (width * (d.revenue - minRevenue)) / (maxRevenue - minRevenue)359 const x = xScale(xAccessor(d))360 const y = yScale(yAccessor(d))361 const r = rScale?.(rAccessor(d)) ?? radius362 const color = colorScale?.(colorAccessor(d)) ?? defaultColor363 const darkerColor = darker(color)364365 return (366 <circle367 key={d.id ?? i}368 cx={x}369 cy={y}370 r={r}371 fill={color}372 stroke={darkerColor}373 strokeWidth={1}374 strokeOpacity={1}375 fillOpacity={1}376 onClick={() => console.log(d)}377 onMouseEnter={onHover ? () => onHover(d) : null}378 onMouseLeave={onHover ? () => onHover(undefined) : null}379 />380 )381 })}382 </g>383 )384 }385)386387/** dynamically create a darker color */388function darker(color, factor = 0.85) {389 const labColor = lab(color)390 labColor.l *= factor391392 // rgb doesn't correspond to visual perception, but is393 // easy for computers394 // const rgbColor = rgb(color)395 // rgbColor.r *= 0.8396 // rgbColor.g *= 0.8397 // rgbColor.b *= 0.8398399 // rgb(100, 50, 50);400 // rgb(75, 25, 25); // is this half has light perceptually?401 return labColor.toString()402}403404/** fancier way of getting a nice svg text stroke */405const OutlinedSvgText = ({ stroke, strokeWidth, children, ...other }) => {406 return (407 <>408 <text stroke={stroke} strokeWidth={strokeWidth} {...other}>409 {children}410 </text>411 <text {...other}>{children}</text>412 </>413 )414}415416/** determine number of ticks based on space available */417function numTicksForPixels(pixelsAvailable, pixelsPerTick = 70) {418 return Math.floor(Math.abs(pixelsAvailable) / pixelsPerTick)419}420421/** Y-axis with title and grid lines */422const YAxis = ({ yScale, title, formatter, gridLineWidth }) => {423 const [yMin, yMax] = yScale.range()424 const ticks = yScale.ticks(numTicksForPixels(yMax - yMin, 50))425426 return (427 <g data-testid="YAxis">428 <OutlinedSvgText429 stroke="#fff"430 strokeWidth={2.5}431 dx={4}432 dy="0.8em"433 fill="var(--gray-600)"434 className="font-semibold text-2xs"435 >436 {title}437 </OutlinedSvgText>438439 <line x1={0} x2={0} y1={yMin} y2={yMax} stroke="var(--gray-400)" />440 {ticks.map((tick) => {441 const y = yScale(tick)442 return (443 <g key={tick} transform={`translate(0 ${y})`}>444 <text445 dy="0.34em"446 textAnchor="end"447 dx={-12}448 fill="currentColor"449 className="text-gray-400 text-2xs"450 >451 {formatter(tick)}452 </text>453 <line454 x1={0}455 x2={-8}456 stroke="var(--gray-300)"457 data-testid="tickmark"458 />459 {gridLineWidth ? (460 <line461 x1={0}462 x2={gridLineWidth}463 stroke="var(--gray-200)"464 strokeOpacity={0.8}465 data-testid="gridline"466 />467 ) : null}468 </g>469 )470 })}471 </g>472 )473}474475/** X-axis with title and grid lines */476const XAxis = ({ xScale, title, formatter, innerHeight, gridLineHeight }) => {477 const [xMin, xMax] = xScale.range()478 const ticks = xScale.ticks(numTicksForPixels(xMax - xMin))479480 return (481 <g data-testid="XAxis" transform={`translate(0 ${innerHeight})`}>482 <text483 x={xMax}484 textAnchor="end"485 dy={-4}486 fill="var(--gray-600)"487 className="font-semibold text-2xs text-shadow-white-stroke"488 >489 {title}490 </text>491492 <line x1={xMin} x2={xMax} y1={0} y2={0} stroke="var(--gray-400)" />493 {ticks.map((tick) => {494 const x = xScale(tick)495 return (496 <g key={tick} transform={`translate(${x} 0)`}>497 <text498 y={10}499 dy="0.8em"500 textAnchor="middle"501 fill="currentColor"502 className="text-gray-400 text-2xs"503 >504 {formatter(tick)}505 </text>506 <line507 y1={0}508 y2={8}509 stroke="var(--gray-300)"510 data-testid="tickmark"511 />512 {gridLineHeight ? (513 <line514 y1={0}515 y2={-gridLineHeight}516 stroke="var(--gray-200)"517 strokeOpacity={0.8}518 data-testid="gridline"519 />520 ) : null}521 </g>522 )523 })}524 </g>525 )526}527528// fetch our data from CSV and translate to JSON529const useMovieData = () => {530 const [data, setData] = React.useState(undefined)531532 React.useEffect(() => {533 fetch('/datasets/tmdb_1000_movies_small.csv')534 // fetch('/datasets/tmdb_5000_movies.csv')535 .then((response) => response.text())536 .then((csvString) => {537 const data = csvParse(csvString, (row) => {538 return {539 budget: +row.budget,540 vote_average: +row.vote_average,541 vote_count: +row.vote_count,542 genres: JSON.parse(row.genres),543 primary_genre: JSON.parse(row.genres)[0]?.name,544 revenue: +row.revenue,545 original_title: row.original_title,546 }547 })548 .filter((d) => d.revenue > 0)549 .slice(0, 30)550551 console.log('[data]', data)552553 setData(data)554 })555 }, [])556557 return data558}559560// very lazy large number money formatter ($1.5M, $1.65B etc)561const bigMoneyFormat = (value) => {562 if (value == null) return value563 const formatted = format('$~s')(value)564 return formatted.replace(/G$/, 'B')565}566567// metrics (numeric) + dimensions (non-numeric) = fields568const fields = {569 revenue: {570 accessor: (d) => d.revenue,571 title: 'Revenue',572 formatter: bigMoneyFormat,573 },574 budget: {575 accessor: (d) => d.budget,576 title: 'Budget',577 formatter: bigMoneyFormat,578 },579 vote_average: {580 accessor: (d) => d.vote_average,581 title: 'Vote Average out of 10',582 formatter: format('.1f'),583 },584 vote_count: {585 accessor: (d) => d.vote_count,586 title: 'Vote Count',587 formatter: format('.1f'),588 },589 primary_genre: {590 accessor: (d) => d.primary_genre,591 title: 'Primary Genre',592 formatter: (d) => d,593 },594 original_title: {595 accessor: (d) => d.original_title,596 title: 'Original Title',597 formatter: (d) => d,598 },599}600