feat: replace custom Sankey SVG with MUI X Charts Pro Sankey

This commit is contained in:
Usman Baig
2026-03-12 22:07:20 +01:00
parent 4b10f8c1fc
commit 281a9f237a
3 changed files with 1126 additions and 372 deletions

View File

@@ -1,8 +1,18 @@
'use client' 'use client'
import { useMemo, useState } from 'react' import { useMemo } from 'react'
import { useTheme } from '@ciphera-net/ui' import { useTheme } from '@ciphera-net/ui'
import { TreeStructure } from '@phosphor-icons/react' import { TreeStructure } from '@phosphor-icons/react'
import { createTheme, ThemeProvider } from '@mui/material/styles'
import {
SankeyDataProvider,
SankeyLinkPlot,
SankeyNodePlot,
SankeyNodeLabelPlot,
SankeyTooltip,
} from '@mui/x-charts-pro/SankeyChart'
import { ChartsWrapper } from '@mui/x-charts-pro/ChartsWrapper'
import { ChartsSurface } from '@mui/x-charts-pro/ChartsSurface'
import type { PathTransition } from '@/lib/api/journeys' import type { PathTransition } from '@/lib/api/journeys'
// ─── Types ────────────────────────────────────────────────────────── // ─── Types ──────────────────────────────────────────────────────────
@@ -14,246 +24,69 @@ interface SankeyDiagramProps {
onNodeClick?: (path: string) => void onNodeClick?: (path: string) => void
} }
interface PositionedNode { // ─── Data transformation ────────────────────────────────────────────
id: string // "col:path"
path: string
column: number
flow: number
x: number
y: number
height: number
}
interface PositionedLink { const NODE_COLOR = '#FD5E0F'
id: string const EXIT_COLOR = '#595b63'
fromNode: PositionedNode
toNode: PositionedNode
sessionCount: number
sourceY: number
targetY: number
width: number
}
// ─── Layout constants ─────────────────────────────────────────────── function transformToSankeyData(transitions: PathTransition[], depth: number) {
const PADDING_X = 60
const PADDING_Y = 40
const NODE_WIDTH = 8
const NODE_GAP = 6
const MIN_NODE_HEIGHT = 4
const LABEL_MAX_LENGTH = 24
const EXIT_LABEL = '(exit)'
// ─── Helpers ────────────────────────────────────────────────────────
function truncatePath(path: string, maxLen: number): string {
if (path.length <= maxLen) return path
return path.slice(0, maxLen - 1) + '\u2026'
}
function buildSankeyLayout(
transitions: PathTransition[],
depth: number,
svgWidth: number,
svgHeight: number,
) {
if (!transitions.length) return { nodes: [], links: [] }
// ── 1. Build columns ──────────────────────────────────────────────
// columns[colIndex] = Map<path, { inFlow, outFlow }>
const numColumns = depth + 1 const numColumns = depth + 1
const columns: Map<string, { inFlow: number; outFlow: number }>[] = Array.from( const nodeMap = new Map<string, { id: string; label: string; color: string }>()
{ length: numColumns }, const links: { source: string; target: string; value: number }[] = []
() => new Map(),
) // Track flow in/out per node to compute exits
const flowIn = new Map<string, number>()
const flowOut = new Map<string, number>()
for (const t of transitions) { for (const t of transitions) {
const fromCol = t.step_index if (t.step_index >= numColumns || t.step_index + 1 >= numColumns) continue
const toCol = t.step_index + 1
if (fromCol >= numColumns || toCol >= numColumns) continue
// from node const fromId = `${t.step_index}:${t.from_path}`
const fromEntry = columns[fromCol].get(t.from_path) ?? { inFlow: 0, outFlow: 0 } const toId = `${t.step_index + 1}:${t.to_path}`
fromEntry.outFlow += t.session_count
columns[fromCol].set(t.from_path, fromEntry)
// to node if (!nodeMap.has(fromId)) {
const toEntry = columns[toCol].get(t.to_path) ?? { inFlow: 0, outFlow: 0 } nodeMap.set(fromId, { id: fromId, label: t.from_path, color: NODE_COLOR })
toEntry.inFlow += t.session_count }
columns[toCol].set(t.to_path, toEntry) if (!nodeMap.has(toId)) {
nodeMap.set(toId, { id: toId, label: t.to_path, color: NODE_COLOR })
}
links.push({ source: fromId, target: toId, value: t.session_count })
flowOut.set(fromId, (flowOut.get(fromId) ?? 0) + t.session_count)
flowIn.set(toId, (flowIn.get(toId) ?? 0) + t.session_count)
} }
// For column 0, nodes that have no inFlow — use outFlow as total flow // Add exit nodes for flows that don't continue
// For other columns, use max(inFlow, outFlow) for (const [nodeId, node] of nodeMap) {
// Also ensure column 0 nodes get their inFlow from the fact they are entry points const totalIn = flowIn.get(nodeId) ?? 0
const totalOut = flowOut.get(nodeId) ?? 0
const flow = Math.max(totalIn, totalOut)
const exitCount = flow - totalOut
// ── 2. Add exit nodes ───────────────────────────────────────────── if (exitCount > 0) {
// For each node, exitCount = inFlow - outFlow (if positive) const col = parseInt(nodeId.split(':')[0], 10)
// For column 0, exitCount = outFlow - outFlow = handled differently: if (col < numColumns - 1) {
// column 0 nodes: flow = outFlow, and if they also appear as to_path, inFlow is set const exitId = `${col + 1}:(exit)`
// Actually for column 0 the total flow IS outFlow (they are entry points) if (!nodeMap.has(exitId)) {
nodeMap.set(exitId, { id: exitId, label: '(exit)', color: EXIT_COLOR })
// Build exit transitions for each column (except last, which is all exit) }
const exitTransitions: { fromCol: number; fromPath: string; exitCount: number }[] = [] links.push({ source: nodeId, target: exitId, value: exitCount })
for (let col = 0; col < numColumns; col++) {
for (const [path, entry] of columns[col]) {
const totalFlow = col === 0 ? entry.outFlow : Math.max(entry.inFlow, entry.outFlow)
const exitCount = totalFlow - entry.outFlow
if (exitCount > 0) {
exitTransitions.push({ fromCol: col, fromPath: path, exitCount })
} }
} }
} }
// For the last column, ALL flow is exit (no outgoing transitions)
// We don't add extra exit nodes for the last column since those nodes are already endpoints
// Add exit nodes to columns (they sit in the same column, below the real nodes,
// or we add them as virtual nodes in col+1). Actually per spec: "Add virtual (exit) nodes
// at the right end of flows that don't continue" — this means we add them as targets in
// the next column. But we only do this for non-last columns.
const exitLinks: { fromCol: number; fromPath: string; exitCount: number }[] = []
for (const et of exitTransitions) {
if (et.fromCol < numColumns - 1) {
const exitCol = et.fromCol + 1
const exitEntry = columns[exitCol].get(EXIT_LABEL) ?? { inFlow: 0, outFlow: 0 }
exitEntry.inFlow += et.exitCount
columns[exitCol].set(EXIT_LABEL, exitEntry)
exitLinks.push(et)
}
}
// ── 3. Sort nodes per column and assign positions ─────────────────
const availableWidth = svgWidth - PADDING_X * 2
const availableHeight = svgHeight - PADDING_Y * 2
const colSpacing = numColumns > 1 ? availableWidth / (numColumns - 1) : 0
const positionedNodes: Map<string, PositionedNode> = new Map()
for (let col = 0; col < numColumns; col++) {
const entries = Array.from(columns[col].entries()).map(([path, entry]) => ({
path,
flow: col === 0 ? entry.outFlow : Math.max(entry.inFlow, entry.outFlow),
}))
// Sort by flow descending, but keep (exit) at bottom
entries.sort((a, b) => {
if (a.path === EXIT_LABEL) return 1
if (b.path === EXIT_LABEL) return -1
return b.flow - a.flow
})
const totalFlow = entries.reduce((sum, e) => sum + e.flow, 0)
const totalGaps = Math.max(0, entries.length - 1) * NODE_GAP
const usableHeight = availableHeight - totalGaps
let y = PADDING_Y
const x = PADDING_X + col * colSpacing
for (const entry of entries) {
const proportion = totalFlow > 0 ? entry.flow / totalFlow : 1 / entries.length
const nodeHeight = Math.max(MIN_NODE_HEIGHT, proportion * usableHeight)
const id = `${col}:${entry.path}`
positionedNodes.set(id, {
id,
path: entry.path,
column: col,
flow: entry.flow,
x,
y,
height: nodeHeight,
})
y += nodeHeight + NODE_GAP
}
}
// ── 4. Build positioned links ─────────────────────────────────────
// Track how much vertical space has been used at each node's source/target side
const sourceOffsets: Map<string, number> = new Map()
const targetOffsets: Map<string, number> = new Map()
const allLinks: {
fromId: string
toId: string
sessionCount: number
}[] = []
// Regular transitions
for (const t of transitions) {
const fromCol = t.step_index
const toCol = t.step_index + 1
if (fromCol >= numColumns || toCol >= numColumns) continue
allLinks.push({
fromId: `${fromCol}:${t.from_path}`,
toId: `${toCol}:${t.to_path}`,
sessionCount: t.session_count,
})
}
// Exit links
for (const et of exitLinks) {
allLinks.push({
fromId: `${et.fromCol}:${et.fromPath}`,
toId: `${et.fromCol + 1}:${EXIT_LABEL}`,
sessionCount: et.exitCount,
})
}
// Sort links by session count descending for better visual stacking
allLinks.sort((a, b) => b.sessionCount - a.sessionCount)
const positionedLinks: PositionedLink[] = []
for (const link of allLinks) {
const fromNode = positionedNodes.get(link.fromId)
const toNode = positionedNodes.get(link.toId)
if (!fromNode || !toNode) continue
const linkWidth = Math.max(
1,
fromNode.flow > 0 ? (link.sessionCount / fromNode.flow) * fromNode.height : 1,
)
const sourceOffset = sourceOffsets.get(link.fromId) ?? 0
const targetOffset = targetOffsets.get(link.toId) ?? 0
positionedLinks.push({
id: `${link.fromId}->${link.toId}`,
fromNode,
toNode,
sessionCount: link.sessionCount,
sourceY: fromNode.y + sourceOffset,
targetY: toNode.y + targetOffset,
width: linkWidth,
})
sourceOffsets.set(link.fromId, sourceOffset + linkWidth)
targetOffsets.set(link.toId, targetOffset + linkWidth)
}
return { return {
nodes: Array.from(positionedNodes.values()), nodes: Array.from(nodeMap.values()),
links: positionedLinks, links,
} }
} }
function buildLinkPath(link: PositionedLink): string { const valueFormatter = (value: number, context: { type: string }) => {
const sx = link.fromNode.x + NODE_WIDTH if (context.type === 'link') {
const sy = link.sourceY return `${value.toLocaleString()} sessions`
const tx = link.toNode.x }
const ty = link.targetY return `${value.toLocaleString()} sessions total`
const w = link.width
const midX = (sx + tx) / 2
return [
`M ${sx},${sy}`,
`C ${midX},${sy} ${midX},${ty} ${tx},${ty}`,
`L ${tx},${ty + w}`,
`C ${midX},${ty + w} ${midX},${sy + w} ${sx},${sy + w}`,
'Z',
].join(' ')
} }
// ─── Component ────────────────────────────────────────────────────── // ─── Component ──────────────────────────────────────────────────────
@@ -266,17 +99,21 @@ export default function SankeyDiagram({
}: SankeyDiagramProps) { }: SankeyDiagramProps) {
const { resolvedTheme } = useTheme() const { resolvedTheme } = useTheme()
const isDark = resolvedTheme === 'dark' const isDark = resolvedTheme === 'dark'
const [hoveredLink, setHoveredLink] = useState<string | null>(null)
const svgWidth = 1000 const muiTheme = useMemo(
const svgHeight = 500 () =>
createTheme({
palette: { mode: isDark ? 'dark' : 'light' },
}),
[isDark],
)
const { nodes, links } = useMemo( const data = useMemo(
() => buildSankeyLayout(transitions, depth, svgWidth, svgHeight), () => transformToSankeyData(transitions, depth),
[transitions, depth], [transitions, depth],
) )
if (!transitions.length || !links.length) { if (!transitions.length || !data.links.length) {
return ( return (
<div className="h-[400px] flex flex-col items-center justify-center text-center px-6 py-8 gap-3"> <div className="h-[400px] flex flex-col items-center justify-center text-center px-6 py-8 gap-3">
<div className="rounded-full bg-neutral-100 dark:bg-neutral-800 p-4"> <div className="rounded-full bg-neutral-100 dark:bg-neutral-800 p-4">
@@ -292,122 +129,40 @@ export default function SankeyDiagram({
) )
} }
const numColumns = depth + 1
const isLastColumn = (col: number) => col === numColumns - 1
// Colors — brand orange nodes, subtle neutral links
const brandOrange = '#FD5E0F'
const labelFill = isDark ? '#d4d4d4' : '#525252'
const linkDefault = isDark ? 'rgba(255, 255, 255, 0.06)' : 'rgba(0, 0, 0, 0.06)'
const linkHover = isDark ? 'rgba(253, 94, 15, 0.35)' : 'rgba(253, 94, 15, 0.25)'
const linkDimmed = isDark ? 'rgba(255, 255, 255, 0.02)' : 'rgba(0, 0, 0, 0.02)'
const exitNodeFill = isDark ? '#404040' : '#d4d4d4'
// Fade node opacity from 1.0 (entry) to 0.5 (deepest)
const nodeOpacity = (col: number) => {
if (numColumns <= 1) return 1
return 1 - (col / (numColumns - 1)) * 0.5
}
return ( return (
<svg <ThemeProvider theme={muiTheme}>
viewBox={`0 0 ${svgWidth} ${svgHeight}`} <div style={{ width: '100%', height: 500 }}>
preserveAspectRatio="xMidYMid meet" <SankeyDataProvider
className="w-full" series={[
role="img" {
aria-label="User journey Sankey diagram" type: 'sankey' as const,
> data,
{/* Links */} valueFormatter,
<g> nodeOptions: {
{links.map((link) => { sort: 'auto',
const isHovered = hoveredLink === link.id padding: 20,
const hasSomeHovered = hoveredLink !== null width: 9,
const pct = totalSessions > 0 showLabels: true,
? ((link.sessionCount / totalSessions) * 100).toFixed(1) },
: '0' linkOptions: {
color: 'source',
let fill: string opacity: 0.6,
if (isHovered) fill = linkHover curveCorrection: 0,
else if (hasSomeHovered) fill = linkDimmed },
else fill = linkDefault },
]}
return ( margin={{ top: 10, bottom: 10, left: 10, right: 10 }}
<path >
key={link.id} <ChartsWrapper>
d={buildLinkPath(link)} <ChartsSurface>
fill={fill} <SankeyNodePlot />
style={{ transition: 'fill 0.15s ease' }} <SankeyLinkPlot />
onMouseEnter={() => setHoveredLink(link.id)} <SankeyNodeLabelPlot />
onMouseLeave={() => setHoveredLink(null)} </ChartsSurface>
className="cursor-default" <SankeyTooltip trigger="item" />
> </ChartsWrapper>
<title> </SankeyDataProvider>
{link.fromNode.path} {link.toNode.path}: {link.sessionCount.toLocaleString()} sessions ({pct}%) </div>
</title> </ThemeProvider>
</path>
)
})}
</g>
{/* Nodes */}
<g>
{nodes.map((node) => {
const isExit = node.path === EXIT_LABEL
return (
<rect
key={node.id}
x={node.x}
y={node.y}
width={NODE_WIDTH}
height={node.height}
rx={3}
ry={3}
fill={isExit ? exitNodeFill : brandOrange}
opacity={isExit ? 0.4 : nodeOpacity(node.column)}
className={onNodeClick && !isExit ? 'cursor-pointer' : 'cursor-default'}
onClick={() => {
if (onNodeClick && !isExit) onNodeClick(node.path)
}}
>
<title>{node.path} {node.flow.toLocaleString()} sessions</title>
</rect>
)
})}
</g>
{/* Labels */}
<g>
{nodes.map((node) => {
const isLast = isLastColumn(node.column)
const labelX = isLast ? node.x - 6 : node.x + NODE_WIDTH + 6
const labelY = node.y + node.height / 2
const anchor = isLast ? 'end' : 'start'
const displayLabel = truncatePath(node.path, LABEL_MAX_LENGTH)
// Only show labels for nodes tall enough to fit text
if (node.height < 10) return null
return (
<text
key={`label-${node.id}`}
x={labelX}
y={labelY}
dy="0.35em"
textAnchor={anchor}
fill={labelFill}
fontSize={11}
className={onNodeClick && node.path !== EXIT_LABEL ? 'cursor-pointer' : 'cursor-default'}
onClick={() => {
if (onNodeClick && node.path !== EXIT_LABEL) onNodeClick(node.path)
}}
>
{displayLabel}
<title>{node.path}</title>
</text>
)
})}
</g>
</svg>
) )
} }

1049
package-lock.json generated

File diff suppressed because it is too large Load Diff

View File

@@ -14,6 +14,10 @@
"dependencies": { "dependencies": {
"@ciphera-net/ui": "^0.2.5", "@ciphera-net/ui": "^0.2.5",
"@ducanh2912/next-pwa": "^10.2.9", "@ducanh2912/next-pwa": "^10.2.9",
"@emotion/react": "^11.14.0",
"@emotion/styled": "^11.14.1",
"@mui/material": "^7.3.9",
"@mui/x-charts-pro": "^8.27.5",
"@phosphor-icons/react": "^2.1.10", "@phosphor-icons/react": "^2.1.10",
"@simplewebauthn/browser": "^13.2.2", "@simplewebauthn/browser": "^13.2.2",
"@stripe/react-stripe-js": "^5.6.0", "@stripe/react-stripe-js": "^5.6.0",