Skip to content

Commit

Permalink
feat: add live proxy
Browse files Browse the repository at this point in the history
  • Loading branch information
Neemias Almeida committed Jan 5, 2024
1 parent d25a6af commit 83a34bf
Show file tree
Hide file tree
Showing 7 changed files with 438 additions and 3 deletions.
6 changes: 6 additions & 0 deletions air_example.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,9 @@ clean_on_exit = true
[screen]
clear_on_rebuild = true
keep_scroll = true

# Enable live-reloading on the browser. This is useful when developing UI applications.
[proxy]
enabled = true
proxy_port = 8090
app_port = 8080
12 changes: 9 additions & 3 deletions runner/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ type Config struct {
Log cfgLog `toml:"log"`
Misc cfgMisc `toml:"misc"`
Screen cfgScreen `toml:"screen"`
Proxy cfgProxy `toml:"proxy"`
}

type cfgBuild struct {
Expand Down Expand Up @@ -96,6 +97,12 @@ type cfgScreen struct {
KeepScroll bool `toml:"keep_scroll"`
}

type cfgProxy struct {
Enabled bool `toml:"enabled"`
Port int `toml:"proxy_port"`
AppPort int `toml:"app_port"`
}

type sliceTransformer struct{}

func (t sliceTransformer) Transformer(typ reflect.Type) func(dst, src reflect.Value) error {
Expand Down Expand Up @@ -350,10 +357,9 @@ func (c *Config) killDelay() time.Duration {
// interpret as milliseconds if less than the value of 1 millisecond
if c.Build.KillDelay < time.Millisecond {
return c.Build.KillDelay * time.Millisecond
} else {
// normalize kill delay to milliseconds
return time.Duration(c.Build.KillDelay.Milliseconds()) * time.Millisecond
}
// normalize kill delay to milliseconds
return time.Duration(c.Build.KillDelay.Milliseconds()) * time.Millisecond
}

func (c *Config) binPath() string {
Expand Down
20 changes: 20 additions & 0 deletions runner/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
// Engine ...
type Engine struct {
config *Config
proxy *Proxy
logger *logger
watcher filenotify.FileWatcher
debugMode bool
Expand Down Expand Up @@ -48,6 +49,7 @@ func NewEngineWithConfig(cfg *Config, debugMode bool) (*Engine, error) {
}
e := Engine{
config: cfg,
proxy: NewProxy(&cfg.Proxy),
logger: logger,
watcher: watcher,
debugMode: debugMode,
Expand Down Expand Up @@ -310,6 +312,13 @@ func (e *Engine) isModified(filename string) bool {

// Endless loop and never return
func (e *Engine) start() {
if e.config.Proxy.Enabled {
go func() {
e.mainLog("Proxy server listening on %s", e.proxy.server.Addr)
e.proxy.Run()
}()
}

e.running = true
firstRunCh := make(chan bool, 1)
firstRunCh <- true
Expand Down Expand Up @@ -347,6 +356,9 @@ func (e *Engine) start() {
}
}

if e.config.Proxy.Enabled {
e.proxy.Reload()
}
e.mainLog("%s has changed", e.config.rel(filename))
case <-firstRunCh:
// go down
Expand Down Expand Up @@ -535,6 +547,9 @@ func (e *Engine) runBin() error {
cmd, stdout, stderr, _ := e.startCmd(command)
processExit := make(chan struct{})
e.mainDebug("running process pid %v", cmd.Process.Pid)
if e.proxy.config.Enabled {
e.proxy.Reload()
}

wg.Add(1)
atomic.AddUint64(&e.round, 1)
Expand Down Expand Up @@ -579,6 +594,11 @@ func (e *Engine) cleanup() {
e.mainLog("cleaning...")
defer e.mainLog("see you again~")

if e.config.Proxy.Enabled {
e.mainDebug("powering down the proxy...")
e.proxy.Stop()
}

e.withLock(func() {
close(e.binStopCh)
e.binStopCh = make(chan bool)
Expand Down
155 changes: 155 additions & 0 deletions runner/proxy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
package runner

import (
"bytes"
"errors"
"fmt"
"io"
"log"
"net/http"
"strconv"
"strings"
"syscall"
"time"
)

type Reloader interface {
AddSubscriber() *Subscriber
RemoveSubscriber(id int)
Reload()
Stop()
}

type Proxy struct {
server *http.Server
config *cfgProxy
stream Reloader
}

func NewProxy(cfg *cfgProxy) *Proxy {
p := &Proxy{
config: cfg,
server: &http.Server{
Addr: fmt.Sprintf("localhost:%d", cfg.Port),
},
stream: NewProxyStream(),
}
return p
}

func (p *Proxy) Run() {
http.HandleFunc("/", p.proxyHandler)
http.HandleFunc("/internal/reload", p.reloadHandler)
log.Fatal(p.server.ListenAndServe())
}

func (p *Proxy) Stop() {
p.server.Close()
p.stream.Stop()
}

func (p *Proxy) Reload() {
p.stream.Reload()
}

func (p *Proxy) injectLiveReload(origURL string, respBody io.ReadCloser) string {
buf := new(bytes.Buffer)
if _, err := buf.ReadFrom(respBody); err != nil {
panic("failed to convert request body to bytes buffer")
}
s := buf.String()

body := strings.LastIndex(s, "</body>")
if body == -1 {
panic("invalid html")
}
script := `
<script>
const parser = new DOMParser();
const proxyURL = "http://localhost:%d";
new EventSource(proxyURL + "/internal/reload").onmessage = () => {
fetch(proxyURL + "%s").then(res => res.text()).then(resStr => {
const newPage = parser.parseFromString(resStr, "text/html");
document.replaceChild(newPage.documentElement, document.documentElement);
});
};
</script>
`
parsedScript := fmt.Sprintf(script, p.config.Port, origURL)

s = s[:body] + parsedScript + s[body:]
return s
}

func (p *Proxy) proxyHandler(w http.ResponseWriter, r *http.Request) {
url := fmt.Sprintf("http://localhost:%d", p.config.AppPort)
req, err := http.NewRequest(r.Method, url, r.Body)
if err != nil {
panic(err)
}
req.Header.Set("X-Forwarded-For", r.RemoteAddr)

client := &http.Client{}
var resp *http.Response
for i := 0; i < 10; i++ {
resp, err = client.Do(req)
if err == nil {
break
}
if !errors.Is(err, syscall.ECONNREFUSED) {
log.Fatalf("failed to call http://localhost:%d, err: %+v", p.config.AppPort, err)
}
time.Sleep(100 * time.Millisecond)
}
defer resp.Body.Close()

// copy all headers except Content-Length
for k, vv := range resp.Header {
for _, v := range vv {
if k == "Content-Length" {
continue
}
w.Header().Add(k, v)
}
}
w.WriteHeader(resp.StatusCode)

if strings.Contains(resp.Header.Get("Content-Type"), "text/html") {
s := p.injectLiveReload(r.URL.String(), resp.Body)
w.Header().Set("Content-Length", strconv.Itoa((len([]byte(s)))))
if _, err := io.WriteString(w, s); err != nil {
panic("failed to write injected payload")
}
} else {
w.Header().Set("Content-Length", resp.Header.Get("Content-Length"))
if _, err := io.Copy(w, resp.Body); err != nil {
panic("failed to write normal payload")
}
}
}

func (p *Proxy) reloadHandler(w http.ResponseWriter, r *http.Request) {
flusher, err := w.(http.Flusher)
if !err {
http.Error(w, "Streaming unsupported!", http.StatusInternalServerError)
return
}

w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")

sub := p.stream.AddSubscriber()
go func() {
<-r.Context().Done()
p.stream.RemoveSubscriber(sub.id)
}()

w.WriteHeader(http.StatusOK)
flusher.Flush()

for range sub.reloadCh {
fmt.Fprintf(w, "data: reload\n\n")
flusher.Flush()
}
}
50 changes: 50 additions & 0 deletions runner/proxy_stream.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package runner

import (
"sync"
)

type ProxyStream struct {
sync.Mutex
subscribers map[int]*Subscriber
count int
}

type Subscriber struct {
id int
reloadCh chan struct{}
}

func NewProxyStream() *ProxyStream {
return &ProxyStream{subscribers: make(map[int]*Subscriber)}
}

func (stream *ProxyStream) Stop() {
for id := range stream.subscribers {
stream.RemoveSubscriber(id)
}
stream.count = 0
}

func (stream *ProxyStream) AddSubscriber() *Subscriber {
stream.Lock()
defer stream.Unlock()
stream.count++

sub := &Subscriber{id: stream.count, reloadCh: make(chan struct{})}
stream.subscribers[stream.count] = sub
return sub
}

func (stream *ProxyStream) RemoveSubscriber(id int) {
stream.Lock()
defer stream.Unlock()
close(stream.subscribers[id].reloadCh)
delete(stream.subscribers, id)
}

func (stream *ProxyStream) Reload() {
for _, sub := range stream.subscribers {
sub.reloadCh <- struct{}{}
}
}
66 changes: 66 additions & 0 deletions runner/proxy_stream_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package runner

import (
"sync"
"testing"
)

func find(s map[int]*Subscriber, id int) bool {
for _, sub := range s {
if sub.id == id {
return true
}
}
return false
}

func TestProxyStream(t *testing.T) {
stream := NewProxyStream()

var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
_ = stream.AddSubscriber()
}(i)
}
wg.Wait()

if got, exp := len(stream.subscribers), 10; got != exp {
t.Errorf("expected %d but got %d", exp, got)
}

go func() {
stream.Reload()
}()

reloadCount := 0
for _, sub := range stream.subscribers {
wg.Add(1)
go func(sub *Subscriber) {
defer wg.Done()
<-sub.reloadCh
reloadCount++
}(sub)
}
wg.Wait()

if got, exp := reloadCount, 10; got != exp {
t.Errorf("expected %d but got %d", exp, got)
}

stream.RemoveSubscriber(2)
stream.AddSubscriber()
if got, exp := find(stream.subscribers, 2), false; got != exp {
t.Errorf("expected subscriber found to be %t but got %t", exp, got)
}
if got, exp := find(stream.subscribers, 11), true; got != exp {
t.Errorf("expected subscriber found to be %t but got %t", exp, got)
}

stream.Stop()
if got, exp := len(stream.subscribers), 0; got != exp {
t.Errorf("expected %d but got %d", exp, got)
}
}
Loading

0 comments on commit 83a34bf

Please sign in to comment.