Refactoring React Three Fiber Drawing Tools: A Modular Approach

Jeremy Atkinson
Developer

Learn how to transform a monolithic drawing application into a clean, extensible tool system using object-oriented design patterns, custom hooks, and reusable components.

Refactoring React Three Fiber Drawing Tools: From Monolith to Modular Architecture

In my previous post, we built a functional drawing application with React Three Fiber. However, the code suffered from a common problem: monolithic architecture. All tool logic was hardcoded into conditional statements, making it difficult to maintain and extend.

Today, we'll refactor this application into a clean, modular system that's easy to extend, test, and maintain.

The Problem with Monolithic Drawing Tools

Our original implementation had several issues:

// The old way - hardcoded conditionals everywhere
if (currentTool === 'line') {
  // 20+ lines of line tool logic
} else if (currentTool === 'shape') {
  // 30+ lines of shape tool logic  
} else if (currentTool === 'rectangle') {
  // 25+ lines of rectangle tool logic
}

Problems:

  • Rigid: Adding new tools requires modifying existing code
  • Duplicated Logic: Similar patterns repeated across tools
  • Hard to Test: Tool logic mixed with UI concerns
  • Poor Separation: State management scattered throughout components

Demo: Before vs After

"use client"
import { Canvas, useFrame, useThree } from '@react-three/fiber'
import { Sphere, Line, Text } from '@react-three/drei'
import { Leva, useControls, button } from 'leva'
import { StrictMode, useRef, useState, useCallback, useEffect } from 'react'
import * as THREE from 'three'

// Utility function to calculate distance between two points
const calculateDistance = (p1, p2) => {
  return Math.sqrt(Math.pow(p2.x - p1.x, 2) + Math.pow(p2.y - p1.y, 2))
}

export default function App() {
  return (
    <StrictMode>
      <Leva flat />
      <Canvas orthographic camera={{ zoom: 50, near: 0.1, far: 200, position: [0, 0, 15] }}>
        <Experience />
      </Canvas>
    </StrictMode>
  )
}

