みかづきブログ・カスタム

基本的にはちょちょいのほいです。

ブラウザに書いた手書きの数字の認識を目指す(OpenAI API編) ✏️

ひょんなことから、ブラウザに書いた手書きの数字の認識を目指すことにしました。
今回は、OpenAI APIを試してみます。

前回同様、手書きで数字を書く仕組みは、かつて作った、Canvasに線を引けるWebサイトを流用します。

blog.kimizuka.org

また、OpenAIのAPIに画像の内容をテキスト化してもらうプロトタイプも、かつて作っているので、今回はこれらをうまいこと組み合わせればすぐに完成することでしょう。

blog.kimizuka.org

今回はNext.jsのPage Routerをつかって実装しました。

ソースコード(抜粋)

src/pages/api/openai.ts

import OpenAI from 'openai';
import type { NextApiRequest, NextApiResponse } from 'next';

type Data = {
  state: string;
  data?: any;
}

export default async function handler(
  req: NextApiRequest,
  res: NextApiResponse<Data>
) {
  if (req.method === 'POST') {
    const openai = new OpenAI({
      apiKey: process.env.OPENAI_API_KEY
    });

    const chatCompletion = await openai.chat.completions.create({
      model: 'gpt-4-turbo',
      messages: [{
        role: 'user',
        content: [{
          type: 'text',
          text: '画像には手書きで数字が書いてあります。書いてある数字を数字のみで回答してください。数字以外の文字が書かれていた際や、読み取れない場合は?のみを返してください。'
        },{
          type: 'image_url',
          image_url: {
            url: `${ req.body.base64 }`
          }
        }]
      }],
      max_tokens: 300
    });

    res.status(200).json({ state: 'ok', data: chatCompletion.choices[0] });
  }
}

前回からの差分は、

  1. モデルをgpt-4-vision-previewからgpt-4-turboに変更
  2. プロンプトの変更
  3. image_urlの渡し方を変更

の3点です。

src/components/templates/IndexPageTemplate.tsx

import axios from 'axios';
import styled from 'styled-components';
import { useEffect, useRef, useState } from 'react';
import { setup } from '@/scripts/canvas.js';

export function IndexPageTemplate() {
  const canvasRef = useRef<HTMLCanvasElement>(null);
  const [ text, setText ] = useState<string>('');
  const [ isLoading, setIsLoading ] = useState<boolean>(false);
  const [ events, setEvents ] = useState({ clear: () => {} });

  useEffect(() => {
    const { clear } = setup(canvasRef.current);

    setEvents({
      clear
    });
  }, []);

  async function handleClickBtnSubmit() {
    const base64 = canvasRef.current?.toDataURL('image/png');

    if (base64) {
      setIsLoading(true);
      const { data } = await axios.post('/api/openai', { base64 });

      setText(data.data.message.content);
      setIsLoading(false);
    }
  }

  return (
    <Wrapper data-is-loading={ isLoading }>
      <canvas ref={ canvasRef } />
      <button onClick={ events.clear }>clear</button>
      <button onClick={ handleClickBtnSubmit }>check</button>
      <p>{ text }</p>
      <div className="overlay">
        <svg width="25" height="24" viewBox="0 0 25 24" fill="none" xmlns="http://www.w3.org/2000/svg">
          <circle opacity="0.5" cx="12.5" cy="12" r="10" stroke="white" strokeWidth="4"/>
          <path d="M22.5 12C22.5 6.47715 18.0228 2 12.5 2" stroke="white" strokeWidth="4" strokeLinecap="round">
            <animateTransform
              attributeName="transform"
              type="rotate"
              repeatCount="indefinite"
              dur=".8s"
              from="0 12 12"
              to="360 12 12"
            />
          </path>
        </svg>
      </div>
    </Wrapper>
  );
}

const Wrapper = styled.div`
  padding: 16px;

  canvas {
    display: block;
    border: 1px solid #000;
    max-width: 100%;
    height: auto;
  }

  button {
    + button {
      margin-left: 4px;
    }
  }
  .overlay {
    position: fixed;
    top: 0; bottom: 0;
    left: 0; right: 0;
    background: rgba(0, 0, 0, 0.4);
    opacity: 0;
    pointer-events: none;
  }

  svg {
    position: fixed;
    top: 50%;
    left: 50%;
    transform: translate(-50%, -50%);
  }

  &[data-is-loading='true'] {
    pointer-events: none;

    .overlay {
      opacity: 1;
    }
  }
`;

