Responsive via ResizeObserver

Description

This approach to resizing our charts to fit their container (aka "responsive sizing"), makes use of a more modern web platform tool: the ResizeObserver. It's likely to be more performant than the previously covered approaches using AutoSizer and using withAutoSizer. The basic idea is that we can observe whenever the container changes size and then update some local state to use as the width and height props in our Scatterplot component. I haven't used this much myself, but it seems pretty solid. Depending on the browsers you target, you may need a polyfill.

Be sure to check out Amelia Wattenberger's post on React and D3, which covers her preferred usage of this pattern, and also check out react-measure, which may have some useful ideas for this problem too.

Code

1import * as React from 'react' // v17.0.2
2import { extent } from 'd3-array' // v^2.12.1
3import { csvParse } from 'd3-dsv' // v^2.0.0
4import { format } from 'd3-format' // v^2.0.0
5import { lab } from 'd3-color' // v^2.0.0
6import { scaleLinear, scaleSequential, scaleSqrt } from 'd3-scale' // v^3.2.4
7import { interpolateCividis } from 'd3-scale-chromatic' // v^2.0.0
8import { pointer } from 'd3-selection' // v^2.0.0
9import { groupBy, mean, n, summarize, tidy } from '@tidyjs/tidy' // v^2.1.0
10
11const ContainerResizeObserver = ({}) => {
12 // we call a custom hook to measure our container and pass
13 // the measured width and height in as props.
14 const ref = React.useRef(null)
15 const { width, height } = useResizeObserver(ref)
16
17 return (
18 <div
19 className="border border-dashed border-cyan-500"
20 ref={ref}
21 style={{ height: 400 }}
22 >
23 <Scatterplot width={width} height={height} />
24 </div>
25 )
26}
27
28/**
29 * Custom hook that uses a ResizeObserver to detect when the container
30 * changes size. This may require a polyfill to work correctly.
31 */
32const useResizeObserver = (ref) => {
33 const [{ width, height }, setWidthHeight] = React.useState({
34 width: 0,
35 height: 0,
36 })
37
38 React.useLayoutEffect(() => {
39 const divNode = ref.current
40 if (divNode == null) return
41
42 // create a resize observer
43 const observer = new ResizeObserver((entries) => {
44 if (!entries.length) return
45
46 // on resize, update our internal state with the latest values
47 const { width, height } = entries[0].contentRect
48 setWidthHeight({ width, height })
49 })
50
51 // observe our container node
52 observer.observe(divNode)
53
54 // cleanup
55 return () => {
56 observer.unobserve(divNode)
57 }
58 }, [ref])
59
60 return { width, height }
61}
62
63export default ContainerResizeObserver
64
65const Scatterplot = ({ width = 650, height = 400 }) => {
66 const data = useMovieData()
67
68 const margin = { top: 10, right: 10, bottom: 30, left: 50 }
69 const innerWidth = width - margin.left - margin.right
70 const innerHeight = height - margin.top - margin.bottom
71
72 // read from pre-defined metric/dimension ("fields") bundles
73 const xField = fields.revenue
74 const yField = fields.vote_average
75 const rField = fields.count
76 const colorField = fields.count
77 const labelField = fields.primary_genre
78
79 // optionally pull out values into local variables
80 const { accessor: xAccessor, title: xTitle, formatter: xFormatter } = xField
81 const { accessor: yAccessor, title: yTitle, formatter: yFormatter } = yField
82 const { accessor: rAccessor } = rField
83 const { accessor: colorAccessor } = colorField
84
85 // memoize creating our scales so we can optimize re-renders with React.memo
86 // (e.g. <Points> only re-renders when its props change)
87 const { xScale, yScale, rScale, colorScale } = React.useMemo(() => {
88 if (!data) return {}
89 const xExtent = extent(data, xAccessor)
90 const xDomain = padExtent(xExtent, 0.125)
91 const yExtent = extent(data, yAccessor)
92 const yDomain = padExtent(yExtent, 0.125)
93 const rExtent = extent(data, rAccessor)
94 const colorExtent = extent(data, colorAccessor)
95
96 const xScale = scaleLinear().domain(xDomain).range([0, innerWidth])
97 const yScale = scaleLinear().domain(yDomain).range([innerHeight, 0])
98 const rScale = scaleSqrt().domain(rExtent).range([2, 16])
99 const colorScale = scaleSequential(interpolateCividis).domain(colorExtent)
100
101 return {
102 xScale,
103 yScale,
104 rScale,
105 colorScale,
106 }
107 }, [
108 colorAccessor,
109 data,
110 innerHeight,
111 innerWidth,
112 rAccessor,
113 xAccessor,
114 yAccessor,
115 ])
116
117 // interaction setup
118 const interactionRef = React.useRef(null)
119 const hoverPoint = useClosestHoverPoint({
120 interactionRef,
121 data,
122 xScale,
123 xAccessor,
124 yScale,
125 yAccessor,
126 radius: 60,
127 })
128
129 if (!data) return <div style={{ width, height }} />
130
131 return (
132 <div style={{ width }} className="relative">
133 <svg width={width} height={height}>
134 <g transform={`translate(${margin.left} ${margin.top})`}>
135 <XAxis
136 xScale={xScale}
137 formatter={xFormatter}
138 title={xTitle}
139 innerHeight={innerHeight}
140 gridLineHeight={innerHeight}
141 />
142
143 <YAxis
144 gridLineWidth={innerWidth}
145 yScale={yScale}
146 formatter={yFormatter}
147 title={yTitle}
148 />
149 <XAxisTitle
150 title={xTitle}
151 xScale={xScale}
152 innerHeight={innerHeight}
153 />
154 <YAxisTitle title={yTitle} />
155 <Points
156 data={data}
157 xScale={xScale}
158 xAccessor={xAccessor}
159 yScale={yScale}
160 yAccessor={yAccessor}
161 rScale={rScale}
162 rAccessor={rAccessor}
163 colorScale={colorScale}
164 colorAccessor={colorAccessor}
165 />
166 <HoverPoint
167 labelField={labelField}
168 xScale={xScale}
169 xField={xField}
170 yScale={yScale}
171 yField={yField}
172 rScale={rScale}
173 rField={rField}
174 colorScale={colorScale}
175 colorField={colorField}
176 hoverPoint={hoverPoint}
177 />
178
179 <rect
180 /* this node absorbs all mouse events */
181 ref={interactionRef}
182 width={innerWidth}
183 height={innerHeight}
184 x={0}
185 y={0}
186 fill="tomato"
187 fillOpacity={0}
188 />
189 </g>
190 </svg>
191 </div>
192 )
193}
194
195function padExtent([min, max], paddingFactor) {
196 const delta = Math.abs(max - min)
197 const padding = delta * paddingFactor
198
199 return [min - padding, max + padding]
200 // option to treat [0, 1] as a special case
201 // return [min === 0 ? 0 : min - padding, max === 1 ? 1 : max + padding]
202}
203
204/**
205 * Custom hook to get the closest point to the mouse based on
206 * iterating through all points. Supports a max distance from
207 * the mouse via the radius prop. You must provide an ref to a
208 * DOM node that can be used to capture the mouse, typically a
209 * <rect> or <g> that covers the entire visualization.
210 */
211function useClosestHoverPoint({
212 interactionRef,
213 data,
214 xScale,
215 xAccessor,
216 yScale,
217 yAccessor,
218 radius,
219}) {
220 // capture our hover point or undefined if none
221 const [hoverPoint, setHoverPoint] = React.useState(undefined)
222
223 // we can throttle our updates by using requestAnimationFrame (raf)
224 const rafRef = React.useRef(null)
225
226 React.useEffect(() => {
227 const interactionRect = interactionRef.current
228 if (interactionRect == null) return
229
230 const handleMouseMove = (evt) => {
231 // here we use d3-selection's pointer. You could also try react-use useMouse.
232 const [mouseX, mouseY] = pointer(evt)
233
234 // if we already had a pending update, cancel it in favour of this one
235 if (rafRef.current) {
236 cancelAnimationFrame(rafRef.current)
237 }
238
239 rafRef.current = requestAnimationFrame(() => {
240 // naive iterate over all points method
241 const newHoverPoint = findClosestPoint({
242 data,
243 xScale,
244 xAccessor,
245 yScale,
246 yAccessor,
247 radius,
248 pixelX: mouseX,
249 pixelY: mouseY,
250 })
251
252 setHoverPoint(newHoverPoint)
253 })
254 }
255 interactionRect.addEventListener('mousemove', handleMouseMove)
256
257 // make sure we handle when the mouse leaves the interaction area to remove
258 // our active hover point
259 const handleMouseLeave = () => setHoverPoint(undefined)
260 interactionRect.addEventListener('mouseleave', handleMouseLeave)
261
262 // cleanup our listeners
263 return () => {
264 interactionRect.removeEventListener('mousemove', handleMouseMove)
265 interactionRect.removeEventListener('mouseleave', handleMouseLeave)
266 }
267 }, [interactionRef, data, xScale, yScale, radius, xAccessor, yAccessor])
268
269 return hoverPoint
270}
271
272// simple algorithm for finding the nearest point. uses fancy Math.hypot
273// to compute distance between a target (pixelX, pixelY) and each point.
274// supports a max distance via the radius prop.
275function findClosestPoint({
276 data,
277 xScale,
278 yScale,
279 xAccessor,
280 yAccessor,
281 pixelX,
282 pixelY,
283 radius,
284}) {
285 let closestPoint
286 let minDistance = Infinity
287 for (const d of data) {
288 const pointPixelX = xScale(xAccessor(d))
289 const pointPixelY = yScale(yAccessor(d))
290 const distance = Math.hypot(pointPixelX - pixelX, pointPixelY - pixelY)
291 if (distance < minDistance && radius != null && distance < radius) {
292 closestPoint = d
293 minDistance = distance
294 }
295 }
296
297 return closestPoint
298}
299
300/** draws our hover marks: a crosshair + point + basic tooltip */
301const HoverPoint = ({
302 hoverPoint,
303 xScale,
304 xField,
305 yField,
306 yScale,
307 rScale,
308 rField,
309 labelField,
310 color = 'cyan',
311}) => {
312 if (!hoverPoint) return null
313
314 const d = hoverPoint
315 const x = xScale(xField.accessor(d))
316 const y = yScale(yField.accessor(d))
317 const r = rScale?.(rField.accessor(d))
318 const darkerColor = darker(color)
319
320 const [xPixelMin, xPixelMax] = xScale.range()
321 const [yPixelMin, yPixelMax] = yScale.range()
322
323 return (
324 <g className="pointer-events-none">
325 <g data-testid="xCrosshair">
326 <line
327 x1={xPixelMin}
328 x2={xPixelMax}
329 y1={y}
330 y2={y}
331 stroke="#fff"
332 strokeWidth={4}
333 />
334 <line
335 x1={xPixelMin}
336 x2={xPixelMax}
337 y1={y}
338 y2={y}
339 stroke={darkerColor}
340 strokeWidth={1}
341 />
342 </g>
343 <g data-testid="yCrosshair">
344 <line
345 y1={yPixelMin}
346 y2={yPixelMax}
347 x1={x}
348 x2={x}
349 stroke="#fff"
350 strokeWidth={4}
351 />
352 <line
353 y1={yPixelMin}
354 y2={yPixelMax}
355 x1={x}
356 x2={x}
357 stroke={darkerColor}
358 strokeWidth={1}
359 />
360 </g>
361 <circle cx={x} cy={y} r={r} fill={color} stroke="#fff" strokeWidth={4} />
362 <circle
363 cx={x}
364 cy={y}
365 r={r}
366 fill={color}
367 stroke={darkerColor}
368 strokeWidth={2}
369 />
370 <g transform={`translate(${x + 8} ${y + 4})`}>
371 <OutlinedSvgText
372 stroke="#fff"
373 strokeWidth={5}
374 className="text-sm font-bold"
375 dy="0.8em"
376 >
377 {labelField.accessor(d)}
378 </OutlinedSvgText>
379 <OutlinedSvgText
380 stroke="#fff"
381 strokeWidth={5}
382 className="text-xs"
383 dy="0.8em"
384 y={16}
385 >
386 {`${xField.title}: ${xField.formatter(xField.accessor(d))}`}
387 </OutlinedSvgText>
388 <OutlinedSvgText
389 stroke="#fff"
390 strokeWidth={5}
391 className="text-xs"
392 dy="0.8em"
393 y={30}
394 >
395 {`${yField.title}: ${yField.formatter(yField.accessor(d))}`}
396 </OutlinedSvgText>
397 </g>
398 </g>
399 )
400}
401
402/**
403 * A memoized component that renders all our points, but only re-renders
404 * when its props change.
405 */
406const Points = React.memo(
407 ({
408 data,
409 xScale,
410 xAccessor,
411 yAccessor,
412 yScale,
413 rScale,
414 rAccessor,
415 radius = 8,
416 colorScale,
417 colorAccessor,
418 defaultColor = 'tomato',
419 onHover,
420 }) => {
421 return (
422 <g data-testid="Points">
423 {data.map((d, i) => {
424 // const x = (width * (d.revenue - minRevenue)) / (maxRevenue - minRevenue)
425 const x = xScale(xAccessor(d))
426 const y = yScale(yAccessor(d))
427 const r = rScale?.(rAccessor(d)) ?? radius
428 const color = colorScale?.(colorAccessor(d)) ?? defaultColor
429
430 return (
431 <circle
432 key={d.id ?? i}
433 r={r}
434 cx={x}
435 cy={y}
436 fill={color}
437 stroke={darker(color)}
438 strokeWidth={1}
439 strokeOpacity={1}
440 fillOpacity={0.5}
441 onClick={() => console.log(d)}
442 onMouseEnter={onHover ? () => onHover(d) : null}
443 onMouseLeave={onHover ? () => onHover(undefined) : null}
444 />
445 )
446 })}
447 </g>
448 )
449 }
450)
451
452function isDarkColor(color) {
453 const labColor = lab(color)
454 return labColor.l < 75
455}
456
457/** dynamically create a darker color */
458function darker(color, factor = 0.85) {
459 const labColor = lab(color)
460 labColor.l *= factor
461
462 // rgb doesn't correspond to visual perception, but is
463 // easy for computers
464 // const rgbColor = rgb(color)
465 // rgbColor.r *= 0.8
466 // rgbColor.g *= 0.8
467 // rgbColor.b *= 0.8
468
469 // rgb(100, 50, 50);
470 // rgb(75, 25, 25); // is this half has light perceptually?
471 return labColor.toString()
472}
473
474/** fancier way of getting a nice svg text stroke */
475const OutlinedSvgText = ({ stroke, strokeWidth, children, ...other }) => {
476 return (
477 <>
478 <text stroke={stroke} strokeWidth={strokeWidth} {...other}>
479 {children}
480 </text>
481 <text {...other}>{children}</text>
482 </>
483 )
484}
485
486/** determine number of ticks based on space available */
487function numTicksForPixels(pixelsAvailable, pixelsPerTick = 70) {
488 return Math.floor(Math.abs(pixelsAvailable) / pixelsPerTick)
489}
490
491const YAxisTitle = ({ title }) => {
492 return (
493 <OutlinedSvgText
494 stroke="#fff"
495 strokeWidth={2.5}
496 dx={4}
497 dy="0.8em"
498 fill="var(--gray-600)"
499 className="font-semibold text-2xs"
500 >
501 {title}
502 </OutlinedSvgText>
503 )
504}
505
506/** Y-axis with title and grid lines */
507const YAxis = ({ yScale, formatter, gridLineWidth }) => {
508 const [yMin, yMax] = yScale.range()
509 const ticks = yScale.ticks(numTicksForPixels(yMax - yMin, 50))
510
511 return (
512 <g data-testid="YAxis">
513 <line x1={0} x2={0} y1={yMin} y2={yMax} stroke="var(--gray-400)" />
514 {ticks.map((tick) => {
515 const y = yScale(tick)
516 return (
517 <g key={tick} transform={`translate(0 ${y})`}>
518 <text
519 dy="0.34em"
520 textAnchor="end"
521 dx={-12}
522 fill="currentColor"
523 className="text-gray-400 text-2xs"
524 >
525 {formatter(tick)}
526 </text>
527 <line
528 x1={0}
529 x2={-8}
530 stroke="var(--gray-300)"
531 data-testid="tickmark"
532 />
533 {gridLineWidth ? (
534 <line
535 x1={0}
536 x2={gridLineWidth}
537 stroke="var(--gray-200)"
538 strokeOpacity={0.8}
539 data-testid="gridline"
540 />
541 ) : null}
542 </g>
543 )
544 })}
545 </g>
546 )
547}
548
549const XAxisTitle = ({ xScale, title, innerHeight }) => {
550 const [, xMax] = xScale.range()
551 return (
552 <text
553 x={xMax}
554 y={innerHeight}
555 textAnchor="end"
556 dy={-4}
557 fill="var(--gray-600)"
558 className="font-semibold text-2xs text-shadow-white-stroke"
559 >
560 {title}
561 </text>
562 )
563}
564
565/** X-axis with title and grid lines */
566const XAxis = ({ xScale, title, formatter, innerHeight, gridLineHeight }) => {
567 const [xMin, xMax] = xScale.range()
568 const ticks = xScale.ticks(numTicksForPixels(xMax - xMin))
569
570 return (
571 <g data-testid="XAxis" transform={`translate(0 ${innerHeight})`}>
572 <line x1={xMin} x2={xMax} y1={0} y2={0} stroke="var(--gray-400)" />
573 {ticks.map((tick) => {
574 const x = xScale(tick)
575 return (
576 <g key={tick} transform={`translate(${x} 0)`}>
577 <text
578 y={10}
579 dy="0.8em"
580 textAnchor="middle"
581 fill="currentColor"
582 className="text-gray-400 text-2xs"
583 >
584 {formatter(tick)}
585 </text>
586 <line
587 y1={0}
588 y2={8}
589 stroke="var(--gray-300)"
590 data-testid="tickmark"
591 />
592 {gridLineHeight ? (
593 <line
594 y1={0}
595 y2={-gridLineHeight}
596 stroke="var(--gray-200)"
597 strokeOpacity={0.8}
598 data-testid="gridline"
599 />
600 ) : null}
601 </g>
602 )
603 })}
604 </g>
605 )
606}
607
608// fetch our data from CSV and translate to JSON
609const useMovieData = () => {
610 const [data, setData] = React.useState(undefined)
611
612 React.useEffect(() => {
613 fetch('/datasets/tmdb_1000_movies_small.csv')
614 // fetch('/datasets/tmdb_5000_movies.csv')
615 .then((response) => response.text())
616 .then((csvString) => {
617 const data = csvParse(csvString, (row) => {
618 return {
619 budget: +row.budget,
620 vote_average: +row.vote_average,
621 vote_count: +row.vote_count,
622 genres: JSON.parse(row.genres),
623 primary_genre: JSON.parse(row.genres)[0]?.name,
624 revenue: +row.revenue,
625 original_title: row.original_title,
626 }
627 }).filter((d) => d.revenue > 0)
628 console.log('[data]', data)
629
630 // group by genre and summarize
631 const groupedData = tidy(
632 data,
633 groupBy(
634 ['primary_genre'],
635 [
636 summarize({
637 revenue: mean('revenue'),
638 vote_average: mean('vote_average'),
639 count: n(),
640 }),
641 ]
642 )
643 )
644
645 console.log('groupedData', groupedData)
646
647 setData(groupedData)
648 })
649 }, [])
650
651 return data
652}
653
654// very lazy large number money formatter ($1.5M, $1.65B etc)
655const bigMoneyFormat = (value) => {
656 if (value == null) return value
657 const formatted = format('$~s')(value)
658 return formatted.replace(/G$/, 'B')
659}
660
661// metrics (numeric) + dimensions (non-numeric) = fields
662const fields = {
663 revenue: {
664 accessor: (d) => d.revenue,
665 title: 'Revenue',
666 formatter: bigMoneyFormat,
667 },
668 budget: {
669 accessor: (d) => d.budget,
670 title: 'Budget',
671 formatter: bigMoneyFormat,
672 },
673 vote_average: {
674 accessor: (d) => d.vote_average,
675 title: 'Vote Average out of 10',
676 formatter: format('.1f'),
677 },
678 vote_count: {
679 accessor: (d) => d.vote_count,
680 title: 'Vote Count',
681 formatter: format('.1f'),
682 },
683 primary_genre: {
684 accessor: (d) => d.primary_genre,
685 title: 'Primary Genre',
686 formatter: (d) => d,
687 },
688 original_title: {
689 accessor: (d) => d.original_title,
690 title: 'Original Title',
691 formatter: (d) => d,
692 },
693
694 count: {
695 accessor: (d) => d.count,
696 title: 'Num Movies in Group',
697 formatter: (d) => d,
698 },
699}
700