export function Experience() {
  const [drawnObjects, setDrawnObjects] = useState([])
  const [currentTool, setCurrentTool] = useState('line')
  const [isDrawing, setIsDrawing] = useState(false)
  const [drawingState, setDrawingState] = useState(null)
  const [cursorPosition, setCursorPosition] = useState(null)

  const { tool, snapEnabled, snapValue, showLengths } = useControls('Drawing Tools', {
    tool: { 
      value: 'line', 
      options: { Line: 'line', Shape: 'shape', Rectangle: 'rectangle', Delete: 'delete' },
      onChange: (value) => {
        setCurrentTool(value)
        setIsDrawing(false)
        setDrawingState(null)
      }
    },
    snapEnabled: true,
    snapValue: { value: 0.5, min: 0.1, max: 2, step: 0.1 },
    showLengths: false,
    clear: button(() => {
      setDrawnObjects([])
      setIsDrawing(false)
      setDrawingState(null)
    })
  })

  const snapToGrid = useCallback((point) => {
    if (!snapEnabled) return point
    return {
      x: Math.round(point.x / snapValue) * snapValue,
      y: Math.round(point.y / snapValue) * snapValue,
      z: 0
    }
  }, [snapEnabled, snapValue])

  const handleObjectDelete = useCallback((objectId) => {
    setDrawnObjects(prev => prev.filter(obj => obj.id !== objectId))
  }, [])

  const handleCanvasClick = useCallback((point) => {
    if (currentTool === 'delete') return // Don't handle canvas clicks in delete mode
    
    const snappedPoint = snapToGrid(point)
    
    if (currentTool === 'line') {
      if (!isDrawing) {
        setIsDrawing(true)
        setDrawingState({ start: snappedPoint, end: snappedPoint })
      } else {
        setDrawnObjects(prev => [...prev, {
          type: 'line',
          id: Date.now(),
          start: drawingState.start,
          end: snappedPoint
        }])
        setIsDrawing(false)
        setDrawingState(null)
      }
    } else if (currentTool === 'shape') {
      if (!isDrawing) {
        setIsDrawing(true)
        setDrawingState({ points: [snappedPoint] })
      } else {
        const newPoints = [...drawingState.points, snappedPoint]
        setDrawingState({ points: newPoints })
      }
    } else if (currentTool === 'rectangle') {
      if (!isDrawing) {
        setIsDrawing(true)
        setDrawingState({ corner1: snappedPoint, corner2: snappedPoint })
      } else {
        setDrawnObjects(prev => [...prev, {
          type: 'rectangle',
          id: Date.now(),
          corner1: drawingState.corner1,
          corner2: snappedPoint
        }])
        setIsDrawing(false)
        setDrawingState(null)
      }
    }
  }, [currentTool, isDrawing, drawingState, snapToGrid])

  const handleMouseMove = useCallback((point) => {
    const snappedPoint = snapToGrid(point)
    setCursorPosition(snappedPoint)
    
    if (!isDrawing || !drawingState) return
    
    if (currentTool === 'line') {
      setDrawingState(prev => ({ ...prev, end: snappedPoint }))
    } else if (currentTool === 'rectangle') {
      setDrawingState(prev => ({ ...prev, corner2: snappedPoint }))
    }
  }, [isDrawing, drawingState, currentTool, snapToGrid])

  const handleShapeComplete = useCallback(() => {
    if (currentTool === 'shape' && isDrawing && drawingState?.points.length >= 3) {
      setDrawnObjects(prev => [...prev, {
        type: 'shape',
        id: Date.now(),
        points: drawingState.points
      }])
      setIsDrawing(false)
      setDrawingState(null)
    }
  }, [currentTool, isDrawing, drawingState])

  useEffect(() => {
    const handleKeyDown = (e) => {
      if (e.key === 'Escape') {
        if (currentTool === 'shape') {
          handleShapeComplete()
        } else {
          setIsDrawing(false)
          setDrawingState(null)
        }
      }
    }
    
    window.addEventListener('keydown', handleKeyDown)
    return () => window.removeEventListener('keydown', handleKeyDown)
  }, [currentTool, handleShapeComplete])

  return (
    <>
      <directionalLight position={[1, 2, 3]} intensity={1.5} />
      <ambientLight intensity={0.5} />
      
      <DrawingCanvas 
        onCanvasClick={handleCanvasClick}
        onMouseMove={handleMouseMove}
        snapEnabled={snapEnabled}
        snapValue={snapValue}
      />
      
      {/* Render completed objects */}
      {drawnObjects.map(obj => (
        <DrawnObject 
          key={obj.id} 
          object={obj} 
          isDeleteMode={currentTool === 'delete'}
          showLengths={showLengths}
          onDelete={() => handleObjectDelete(obj.id)}
        />
      ))}
      
      {/* Render current drawing preview */}
      {isDrawing && drawingState && (
        <DrawingPreview 
          tool={currentTool} 
          state={drawingState} 
          cursorPosition={cursorPosition}
          showLengths={showLengths}
          onShapeComplete={handleShapeComplete}
        />
      )}
      
      <GridHelper snapValue={snapValue} visible={snapEnabled} />
    </>
  )
}

export function DrawingCanvas({ onCanvasClick, onMouseMove, snapEnabled, snapValue }) {
  const cursorRef = useRef()
  const planeRef = useRef()

  return (
    <mesh
      ref={planeRef}
      position={[0, 0, 0]}
      onPointerMove={(e) => {
        const localPoint = planeRef.current.worldToLocal(e.point.clone())
        
        if (snapEnabled) {
          localPoint.x = Math.round(localPoint.x / snapValue) * snapValue
          localPoint.y = Math.round(localPoint.y / snapValue) * snapValue
        }
        
        if (cursorRef.current) {
          cursorRef.current.position.set(localPoint.x, localPoint.y, 0.01)
          cursorRef.current.visible = true
        }
        
        // Drawing coordinates are directly X,Y with Z=0
        const drawingPoint = { x: localPoint.x, y: localPoint.y, z: 0 }
        onMouseMove(drawingPoint)
      }}
      onPointerOut={() => {
        if (cursorRef.current) {
          cursorRef.current.visible = false
        }
      }}
      onClick={(e) => {
        const localPoint = planeRef.current.worldToLocal(e.point.clone())
        
        if (snapEnabled) {
          localPoint.x = Math.round(localPoint.x / snapValue) * snapValue
          localPoint.y = Math.round(localPoint.y / snapValue) * snapValue
        }
        
        // Drawing coordinates are directly X,Y with Z=0
        const drawingPoint = { x: localPoint.x, y: localPoint.y, z: 0 }
        onCanvasClick(drawingPoint)
      }}
    >
      <planeGeometry args={[20, 20]} />
      <meshBasicMaterial transparent opacity={0.1} color="lightblue" />
      
      <CursorIndicator ref={cursorRef} />
    </mesh>
  )
}