レガシーな記法の手書きの実装(canvas.js)を強引に組み込みました。

src/scripts/canvas.js

// NameSpace
const ns = {};

// Util
(() => {
  ns.Util = ns.Util || {};

  function buildGetAve(opt_maxLength) {
    const array = [];
    const maxLength = opt_maxLength || 10;
    let index = 0;

    function _ave() {
      const length = array.length;
      let sum = 0;

      for (let i = 0; i < length; ++i) {
        sum += array[i];
      }

      return sum / length;
    }

    function getAve(val) {
      array[index] = val;
      index = (index + 1) % maxLength;

      return _ave();
    }

    return getAve;
  }

  ns.Util = {
    buildGetAve
  };
})();

// EventDispatcher
(() => {
  function EventDispatcher() {
    this._events = {};
  }

  EventDispatcher.prototype.hasEventListener = function(eventName) {
    return !!this._events[eventName];
  };

  EventDispatcher.prototype.addEventListener = function(eventName, callback) {
    console.log(this);
    if (this.hasEventListener(eventName)) {
      const events = this._events[eventName];
      const length = events.length;

      for (let i = 0; i < length; i++) {
        if (events[i] === callback) {
          return;
        }
      }

      events.push(callback);
    } else {
      this._events[eventName] = [callback];
    }

    return this;
  };

  EventDispatcher.prototype.removeEventListener = function(eventName, callback) {
    if (!this.hasEventListener(eventName)) {
      return;
    } else {
      const events = this._events[eventName];
      let i = events.length;
      let index;

      while (i--) {
        if (events[i] === callback) {
          index = i;
        }
      }

      if (typeof index === 'number') {
        events.splice(index, 1);
      }
    }

    return this;
  };

  EventDispatcher.prototype.fireEvent = function(eventName, opt_this, opt_arg) {
    if (!this.hasEventListener(eventName)) {
      return;
    } else {
      const events = this._events[eventName];
      const copyEvents = [ ...events ];
      const arg = [ ...arguments ];
      const length = events.length;

      // eventNameとopt_thisを削除
      arg.splice(0, 2);

      for (let i = 0; i < length; i++) {
        copyEvents[i].apply(opt_this || this, arg);
      }
    }
  };

  ns.EventDispatcher = EventDispatcher;
})();

// Throttle
(() => {
  function Throttle(opt_interval, opt_callback) {
    this._timer = null;
    this._lastEventTime = 0;
    this._interval = opt_interval || 500;
    this._callback = opt_callback || function() {};
  }

  Throttle.prototype.setInterval = function(ms) {
    this._interval = ms;
  };

  Throttle.prototype.addEvent = function(fn) {
    this._callback = fn;
  };

  Throttle.prototype.fireEvent = function(opt_arg) {
    const _this = this;
    const currentTime = Date.now();
    const timerInterval = this._interval / 10;

    clearTimeout(this.timer);

    if (currentTime - _this._lastEventTime > _this._interval) {
      _fire();
    } else {
      _this.timer = setTimeout(_fire, timerInterval);
    }

    function _fire() {
      _this._callback.call(_this, opt_arg || null);
      _this._lastEventTime = currentTime;
    }
  };

  ns.Throttle = Throttle;
})();

// Point
(() => {
  function Point(x, y, opt_size) {
    const _this = this;

    _init();

    function _init() {
      ns.EventDispatcher.call(_this);
    }

    _this.x = x;
    _this.y = y;
    _this.size = opt_size || 1;
  }

  Point.prototype = new ns.EventDispatcher();
  Point.prototype.constructor = Point;

  Point.prototype.setSize = function(size) {
    this.size = size;
  };

  Point.prototype.draw = function(ctx, opt_size) {
    const size = opt_size || this.size || 1;

    ctx.save();
      ctx.beginPath();
      ctx.arc(this.x, this.y, size, 0, Math.PI * 2, false);
      ctx.fill();
      ctx.moveTo(this.x, this.y);
    ctx.restore();

    return this;
  };

  Point.getDistance = function(pointA, pointB) {
    return Math.sqrt(Math.pow(pointB.x - pointA.x, 2) + Math.pow(pointB.y - pointA.y, 2));
  };

  ns.Point = Point;
})();

