ひょんなことから、ブラウザに書いた手書きの数字の認識を目指すことにしました。
今回は、OpenAI APIを試してみます。
前回同様、手書きで数字を書く仕組みは、かつて作った、Canvasに線を引けるWebサイトを流用します。
また、OpenAIのAPIに画像の内容をテキスト化してもらうプロトタイプも、かつて作っているので、今回はこれらをうまいこと組み合わせればすぐに完成することでしょう。
今回はNext.jsのPage Routerをつかって実装しました。
ソースコード(抜粋)
src/pages/api/openai.ts
import OpenAI from 'openai'; import type { NextApiRequest, NextApiResponse } from 'next'; type Data = { state: string; data?: any; } export default async function handler( req: NextApiRequest, res: NextApiResponse<Data> ) { if (req.method === 'POST') { const openai = new OpenAI({ apiKey: process.env.OPENAI_API_KEY }); const chatCompletion = await openai.chat.completions.create({ model: 'gpt-4-turbo', messages: [{ role: 'user', content: [{ type: 'text', text: '画像には手書きで数字が書いてあります。書いてある数字を数字のみで回答してください。数字以外の文字が書かれていた際や、読み取れない場合は?のみを返してください。' },{ type: 'image_url', image_url: { url: `${ req.body.base64 }` } }] }], max_tokens: 300 }); res.status(200).json({ state: 'ok', data: chatCompletion.choices[0] }); } }
前回からの差分は、
- モデルをgpt-4-vision-previewからgpt-4-turboに変更
- プロンプトの変更
- image_urlの渡し方を変更
の3点です。
src/components/templates/IndexPageTemplate.tsx
import axios from 'axios'; import styled from 'styled-components'; import { useEffect, useRef, useState } from 'react'; import { setup } from '@/scripts/canvas.js'; export function IndexPageTemplate() { const canvasRef = useRef<HTMLCanvasElement>(null); const [ text, setText ] = useState<string>(''); const [ isLoading, setIsLoading ] = useState<boolean>(false); const [ events, setEvents ] = useState({ clear: () => {} }); useEffect(() => { const { clear } = setup(canvasRef.current); setEvents({ clear }); }, []); async function handleClickBtnSubmit() { const base64 = canvasRef.current?.toDataURL('image/png'); if (base64) { setIsLoading(true); const { data } = await axios.post('/api/openai', { base64 }); setText(data.data.message.content); setIsLoading(false); } } return ( <Wrapper data-is-loading={ isLoading }> <canvas ref={ canvasRef } /> <button onClick={ events.clear }>clear</button> <button onClick={ handleClickBtnSubmit }>check</button> <p>{ text }</p> <div className="overlay"> <svg width="25" height="24" viewBox="0 0 25 24" fill="none" xmlns="http://www.w3.org/2000/svg"> <circle opacity="0.5" cx="12.5" cy="12" r="10" stroke="white" strokeWidth="4"/> <path d="M22.5 12C22.5 6.47715 18.0228 2 12.5 2" stroke="white" strokeWidth="4" strokeLinecap="round"> <animateTransform attributeName="transform" type="rotate" repeatCount="indefinite" dur=".8s" from="0 12 12" to="360 12 12" /> </path> </svg> </div> </Wrapper> ); } const Wrapper = styled.div` padding: 16px; canvas { display: block; border: 1px solid #000; max-width: 100%; height: auto; } button { + button { margin-left: 4px; } } .overlay { position: fixed; top: 0; bottom: 0; left: 0; right: 0; background: rgba(0, 0, 0, 0.4); opacity: 0; pointer-events: none; } svg { position: fixed; top: 50%; left: 50%; transform: translate(-50%, -50%); } &[data-is-loading='true'] { pointer-events: none; .overlay { opacity: 1; } } `;
レガシーな記法の手書きの実装(canvas.js)を強引に組み込みました。
src/scripts/canvas.js
// NameSpace const ns = {}; // Util (() => { ns.Util = ns.Util || {}; function buildGetAve(opt_maxLength) { const array = []; const maxLength = opt_maxLength || 10; let index = 0; function _ave() { const length = array.length; let sum = 0; for (let i = 0; i < length; ++i) { sum += array[i]; } return sum / length; } function getAve(val) { array[index] = val; index = (index + 1) % maxLength; return _ave(); } return getAve; } ns.Util = { buildGetAve }; })(); // EventDispatcher (() => { function EventDispatcher() { this._events = {}; } EventDispatcher.prototype.hasEventListener = function(eventName) { return !!this._events[eventName]; }; EventDispatcher.prototype.addEventListener = function(eventName, callback) { console.log(this); if (this.hasEventListener(eventName)) { const events = this._events[eventName]; const length = events.length; for (let i = 0; i < length; i++) { if (events[i] === callback) { return; } } events.push(callback); } else { this._events[eventName] = [callback]; } return this; }; EventDispatcher.prototype.removeEventListener = function(eventName, callback) { if (!this.hasEventListener(eventName)) { return; } else { const events = this._events[eventName]; let i = events.length; let index; while (i--) { if (events[i] === callback) { index = i; } } if (typeof index === 'number') { events.splice(index, 1); } } return this; }; EventDispatcher.prototype.fireEvent = function(eventName, opt_this, opt_arg) { if (!this.hasEventListener(eventName)) { return; } else { const events = this._events[eventName]; const copyEvents = [ ...events ]; const arg = [ ...arguments ]; const length = events.length; // eventNameとopt_thisを削除 arg.splice(0, 2); for (let i = 0; i < length; i++) { copyEvents[i].apply(opt_this || this, arg); } } }; ns.EventDispatcher = EventDispatcher; })(); // Throttle (() => { function Throttle(opt_interval, opt_callback) { this._timer = null; this._lastEventTime = 0; this._interval = opt_interval || 500; this._callback = opt_callback || function() {}; } Throttle.prototype.setInterval = function(ms) { this._interval = ms; }; Throttle.prototype.addEvent = function(fn) { this._callback = fn; }; Throttle.prototype.fireEvent = function(opt_arg) { const _this = this; const currentTime = Date.now(); const timerInterval = this._interval / 10; clearTimeout(this.timer); if (currentTime - _this._lastEventTime > _this._interval) { _fire(); } else { _this.timer = setTimeout(_fire, timerInterval); } function _fire() { _this._callback.call(_this, opt_arg || null); _this._lastEventTime = currentTime; } }; ns.Throttle = Throttle; })(); // Point (() => { function Point(x, y, opt_size) { const _this = this; _init(); function _init() { ns.EventDispatcher.call(_this); } _this.x = x; _this.y = y; _this.size = opt_size || 1; } Point.prototype = new ns.EventDispatcher(); Point.prototype.constructor = Point; Point.prototype.setSize = function(size) { this.size = size; }; Point.prototype.draw = function(ctx, opt_size) { const size = opt_size || this.size || 1; ctx.save(); ctx.beginPath(); ctx.arc(this.x, this.y, size, 0, Math.PI * 2, false); ctx.fill(); ctx.moveTo(this.x, this.y); ctx.restore(); return this; }; Point.getDistance = function(pointA, pointB) { return Math.sqrt(Math.pow(pointB.x - pointA.x, 2) + Math.pow(pointB.y - pointA.y, 2)); }; ns.Point = Point; })(); // Line (() => { let Point = ns.Point; function Line(opt_point) { const _this = this; const pointList = []; _init(); function _init() { ns.EventDispatcher.call(_this); _this.pointList = pointList; if (opt_point) { _this.push(opt_point); } } } Line.prototype = new ns.EventDispatcher(); Line.prototype.constructor = Line; Line.prototype.push = function(pointModel) { this.pointList.push(pointModel); }; Line.prototype.drawLine = function(ctx) { const pointList = this.pointList; const length = pointList.length; ctx.save(); ctx.lineCap = 'round'; ctx.lineJoin = 'round'; if (length > 1) { for (let i = 1; i < length; ++i) { ctx.beginPath(); ctx.moveTo(pointList[i - 1].x, pointList[i - 1].y); ctx.lineTo(pointList[i].x, pointList[i].y); ctx.lineWidth = pointList[i].size; ctx.stroke(); } } else { pointList[0].draw(ctx); } ctx.restore(); }; Line.prototype.drawQuadraticCurve = function(ctx) { const pointList = this.pointList; const length = pointList.length; const quadraticPointList = []; let lastIndex = 0; ctx.save(); ctx.lineCap = 'round'; ctx.lineJoin = 'round'; if (length > 1) { quadraticPointList[lastIndex] = pointList[0]; ctx.beginPath(); ctx.moveTo(quadraticPointList[0].x, quadraticPointList[0].y); for (let i = 1; i < length; ++i) { quadraticPointList[++lastIndex] = new Point( (quadraticPointList[lastIndex - 1].x + pointList[i].x) / 2, (quadraticPointList[lastIndex - 1].y + pointList[i].y) / 2 ); quadraticPointList[++lastIndex] = (pointList[i]); ctx.quadraticCurveTo( quadraticPointList[i * 2 - 2].x, quadraticPointList[i * 2 - 2].y, quadraticPointList[i * 2 - 1].x, quadraticPointList[i * 2 - 1].y ); ctx.lineWidth = pointList[i].size; ctx.stroke(); ctx.beginPath(); ctx.moveTo(quadraticPointList[i * 2 - 1].x, quadraticPointList[i * 2 - 1].y); } ctx.lineTo(quadraticPointList[lastIndex].x, quadraticPointList[lastIndex].y); ctx.stroke(); } else { pointList[lastIndex].draw(ctx); } ctx.restore(); }; ns.Line = Line; })(); // LineManager (() => { let instance; function getInstance() { if (!instance) { instance = new LineManager(); } return instance; } function LineManager() { const _this = this; const lineList = []; _init(); function _init() { ns.EventDispatcher.call(_this); } _this.lineList = lineList; } LineManager.prototype = new ns.EventDispatcher(); LineManager.prototype.constructor = LineManager; LineManager.prototype.push = function(lineModel) { this.lineList.push(lineModel); }; LineManager.prototype.addPoint = function(pointModel, opt_index) { const index = opt_index || this.lineList.length - 1; this.lineList[index].push(pointModel); }; LineManager.prototype.drawQuadraticCurve = function(ctx) { const lineList = this.lineList; const length = lineList.length; if (length) { lineList[length - 1].drawQuadraticCurve(ctx); } }; LineManager.prototype.clear = function() { let _this = this; _this.lineList = []; }; ns.LineManager = { getInstance: getInstance }; })(); export function setup(canvas) { // ❶ ❷ const doc = document; const SIZE = 10; const _getAve = ns.Util.buildGetAve(10); function _getSize(l) { const MAX_DISTANCE = 100; let width; if (l > MAX_DISTANCE) { width = 1; } width = (MAX_DISTANCE - l) * SIZE / MAX_DISTANCE; return _getAve(width); } const lineManager = ns.LineManager.getInstance(); const ctx = canvas.getContext('2d'); const sub = doc.createElement('canvas'); const subCtx = sub.getContext('2d'); const answer = doc.getElementById('answer'); const START_EVENT = 'mousedown'; const MOVE_EVENT = 'mousemove'; const END_EVENT = 'mouseup'; const WIDTH = 200; const HEIGHT = 200; _setup(); async function _setup() { addSketchEventForCanvas(canvas); canvas.width = sub.width = WIDTH; canvas.height = sub.height = HEIGHT; tick(); } function tick() { draw(); requestAnimationFrame(tick); } function draw() { canvas.width = WIDTH; canvas.height = HEIGHT; ctx.drawImage(sub, 0, 0); ctx.save(); ctx.fillStyle = ctx.strokeStyle = 'rgba(50, 50, 50, 1)'; lineManager.drawQuadraticCurve(ctx); ctx.restore(); } function addSketchEventForCanvas(elm) { const throttle = new ns.Throttle(10, handleThrottleMoveEvent); let lastPoint = null; elm.addEventListener(START_EVENT, handleStartEvent, false); doc.addEventListener(END_EVENT, handleEndEvent, false); elm.addEventListener('touchstart', handleStartEvent, false); elm.addEventListener('touchmove', handleMoveEvent, false); doc.addEventListener('touchend', handleEndEvent, false); function handleThrottleMoveEvent(evt) { const { offsetX, offsetY } = getOffset(evt); const currentPoint = new ns.Point(offsetX, offsetY); currentPoint.setSize(_getSize(ns.Point.getDistance(lastPoint, currentPoint))); lastPoint = currentPoint; lineManager.addPoint(currentPoint); } function handleStartEvent(evt) { const { offsetX, offsetY } = getOffset(evt); lastPoint = new ns.Point(offsetX, offsetY, SIZE / 2); lineManager.push(new ns.Line(lastPoint)); elm.addEventListener(MOVE_EVENT, handleMoveEvent, false); } function handleMoveEvent(evt) { throttle.fireEvent(evt); } function handleEndEvent(evt) { sub.width = WIDTH; sub.height = HEIGHT; subCtx.drawImage(canvas, 0, 0); elm.removeEventListener(MOVE_EVENT, handleMoveEvent, false); lastPoint = null; } function getOffset(evt) { if (evt.touches) { const rect = evt.target.getBoundingClientRect(); return { offsetX: evt.touches[0].pageX - evt.touches[0].target.offsetLeft - rect.left, offsetY: evt.touches[0].pageY - evt.touches[0].target.offsetTop - rect.top }; } return { offsetX: evt.offsetX, offsetY: evt.offsetY }; } } return { // ❸ clear: function() { lineManager.clear(); canvas.width = sub.width = WIDTH; canvas.height = sub.height = HEIGHT; } } }
かつて書いたJavaScriptをTypeScriptに変換するのが大変すぎたため、最低限の改修をしてJavaScriptのまま読み込みました。
行った改修は、
- setupのみをexportする
- setup実行時に引数にCanvasを渡せるようにする
- setupを実行した際に、clearを持ったオブジェクトを戻す
です。
DEMO
Open AIの料金が怖いので、今回は動作の様子だけ貼っておきます。
tesseract.jsよりも格段に精度は上がりましたが、そのトレードオフとしてスピードを失いました。
通信を挟んでいるからスピード面はだいぶ不利ですね。
リポジトリ
雑感
そもそも数字の認識のためにOpenAI APIを使うのはかなりオーバースペックなので、もうちょっと身の丈にあった手段を探したいところです。
手書きの数字を機械学習してローカルで動かすのがベストな手段な気がしていますが、そもそも車輪の再発明になる気がするので、もうちょっと探してみようと思います。