export const CursorIndicator = ({ color = "orange", size = 0.1, ...props }) => {
  const ref = useRef()
  
  useFrame(() => {
    if (ref.current) {
      ref.current.rotation.x += 0.01
      ref.current.rotation.y += 0.01
    }
  })
  
  return (
    <mesh ref={ref} raycast={() => null} visible={false} {...props}>
      <sphereGeometry args={[size]} />
      <meshBasicMaterial color={color} />
    </mesh>
  )
}

export function DrawnObject({ object, isDeleteMode, showLengths, onDelete }) {
  const [isHovered, setIsHovered] = useState(false)
  
  const commonProps = isDeleteMode ? {
    onPointerEnter: () => setIsHovered(true),
    onPointerLeave: () => setIsHovered(false),
    onClick: (e) => {
      e.stopPropagation()
      onDelete()
    },
    style: { cursor: 'pointer' }
  } : {}
  
  const color = isDeleteMode && isHovered ? 'red' : undefined
  const opacity = isDeleteMode && isHovered ? 0.8 : 1
  
  switch (object.type) {
    case 'line':
      return <DrawnLine start={object.start} end={object.end} color={color} opacity={opacity} showLengths={showLengths} {...commonProps} />
    case 'shape':
      return <DrawnShape points={object.points} color={color} opacity={opacity} showLengths={showLengths} {...commonProps} />
    case 'rectangle':
      return <DrawnRectangle corner1={object.corner1} corner2={object.corner2} color={color} opacity={opacity} showLengths={showLengths} {...commonProps} />
    default:
      return null
  }
}

export function DrawingPreview({ tool, state, cursorPosition, showLengths, onShapeComplete }) {
  switch (tool) {
    case 'line':
      return state.start && state.end ? (
        <DrawnLine start={state.start} end={state.end} color="orange" opacity={0.7} showLengths={showLengths} />
      ) : null
    case 'shape':
      return (
        <>
          {state.points?.map((point, i) => (
            <mesh key={i} position={[point.x, point.y, 0]}>
              <sphereGeometry args={[0.05]} />
              <meshBasicMaterial color="orange" />
            </mesh>
          ))}
          {state.points?.length > 1 && (
            <>
              <Line 
                points={state.points.map(p => [p.x, p.y, 0])}
                color="orange"
                lineWidth={2}
              />
              {/* Show lengths on all drawn segments */}
              {showLengths && state.points.slice(0, -1).map((point, i) => {
                const nextPoint = state.points[i + 1]
                const segmentLength = calculateDistance(point, nextPoint)
                const midpoint = {
                  x: (point.x + nextPoint.x) / 2,
                  y: (point.y + nextPoint.y) / 2
                }
                return (
                  <Text
                    key={i}
                    position={[midpoint.x, midpoint.y, 0.1]}
                    fontSize={0.2}
                    color="black"
                    anchorX="center"
                    anchorY="middle"
                  >
                    {segmentLength.toFixed(2)}
                  </Text>
                )
              })}
            </>
          )}
          {/* Preview line from last point to cursor */}
          {state.points?.length > 0 && cursorPosition && (
            <>
              <Line 
                points={[
                  [state.points[state.points.length - 1].x, state.points[state.points.length - 1].y, 0],
                  [cursorPosition.x, cursorPosition.y, 0]
                ]}
                color="orange"
                lineWidth={1}
                transparent
                opacity={0.5}
              />
              {showLengths && (() => {
                const lastPoint = state.points[state.points.length - 1]
                const previewLength = calculateDistance(lastPoint, cursorPosition)
                const midpoint = {
                  x: (lastPoint.x + cursorPosition.x) / 2,
                  y: (lastPoint.y + cursorPosition.y) / 2
                }
                return (
                  <Text
                    position={[midpoint.x, midpoint.y, 0.1]}
                    fontSize={0.2}
                    color="orange"
                    anchorX="center"
                    anchorY="middle"
                  >
                    {previewLength.toFixed(2)}
                  </Text>
                )
              })()}
            </>
          )}
          {state.points?.length >= 3 && (
            <mesh 
              position={[state.points[0].x, state.points[0].y, 0]}
              onClick={onShapeComplete}
            >
              <sphereGeometry args={[0.1]} />
              <meshBasicMaterial color="green" />
            </mesh>
          )}
        </>
      )
    case 'rectangle':
      return state.corner1 && state.corner2 ? (
        <DrawnRectangle corner1={state.corner1} corner2={state.corner2} color="orange" opacity={0.7} showLengths={showLengths} />
      ) : null
    default:
      return null
  }
}