// Line
(() => {
  let Point = ns.Point;

  function Line(opt_point) {
    const _this = this;
    const pointList = [];

    _init();

    function _init() {
      ns.EventDispatcher.call(_this);
      _this.pointList = pointList;

      if (opt_point) {
        _this.push(opt_point);
      }
    }
  }

  Line.prototype = new ns.EventDispatcher();
  Line.prototype.constructor = Line;

  Line.prototype.push = function(pointModel) {
    this.pointList.push(pointModel);
  };

  Line.prototype.drawLine = function(ctx) {
    const pointList = this.pointList;
    const length = pointList.length;

    ctx.save();
      ctx.lineCap  = 'round';
      ctx.lineJoin = 'round';

      if (length > 1) {
        for (let i = 1; i < length; ++i) {
          ctx.beginPath();
            ctx.moveTo(pointList[i - 1].x, pointList[i - 1].y);
            ctx.lineTo(pointList[i].x, pointList[i].y);
            ctx.lineWidth = pointList[i].size;
          ctx.stroke();
        }
      } else {
        pointList[0].draw(ctx);
      }

      ctx.restore();
  };

  Line.prototype.drawQuadraticCurve = function(ctx) {
    const pointList = this.pointList;
    const length = pointList.length;
    const quadraticPointList = [];
    let lastIndex = 0;

    ctx.save();
      ctx.lineCap  = 'round';
      ctx.lineJoin = 'round';

      if (length > 1) {
        quadraticPointList[lastIndex] = pointList[0];

        ctx.beginPath();
          ctx.moveTo(quadraticPointList[0].x, quadraticPointList[0].y);

          for (let i = 1; i < length; ++i) {
            quadraticPointList[++lastIndex] = new Point(
              (quadraticPointList[lastIndex - 1].x + pointList[i].x) / 2,
              (quadraticPointList[lastIndex - 1].y + pointList[i].y) / 2
            );
            quadraticPointList[++lastIndex] = (pointList[i]);
            ctx.quadraticCurveTo(
              quadraticPointList[i * 2 - 2].x,
              quadraticPointList[i * 2 - 2].y,
              quadraticPointList[i * 2 - 1].x,
              quadraticPointList[i * 2 - 1].y
            );
            ctx.lineWidth = pointList[i].size;
            ctx.stroke();
            ctx.beginPath();
            ctx.moveTo(quadraticPointList[i * 2 - 1].x, quadraticPointList[i * 2 - 1].y);
          }

          ctx.lineTo(quadraticPointList[lastIndex].x, quadraticPointList[lastIndex].y);
        ctx.stroke();
      } else {
        pointList[lastIndex].draw(ctx);
      }
    ctx.restore();
  };

  ns.Line = Line;
})();

// LineManager
(() => {
  let instance;

  function getInstance() {
    if (!instance) {
      instance = new LineManager();
    }

    return instance;
  }

  function LineManager() {
    const _this = this;
    const lineList = [];

    _init();

    function _init() {
      ns.EventDispatcher.call(_this);
    }

    _this.lineList = lineList;
  }

  LineManager.prototype = new ns.EventDispatcher();
  LineManager.prototype.constructor = LineManager;

  LineManager.prototype.push = function(lineModel) {
    this.lineList.push(lineModel);
  };

  LineManager.prototype.addPoint = function(pointModel, opt_index) {
    const index = opt_index || this.lineList.length - 1;

    this.lineList[index].push(pointModel);
  };

  LineManager.prototype.drawQuadraticCurve = function(ctx) {
    const lineList = this.lineList;
    const length = lineList.length;

    if (length) {
      lineList[length - 1].drawQuadraticCurve(ctx);
    }
  };

  LineManager.prototype.clear = function() {
    let _this = this;

    _this.lineList = [];
  };

  ns.LineManager = {
    getInstance: getInstance
  };
})();

