Skip to content

Commit

Permalink
adding training animation
Browse files Browse the repository at this point in the history
  • Loading branch information
bleeptrack committed Jun 5, 2024
1 parent 97bc07d commit 89fb9c4
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 6 deletions.
14 changes: 11 additions & 3 deletions Model.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,8 +338,6 @@ def trainModel(self, progress_callback=None):
loss_list = loss_list[1:]

print("Epoch:", epoch, "Loss:", running_loss, l1, l2, m, w)
if progress_callback:
progress_callback(l1.item())

if math.isnan(running_loss):
die()
Expand All @@ -355,7 +353,17 @@ def trainModel(self, progress_callback=None):
if epoch % 100 == 0:
torch.save(self.model.state_dict(), self.model_path)
print("saving...")



if progress_callback:

if epoch % 15 == 0:
print("extracting vectors for animation")
vectors, originpoints = self.extractOriginLineVectors()

progress_callback(self, vectors)
#else:
# progress_callback(self, l1.item())

torch.save(self.model.state_dict(), self.model_path)

Expand Down
17 changes: 14 additions & 3 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import random
import numpy as np
from pathlib import Path
import torch

app = Flask(__name__)
app.config['SECRET_KEY'] = 'secret!'
Expand Down Expand Up @@ -79,9 +80,19 @@ def new_dataset(data):
lineTrainer = LineTrainer(data['name'])
lineTrainer.trainModel(send_progress)

def send_progress(text):
print("sending progress", text)
emit('progress', text)
def send_progress(trainer, text):

if isinstance(text, list):
pointlist = []
for z in text:
print(z)
tensor = trainer.decode_latent_vector(z)
pointlist.append(tensor2Points(tensor))
#print(pointlist)
emit('progress', {'lines': pointlist} )
else:
print("sending progress", text)
emit('progress', {'percent':text} )

@socketio.on('generate')
def generate(data):
Expand Down
134 changes: 134 additions & 0 deletions static/PaperCanvas.js
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,140 @@ export class PaperCanvas extends HTMLElement {
let l = this.originalLines.pop()
l.remove()
}

drawLine(baseLine, color, smoothing) {
let points
if (Array.isArray(baseLine)) {
points = baseLine
} else {
points = baseLine.points
}

let path = new Path({segments: points})
path.strokeColor = color
path.pivot = path.firstSegment.point

console.log("base", baseLine.position)

if (!Array.isArray(baseLine)) {
path.position = new Point(baseLine.position.x * this.config["max_dist"], baseLine.position.y * this.config["max_dist"] )
console.log("pos", path.position)
path.scale(baseLine.scale)
path.rotate(baseLine.rotation * 360)
} else {
console.error("no normapization info", baseLine)
}

if(baseLine.reference_id){
console.log("moving", originalLines[baseLine.reference_id].firstSegment.point)
path.translate(originalLines[baseLine.reference_id].firstSegment.point)
}else{
path.translate(view.center)
}

this.processLine(path)

if (smoothing) {
path.simplify()
}
return path
}

processLine(path) {
let [segmentedPath, scale, angle] = this.createSegments(path)
path.scale(scale, path.firstSegment.point)
path.rotate(angle*360, path.firstSegment.point)
console.log(scale, angle)

let points = this.segments2points(segmentedPath)
//let group = drawLine(points, "red")
//pointlist.push(points)
this.linelist.push({
points: points,
scale: scale,
rotation: angle,
})
this.originalLines.push(path)

}

createSegments(path) {
//scale up to normalized size
let largeDir = Math.max(path.bounds.width, path.bounds.height)
let baseSize = this.config["stroke_normalizing_size"]
path.scale(baseSize/largeDir, path.firstSegment.point)
let scale = largeDir/baseSize

let currAngle = path.lastSegment.point.subtract(
path.firstSegment.point
).angle + 180

let angle = currAngle/360
path.rotate(-currAngle, path.firstSegment.point)


let segmentedPath = new Path()

let dist = path.length / (this.config.nrPoints - 1)
for (let i = 0; i < this.config.nrPoints - 1; i++) {
let p = path.getPointAt(dist * i).round()
segmentedPath.addSegment(p)
}
segmentedPath.addSegment(path.lastSegment.point.round())

return [segmentedPath, scale, angle]
}

segments2points(path) {
return path.segments.map((seg) => {
return {x: seg.point.x, y: seg.point.y}
})
}

trainingEpoch(data){
if(!this.trainingLines){
this.trainingLines = []
this.animLines = []
}
let iteration = []
for(let [idx,line] of Object.entries(data.lines)){
let paperline = this.drawLine(line, "grey")
paperline.strokeWidth = 5
paperline.opacity = 0.5
paperline.position = this.originalLines[idx].firstSegment.point
paperline.scale(this.linelist[idx].scale, paperline.firstSegment.point)
paperline.rotate(this.linelist[idx].rotation*360, paperline.firstSegment.point)
paperline.remove()

iteration.push(paperline)
if(this.animLines.length <= idx){
let animLine = new Path()
animLine.strokeColor = "blue"
animLine.strokeWidth = 8
animLine.opacity = 0.7
animLine.strokeCap = 'round'
animLine.strokeJoin = 'round'
this.animLines.push( animLine )
console.log("push", animLine)
}

if(this.trainingLines[this.trainingLines.length -1]){
let lastItem = this.trainingLines[this.trainingLines.length -1][idx]
//let animPath = lastItem.clone()
//animPath.insertAbove(this.originalLines[idx])
//animPath.strokeColor = "blue"

this.animLines[idx].tween(1000).onUpdate = (event) => {
this.animLines[idx].interpolate(lastItem, paperline, event.factor)
console.log("tweeeeeen")
}

}


}
this.trainingLines.push(iteration)
}

}

Expand Down
4 changes: 4 additions & 0 deletions static/PatternTrainer.js
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ export class PatternTrainer extends HTMLElement {

this.socket.on("progress", (text) => {
console.log("progress received", text)
if(text.lines){
this.canvas.trainingEpoch(text)

}
})


Expand Down

0 comments on commit 89fb9c4

Please sign in to comment.