export function DrawnLine({ start, end, color = "black", opacity = 1, showLengths = false, ...props }) {
  const points = [[start.x, start.y, 0], [end.x, end.y, 0]]
  const length = calculateDistance(start, end)
  const midpoint = {
    x: (start.x + end.x) / 2,
    y: (start.y + end.y) / 2
  }
  
  return (
    <group {...props}>
      <Line points={points} color={color} lineWidth={3} />
      <mesh position={[start.x, start.y, 0]}>
        <sphereGeometry args={[0.05]} />
        <meshBasicMaterial color={color} transparent opacity={opacity} />
      </mesh>
      <mesh position={[end.x, end.y, 0]}>
        <sphereGeometry args={[0.05]} />
        <meshBasicMaterial color={color} transparent opacity={opacity} />
      </mesh>
      {showLengths && (
        <Text
          position={[midpoint.x, midpoint.y, 0.1]}
          fontSize={0.2}
          color="black"
          anchorX="center"
          anchorY="middle"
        >
          {length.toFixed(2)}
        </Text>
      )}
    </group>
  )
}

export function DrawnShape({ points, color = "crimson", opacity = 1, showLengths = false, ...props }) {
  if (points.length < 3) return null
  
  const shape = new THREE.Shape()
  shape.moveTo(points[0].x, points[0].y)
  points.forEach(p => shape.lineTo(p.x, p.y))
  shape.lineTo(points[0].x, points[0].y)
  
  // Create segments for length display
  const segments = []
  for (let i = 0; i < points.length; i++) {
    const start = points[i]
    const end = points[(i + 1) % points.length] // Wrap around to close the shape
    segments.push({ start, end })
  }
  
  return (
    <group {...props}>
      <mesh position={[0, 0, 0]}>
        <shapeGeometry args={[shape]} />
        <meshBasicMaterial 
          color={color} 
          transparent 
          opacity={opacity} 
          side={THREE.DoubleSide} 
        />
      </mesh>
      {/* Edge lines for definition */}
      {segments.map((segment, i) => (
        <Line 
          key={`edge-${i}`}
          points={[[segment.start.x, segment.start.y, 0.01], [segment.end.x, segment.end.y, 0.01]]}
          color="black"
          lineWidth={1}
        />
      ))}
      {showLengths && segments.map((segment, i) => {
        const length = calculateDistance(segment.start, segment.end)
        const midpoint = {
          x: (segment.start.x + segment.end.x) / 2,
          y: (segment.start.y + segment.end.y) / 2
        }
        return (
          <Text
            key={i}
            position={[midpoint.x, midpoint.y, 0.1]}
            fontSize={0.2}
            color="black"
            anchorX="center"
            anchorY="middle"
          >
            {length.toFixed(2)}
          </Text>
        )
      })}
    </group>
  )
}