export function setup(canvas) { // ❶ ❷
  const doc = document;
  const SIZE = 10;
  const _getAve = ns.Util.buildGetAve(10);

  function _getSize(l) {
    const MAX_DISTANCE = 100;
    let width;

    if (l > MAX_DISTANCE) {
      width = 1;
    }

    width = (MAX_DISTANCE - l) * SIZE / MAX_DISTANCE;

    return _getAve(width);
  }

  const lineManager = ns.LineManager.getInstance();
  const ctx = canvas.getContext('2d');
  const sub = doc.createElement('canvas');
  const subCtx  = sub.getContext('2d');
  const answer = doc.getElementById('answer');
  const START_EVENT = 'mousedown';
  const MOVE_EVENT = 'mousemove';
  const END_EVENT = 'mouseup';
  const WIDTH = 200;
  const HEIGHT = 200;

  _setup();

  async function _setup() {
    addSketchEventForCanvas(canvas);

    canvas.width = sub.width = WIDTH;
    canvas.height = sub.height = HEIGHT;

    tick();
  }

  function tick() {
    draw();
    requestAnimationFrame(tick);
  }

  function draw() {
    canvas.width = WIDTH;
    canvas.height = HEIGHT;

    ctx.drawImage(sub, 0, 0);

    ctx.save();
      ctx.fillStyle = ctx.strokeStyle = 'rgba(50, 50, 50, 1)';
      lineManager.drawQuadraticCurve(ctx);
    ctx.restore();
  }

  function addSketchEventForCanvas(elm) {
    const throttle = new ns.Throttle(10, handleThrottleMoveEvent);
    let lastPoint = null;

    elm.addEventListener(START_EVENT, handleStartEvent, false);
    doc.addEventListener(END_EVENT, handleEndEvent, false);

    elm.addEventListener('touchstart', handleStartEvent, false);
    elm.addEventListener('touchmove', handleMoveEvent, false);
    doc.addEventListener('touchend', handleEndEvent, false);

    function handleThrottleMoveEvent(evt) {
      const { offsetX, offsetY } = getOffset(evt);
      const currentPoint = new ns.Point(offsetX, offsetY);

      currentPoint.setSize(_getSize(ns.Point.getDistance(lastPoint, currentPoint)));
      lastPoint = currentPoint;

      lineManager.addPoint(currentPoint);
    }

    function handleStartEvent(evt) {
      const { offsetX, offsetY } = getOffset(evt);

      lastPoint = new ns.Point(offsetX, offsetY, SIZE / 2);

      lineManager.push(new ns.Line(lastPoint));
      elm.addEventListener(MOVE_EVENT, handleMoveEvent, false);
    }

    function handleMoveEvent(evt) {
      throttle.fireEvent(evt);
    }

    function handleEndEvent(evt) {
      sub.width = WIDTH;
      sub.height = HEIGHT;
      subCtx.drawImage(canvas, 0, 0);
      elm.removeEventListener(MOVE_EVENT, handleMoveEvent, false);
      lastPoint = null;
    }

    function getOffset(evt) {
      if (evt.touches) {
        const rect = evt.target.getBoundingClientRect();

        return {
          offsetX: evt.touches[0].pageX - evt.touches[0].target.offsetLeft - rect.left,
          offsetY: evt.touches[0].pageY - evt.touches[0].target.offsetTop - rect.top
        };
      }

      return {
        offsetX: evt.offsetX,
        offsetY: evt.offsetY
      };
    }
  }

  return { // ❸
    clear: function() {
      lineManager.clear();
      canvas.width = sub.width = WIDTH;
      canvas.height = sub.height = HEIGHT;
    }
  }
}

かつて書いたJavaScriptをTypeScriptに変換するのが大変すぎたため、最低限の改修をしてJavaScriptのまま読み込みました。
行った改修は、

  1. setupのみをexportする
  2. setup実行時に引数にCanvasを渡せるようにする
  3. setupを実行した際に、clearを持ったオブジェクトを戻す

です。

DEMO

Open AIの料金が怖いので、今回は動作の様子だけ貼っておきます。

tesseract.jsよりも格段に精度は上がりましたが、そのトレードオフとしてスピードを失いました。
通信を挟んでいるからスピード面はだいぶ不利ですね。

リポジトリ

github.com

雑感

そもそも数字の認識のためにOpenAI APIを使うのはかなりオーバースペックなので、もうちょっと身の丈にあった手段を探したいところです。
手書きの数字を機械学習してローカルで動かすのがベストな手段な気がしていますが、そもそも車輪の再発明になる気がするので、もうちょっと探してみようと思います。