From 89fb9c47cec3e509fa5efec6636d05df1d05fbd8 Mon Sep 17 00:00:00 2001 From: bleeptrack Date: Wed, 5 Jun 2024 16:57:15 +0200 Subject: [PATCH] adding training animation --- Model.py | 14 +++- server.py | 17 ++++- static/PaperCanvas.js | 134 +++++++++++++++++++++++++++++++++++++++ static/PatternTrainer.js | 4 ++ 4 files changed, 163 insertions(+), 6 deletions(-) diff --git a/Model.py b/Model.py index e8083c1..fcdd721 100644 --- a/Model.py +++ b/Model.py @@ -338,8 +338,6 @@ def trainModel(self, progress_callback=None): loss_list = loss_list[1:] print("Epoch:", epoch, "Loss:", running_loss, l1, l2, m, w) - if progress_callback: - progress_callback(l1.item()) if math.isnan(running_loss): die() @@ -355,7 +353,17 @@ def trainModel(self, progress_callback=None): if epoch % 100 == 0: torch.save(self.model.state_dict(), self.model_path) print("saving...") - + + + if progress_callback: + + if epoch % 15 == 0: + print("extracting vectors for animation") + vectors, originpoints = self.extractOriginLineVectors() + + progress_callback(self, vectors) + #else: + # progress_callback(self, l1.item()) torch.save(self.model.state_dict(), self.model_path) diff --git a/server.py b/server.py index 5643b8e..3ae43f7 100644 --- a/server.py +++ b/server.py @@ -9,6 +9,7 @@ import random import numpy as np from pathlib import Path +import torch app = Flask(__name__) app.config['SECRET_KEY'] = 'secret!' @@ -79,9 +80,19 @@ def new_dataset(data): lineTrainer = LineTrainer(data['name']) lineTrainer.trainModel(send_progress) -def send_progress(text): - print("sending progress", text) - emit('progress', text) +def send_progress(trainer, text): + + if isinstance(text, list): + pointlist = [] + for z in text: + print(z) + tensor = trainer.decode_latent_vector(z) + pointlist.append(tensor2Points(tensor)) + #print(pointlist) + emit('progress', {'lines': pointlist} ) + else: + print("sending progress", text) + emit('progress', {'percent':text} ) @socketio.on('generate') def generate(data): diff --git a/static/PaperCanvas.js b/static/PaperCanvas.js index 6580536..8e1e817 100644 --- a/static/PaperCanvas.js +++ b/static/PaperCanvas.js @@ -119,6 +119,140 @@ export class PaperCanvas extends HTMLElement { let l = this.originalLines.pop() l.remove() } + + drawLine(baseLine, color, smoothing) { + let points + if (Array.isArray(baseLine)) { + points = baseLine + } else { + points = baseLine.points + } + + let path = new Path({segments: points}) + path.strokeColor = color + path.pivot = path.firstSegment.point + + console.log("base", baseLine.position) + + if (!Array.isArray(baseLine)) { + path.position = new Point(baseLine.position.x * this.config["max_dist"], baseLine.position.y * this.config["max_dist"] ) + console.log("pos", path.position) + path.scale(baseLine.scale) + path.rotate(baseLine.rotation * 360) + } else { + console.error("no normapization info", baseLine) + } + + if(baseLine.reference_id){ + console.log("moving", originalLines[baseLine.reference_id].firstSegment.point) + path.translate(originalLines[baseLine.reference_id].firstSegment.point) + }else{ + path.translate(view.center) + } + + this.processLine(path) + + if (smoothing) { + path.simplify() + } + return path + } + + processLine(path) { + let [segmentedPath, scale, angle] = this.createSegments(path) + path.scale(scale, path.firstSegment.point) + path.rotate(angle*360, path.firstSegment.point) + console.log(scale, angle) + + let points = this.segments2points(segmentedPath) + //let group = drawLine(points, "red") + //pointlist.push(points) + this.linelist.push({ + points: points, + scale: scale, + rotation: angle, + }) + this.originalLines.push(path) + + } + + createSegments(path) { + //scale up to normalized size + let largeDir = Math.max(path.bounds.width, path.bounds.height) + let baseSize = this.config["stroke_normalizing_size"] + path.scale(baseSize/largeDir, path.firstSegment.point) + let scale = largeDir/baseSize + + let currAngle = path.lastSegment.point.subtract( + path.firstSegment.point + ).angle + 180 + + let angle = currAngle/360 + path.rotate(-currAngle, path.firstSegment.point) + + + let segmentedPath = new Path() + + let dist = path.length / (this.config.nrPoints - 1) + for (let i = 0; i < this.config.nrPoints - 1; i++) { + let p = path.getPointAt(dist * i).round() + segmentedPath.addSegment(p) + } + segmentedPath.addSegment(path.lastSegment.point.round()) + + return [segmentedPath, scale, angle] + } + + segments2points(path) { + return path.segments.map((seg) => { + return {x: seg.point.x, y: seg.point.y} + }) + } + + trainingEpoch(data){ + if(!this.trainingLines){ + this.trainingLines = [] + this.animLines = [] + } + let iteration = [] + for(let [idx,line] of Object.entries(data.lines)){ + let paperline = this.drawLine(line, "grey") + paperline.strokeWidth = 5 + paperline.opacity = 0.5 + paperline.position = this.originalLines[idx].firstSegment.point + paperline.scale(this.linelist[idx].scale, paperline.firstSegment.point) + paperline.rotate(this.linelist[idx].rotation*360, paperline.firstSegment.point) + paperline.remove() + + iteration.push(paperline) + if(this.animLines.length <= idx){ + let animLine = new Path() + animLine.strokeColor = "blue" + animLine.strokeWidth = 8 + animLine.opacity = 0.7 + animLine.strokeCap = 'round' + animLine.strokeJoin = 'round' + this.animLines.push( animLine ) + console.log("push", animLine) + } + + if(this.trainingLines[this.trainingLines.length -1]){ + let lastItem = this.trainingLines[this.trainingLines.length -1][idx] + //let animPath = lastItem.clone() + //animPath.insertAbove(this.originalLines[idx]) + //animPath.strokeColor = "blue" + + this.animLines[idx].tween(1000).onUpdate = (event) => { + this.animLines[idx].interpolate(lastItem, paperline, event.factor) + console.log("tweeeeeen") + } + + } + + + } + this.trainingLines.push(iteration) + } } diff --git a/static/PatternTrainer.js b/static/PatternTrainer.js index 23fe015..409f096 100644 --- a/static/PatternTrainer.js +++ b/static/PatternTrainer.js @@ -18,6 +18,10 @@ export class PatternTrainer extends HTMLElement { this.socket.on("progress", (text) => { console.log("progress received", text) + if(text.lines){ + this.canvas.trainingEpoch(text) + + } })