export function DrawnRectangle({ corner1, corner2, color = "blue", opacity = 1, showLengths = false, ...props }) {
  const width = Math.abs(corner2.x - corner1.x)
  const height = Math.abs(corner2.y - corner1.y)
  const centerX = (corner1.x + corner2.x) / 2
  const centerY = (corner1.y + corner2.y) / 2
  
  // Calculate the four corners
  const minX = Math.min(corner1.x, corner2.x)
  const maxX = Math.max(corner1.x, corner2.x)
  const minY = Math.min(corner1.y, corner2.y)
  const maxY = Math.max(corner1.y, corner2.y)
  
  // Define only two sides to avoid redundant labels
  const sides = [
    { start: { x: minX, y: minY }, end: { x: maxX, y: minY }, length: width }, // bottom
    { start: { x: maxX, y: minY }, end: { x: maxX, y: maxY }, length: height }, // right
  ]
  
  return (
    <group {...props}>
      <mesh position={[centerX, centerY, 0]}>
        <planeGeometry args={[width, height]} />
        <meshBasicMaterial 
          color={color} 
          transparent 
          opacity={opacity} 
          side={THREE.DoubleSide} 
        />
      </mesh>
      {/* Edge lines for definition - all four sides */}
      <Line 
        points={[
          [minX, minY, 0.01], [maxX, minY, 0.01], // bottom
          [maxX, maxY, 0.01], [minX, maxY, 0.01], // top
          [minX, minY, 0.01] // close the rectangle
        ]}
        color="black"
        lineWidth={1}
      />
      {showLengths && sides.map((side, i) => {
        const midpoint = {
          x: (side.start.x + side.end.x) / 2,
          y: (side.start.y + side.end.y) / 2
        }
        return (
          <Text
            key={i}
            position={[midpoint.x, midpoint.y, 0.1]}
            fontSize={0.2}
            color="black"
            anchorX="center"
            anchorY="middle"
          >
            {side.length.toFixed(2)}
          </Text>
        )
      })}
    </group>
  )
}

export function GridHelper({ snapValue, visible }) {
  if (!visible) return null
  
  const lines = []
  const size = 20
  const divisions = size / snapValue
  
  for (let i = -divisions; i <= divisions; i++) {
    const pos = i * snapValue
    lines.push(
      <Line 
        key={`h${i}`}
        points={[[-size/2, pos, 0], [size/2, pos, 0]]}
        color="gray"
        lineWidth={0.5}
      />,
      <Line 
        key={`v${i}`}
        points={[[pos, -size/2, 0], [pos, size/2, 0]]}
        color="gray"
        lineWidth={0.5}
      />
    )
  }
  
  return <group>{lines}</group>
}

Original monolithic implementation

"use client"
import { Canvas, useFrame } from '@react-three/fiber'
import { Sphere, Line } from '@react-three/drei'
import * as THREE from 'three'
import { Leva, useControls, button } from 'leva'
import { StrictMode, useRef, useEffect, memo } from 'react'

// Import our new modular architecture
import { useDrawingState, useSnapToGrid, useCoordinateTransform, useCursorIndicator, useKeyboardEvents } from './hooks.js'
import { DrawableObject } from './DrawableObjects.jsx'
import { ToolPreview } from './ToolPreview.jsx'

export default function App() {
  return (
    <StrictMode>
      <Leva flat />
      <Canvas orthographic camera={{ zoom: 50, near: 0.1, far: 200, position: [0, 0, 15] }}>
        <Experience />
      </Canvas>
    </StrictMode>
  )
}

export function Experience() {
  // Use our custom hooks for state management
  const {
    drawnObjects,
    currentTool,
    isDrawing,
    drawingState,
    cursorPosition,
    handleToolChange,
    handleCanvasClick,
    handleMouseMove,
    handleKeyPress,
    handleObjectDelete,
    clearAll,
    toolOptions,
    toolSpecificData
  } = useDrawingState()

  const { snapEnabled, snapValue, setSnapEnabled, setSnapValue } = useSnapToGrid(0.5)
  const { transformPoint } = useCoordinateTransform()

  // Leva controls
  const { tool, snap, snapVal, showLengths } = useControls('Drawing Tools (Refactored)', {
    tool: {
      value: 'line',
      options: toolOptions,
      onChange: handleToolChange
    },
    snap: {
      value: true,
      onChange: setSnapEnabled
    },
    snapVal: {
      value: 0.5,
      min: 0.1,
      max: 2,
      step: 0.1,
      onChange: setSnapValue
    },
    showLengths: false,
    clear: button(clearAll)
  })

  // Set up keyboard event handling
  useKeyboardEvents(handleKeyPress)

  return (
    <>
      <directionalLight position={[1, 2, 3]} intensity={1.5} />
      <ambientLight intensity={0.5} />

      <DrawingCanvas
        onCanvasClick={handleCanvasClick}
        onMouseMove={handleMouseMove}
        snapEnabled={snapEnabled}
        snapValue={snapValue}
        transformPoint={transformPoint}
      />

      {/* Render completed objects */}
      {drawnObjects.map(obj => (
        <DrawableObject
          key={obj.id}
          object={obj}
          isDeleteMode={currentTool === 'delete'}
          showLengths={showLengths}
          onDelete={() => handleObjectDelete(obj.id)}
        />
      ))}

      {/* Render current drawing preview */}
      {isDrawing && drawingState && (
        <ToolPreview
          tool={currentTool}
          state={drawingState}
          cursorPosition={cursorPosition}
          showLengths={showLengths}
          onShapeComplete={() => handleKeyPress('Escape')}
        />
      )}

      <GridHelper snapValue={snapValue} visible={snapEnabled} />
    </>
  )
}

// Drawing canvas component
export function DrawingCanvas({ onCanvasClick, onMouseMove, snapEnabled, snapValue, transformPoint }) {
  const { cursorRef, showCursor, hideCursor } = useCursorIndicator()
  const planeRef = useRef()

  return (
    <mesh
      ref={planeRef}
      position={[0, 0, 0]}
      onPointerMove={(e) => {
        if (!planeRef.current) return
        const localPoint = planeRef.current.worldToLocal(e.point.clone())
        const transformedPoint = transformPoint(localPoint, snapEnabled, snapValue)

        showCursor(transformedPoint)
        onMouseMove(transformedPoint, snapEnabled, snapValue)
      }}
      onPointerOut={hideCursor}
      onClick={(e) => {
        if (!planeRef.current) return
        const localPoint = planeRef.current.worldToLocal(e.point.clone())
        const transformedPoint = transformPoint(localPoint, snapEnabled, snapValue)
        onCanvasClick(transformedPoint, snapEnabled, snapValue)
      }}
    >
      <planeGeometry args={[20, 20]} />
      <meshBasicMaterial transparent opacity={0.1} color="lightblue" />

      <CursorIndicator ref={cursorRef} />
    </mesh>
  )
}

// Cursor indicator component
export const CursorIndicator = ({ color = "orange", size = 0.1, ...props }) => {
  const ref = useRef()

  useFrame(() => {
    if (ref.current) {
      ref.current.rotation.x += 0.01
      ref.current.rotation.y += 0.01
    }
  })

  return (
    <mesh ref={ref} raycast={() => null} visible={false} {...props}>
      <sphereGeometry args={[size]} />
      <meshBasicMaterial color={color} />
    </mesh>
  )
}

// Grid helper component with React.memo to prevent unnecessary re-renders
export const GridHelper = memo(function GridHelper({ snapValue, visible }) {
  if (!visible) return null
  
  const lines = []
  const size = 20
  const divisions = size / snapValue
  
  for (let i = -divisions; i <= divisions; i++) {
    const pos = i * snapValue
    lines.push(
      <Line 
        key={`h${i}`}
        points={[[-size/2, pos, 0], [size/2, pos, 0]]}
        color="gray"
        lineWidth={0.5}
      />,
      <Line 
        key={`v${i}`}
        points={[[pos, -size/2, 0], [pos, size/2, 0]]}
        color="gray"
        lineWidth={0.5}
      />
    )
  }
  
  return <group>{lines}</group>
})

Refactored modular architecture

The Refactored Architecture

Our new system uses several design patterns to create a clean, extensible architecture:

1. Tool Strategy Pattern

Instead of hardcoded conditionals, we use the Strategy Pattern with a plugin-based tool system:

// Abstract base class defines the interface
export class BaseTool {
  constructor(name) {
    this.name = name
    this.isDrawing = false
    this.state = null
  }

  // Standardized interface all tools must implement
  onCanvasClick(point, snapEnabled, snapValue) {
    const snappedPoint = GeometryUtils.snapToGrid(point, snapEnabled, snapValue)
    
    if (!this.isDrawing) {
      return this.startDrawing(snappedPoint)
    } else {
      return this.continueDrawing(snappedPoint)
    }
  }

  // Abstract methods - must be implemented by subclasses
  startDrawing(point) {
    throw new Error('startDrawing must be implemented by subclass')
  }

  continueDrawing(point) {
    throw new Error('continueDrawing must be implemented by subclass')
  }
}

Each tool becomes a focused class that implements this interface:

export class LineTool extends BaseTool {
  constructor() {
    super('line')
  }

  startDrawing(point) {
    this.isDrawing = true
    this.state = { start: point, end: point }
    return { action: 'startDrawing', state: this.state }
  }

  continueDrawing(point) {
    const object = {
      type: 'line',
      id: GeometryUtils.generateId(),
      start: this.state.start,
      end: point
    }
    
    this.reset()
    return { action: 'complete', object }
  }
}

2. Tool Manager (Registry Pattern)

The ToolManager acts as a registry and delegates operations to the active tool:

export class ToolManager {
  constructor() {
    this.tools = new Map()
    this.activeTool = null
    
    // Register default tools
    this.registerTool('line', new LineTool())
    this.registerTool('shape', new ShapeTool())
    this.registerTool('rectangle', new RectangleTool())
  }

  handleCanvasClick(point, snapEnabled, snapValue) {
    if (!this.activeTool) return null
    return this.activeTool.onCanvasClick(point, snapEnabled, snapValue)
  }

  setActiveTool(toolName) {
    this.activeTool = this.tools.get(toolName)
  }
}

Now our main component becomes much simpler:

// Instead of giant conditional blocks:
const result = toolManager.handleCanvasClick(point, snapEnabled, snapValue)

// The tool manager delegates to the appropriate tool automatically

3. Custom Hooks for State Management

We extract state logic into focused custom hooks:

export function useDrawingState() {
  const [drawnObjects, setDrawnObjects] = useState([])
  const [toolManager] = useState(() => new ToolManager())
  
  const handleCanvasClick = useCallback((point, snapEnabled, snapValue) => {
    const result = toolManager.handleCanvasClick(point, snapEnabled, snapValue)
    
    // Handle result based on action type
    switch (result?.action) {
      case 'complete':
        setDrawnObjects(prev => [...prev, result.object])
        break
      case 'startDrawing':
        setIsDrawing(true)
        break
    }
  }, [toolManager])

  return {
    drawnObjects,
    handleCanvasClick,
    toolManager,
    // ... other state and actions
  }
}

4. Reusable Rendering Components

Common functionality is extracted into reusable components:

// Length display component
export function LengthDisplay({ start, end, color = "black", visible = true }) {
  if (!visible || !start || !end) return null
  
  const length = GeometryUtils.calculateDistance(start, end)
  const midpoint = GeometryUtils.calculateMidpoint(start, end)
  
  return (
    <Text
      position={[midpoint.x, midpoint.y, 0.1]}
      fontSize={0.2}
      color={color}
      anchorX="center"
      anchorY="middle"
    >
      {length.toFixed(2)}
    </Text>
  )
}

// Edge lines component
export function EdgeLines({ segments, color = "black", lineWidth = 1 }) {
  return (
    <>
      {segments.map((segment, i) => (
        <Line 
          key={i}
          points={[
            [segment.start.x, segment.start.y, 0.01], 
            [segment.end.x, segment.end.y, 0.01]
          ]}
          color={color}
          lineWidth={lineWidth}
        />
      ))}
    </>
  )
}

5. Utility Library

Common calculations are centralized in a utility library:

export const GeometryUtils = {
  calculateDistance: (p1, p2) => {
    return Math.sqrt(Math.pow(p2.x - p1.x, 2) + Math.pow(p2.y - p1.y, 2))
  },

  calculateMidpoint: (p1, p2) => ({
    x: (p1.x + p2.x) / 2,
    y: (p1.y + p2.y) / 2,
    z: 0
  }),

  snapToGrid: (point, snapEnabled, snapValue) => {
    if (!snapEnabled) return point
    return {
      x: Math.round(point.x / snapValue) * snapValue,
      y: Math.round(point.y / snapValue) * snapValue,
      z: 0
    }
  }
}

Benefits of the New Architecture

1. Extensibility

Adding a new tool is now trivial:

class CircleTool extends BaseTool {
  startDrawing(point) {
    this.isDrawing = true
    this.state = { center: point, radius: 0 }
    return { action: 'startDrawing', state: this.state }
  }

  continueDrawing(point) {
    const radius = GeometryUtils.calculateDistance(this.state.center, point)
    const object = {
      type: 'circle',
      id: GeometryUtils.generateId(),
      center: this.state.center,
      radius
    }
    
    this.reset()
    return { action: 'complete', object }
  }
}

// Register the new tool
toolManager.registerTool('circle', new CircleTool())

No existing code needs to be modified!

2. Maintainability

Each tool is self-contained with clear responsibilities:

  • Tool classes handle their own state and logic
  • Rendering components focus only on visualization
  • Hooks manage state transitions
  • Utilities provide pure functions

3. Testability

Individual components can be tested in isolation:

// Test a tool in isolation
const lineTool = new LineTool()
const result = lineTool.startDrawing({ x: 0, y: 0 })
expect(result.action).toBe('startDrawing')
expect(result.state.start).toEqual({ x: 0, y: 0 })

4. Reusability

Components can be used across different applications:

  • LengthDisplay can show measurements anywhere
  • EdgeLines can outline any shape
  • GeometryUtils provides universal geometry functions

Key Design Patterns Used

  1. Strategy Pattern: Tool classes with common interface
  2. Registry Pattern: ToolManager for tool registration/delegation
  3. Command Pattern: Tools return action objects describing what happened
  4. Composition: Complex functionality built from simple, focused components
  5. Dependency Injection: Tools receive dependencies (GeometryUtils) rather than creating them

Performance Considerations

The refactored architecture maintains excellent performance:

  • Tool switching has minimal overhead (just swapping object references)
  • Rendering uses the same optimized React Three Fiber components
  • State updates are granular and focused
  • Memory usage is efficient with tool instance reuse

Lessons Learned

1. Start with Working Code

We refactored a working application. This allowed us to:

  • Understand the exact requirements
  • Maintain the same functionality
  • Compare performance before/after

2. Abstractions Should Solve Real Problems

Each abstraction addresses specific issues:

  • BaseTool: Eliminates tool-specific conditionals
  • ToolManager: Centralizes tool switching logic
  • Custom hooks: Separate state concerns from UI
  • Utility functions: Eliminate code duplication

3. Incremental Refactoring Works

We could refactor piece by piece:

  1. Extract utility functions first
  2. Create tool classes second
  3. Build reusable components third
  4. Integrate with hooks last

Next Steps

This architecture opens up many possibilities:

Advanced Tools

  • Text tool for adding labels
  • Dimension tool for automatic dimensioning
  • Arc tool for curved segments
  • Freehand tool for sketch-like drawing

Enhanced Features

  • Undo/Redo system with command pattern
  • Layer management for organizing drawings
  • Import/Export to common formats (SVG, DXF)
  • Keyboard shortcuts for tool switching

Performance Optimizations

  • Spatial indexing for efficient object selection
  • Virtualization for large drawings
  • WebWorkers for complex calculations

Conclusion

Refactoring our drawing tools from a monolithic structure to a modular architecture demonstrates the power of good software design. The new system is:

  • Easier to extend (new tools in minutes, not hours)
  • Simpler to maintain (focused, single-responsibility components)
  • More testable (isolated, pure functions and classes)
  • Highly reusable (components work across applications)

Most importantly, it maintains the same user experience while dramatically improving the developer experience.

The key insight is that abstractions should solve real problems. Each pattern we applied addressed specific pain points in the original code, making the refactored version not just cleaner, but genuinely easier to work with.

Whether you're building drawing tools, game engines, or any complex interactive application, these patterns will help you create systems that can grow and evolve with your requirements.

Jeremy Atkinson

Jeremy is a structural engineer, researcher, and developer from BC. He works on Calcs.app and writes at Kinson.io

Comments (0)

Loading comments...