Skip to content

Commit

Permalink
Merge pull request #9 from replicate/sparse-file-download
Browse files Browse the repository at this point in the history
WIP: use less memory by downloading to sparse file
  • Loading branch information
philandstuff authored Aug 8, 2023
2 parents 4d2003f + be1a93e commit 22f5552
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 20 deletions.
7 changes: 6 additions & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@ jobs:
go-version-file: go.mod
- run: script/build
- uses: ncipollo/release-action@v1
if: ${{ startsWith(github.ref, 'refs/tags') }}
if: github.ref_type=='tag' && !contains(github.ref_name, '-')
with:
artifacts: "pget"
- uses: ncipollo/release-action@v1
if: github.ref_type=='tag' && contains(github.ref_name, '-')
with:
artifacts: "pget"
prerelease: true

75 changes: 56 additions & 19 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@ package main

import (
"archive/tar"
"bytes"
"flag"
"fmt"
"io"
"io/ioutil"
"net/http"
"os"
"path/filepath"
Expand Down Expand Up @@ -34,17 +32,26 @@ func getRemoteFileSize(url string) (int64, error) {
return fileSize, nil
}

func downloadFileToBuffer(url string, concurrency int) (*bytes.Buffer, error) {
func downloadFile(url string, destFile *os.File, concurrency int) error {
fileSize, err := getRemoteFileSize(url)
if err != nil {
return nil, err
return err
}

if err != nil {
fmt.Printf("Error creating file: %v\n", err)
os.Exit(1)
}

err = destFile.Truncate(fileSize)
if err != nil {
return err
}

chunkSize := fileSize / int64(concurrency)
var wg sync.WaitGroup
wg.Add(concurrency)

data := make([]byte, fileSize)
errc := make(chan error, concurrency)
startTime := time.Now()

Expand All @@ -58,6 +65,11 @@ func downloadFileToBuffer(url string, concurrency int) (*bytes.Buffer, error) {

go func(start, end int64) {
defer wg.Done()
fh, err := os.OpenFile(destFile.Name(), os.O_RDWR, 0644)
if err != nil {
errc <- fmt.Errorf("Failed to reopen file: %v", err)
}
defer fh.Close()

retries := 5
for retries > 0 {
Expand Down Expand Up @@ -85,14 +97,22 @@ func downloadFileToBuffer(url string, concurrency int) (*bytes.Buffer, error) {
}
defer resp.Body.Close()

n, err := io.ReadFull(resp.Body, data[start:end+1])
_, err = fh.Seek(start, 0)
if err != nil {
fmt.Printf("Error seeking in file: %v\n", err)
retries--
time.Sleep(time.Millisecond * 100) // wait 100 milliseconds before retrying
continue
}

n, err := io.CopyN(fh, resp.Body, end-start+1)
if err != nil && err != io.EOF {
fmt.Printf("Error reading response: %v\n", err)
retries--
time.Sleep(time.Millisecond * 100) // wait 100 milliseconds before retrying
continue
}
if n != int(end-start+1) {
if n != end-start+1 {
fmt.Printf("Downloaded %d bytes instead of %d\n", n, end-start+1)
retries--
time.Sleep(time.Millisecond * 100) // wait 100 milliseconds before retrying
Expand All @@ -112,20 +132,19 @@ func downloadFileToBuffer(url string, concurrency int) (*bytes.Buffer, error) {
close(errc) // close the error channel
for err := range errc {
if err != nil {
return nil, err // return the first error we encounter
return err // return the first error we encounter
}
}
elapsed := time.Since(startTime).Seconds()
througput := humanize.Bytes(uint64(float64(fileSize) / elapsed))
fmt.Printf("Downloaded %s bytes in %.3fs (%s/s)\n", humanize.Bytes(uint64(fileSize)), elapsed, througput)

buffer := bytes.NewBuffer(data)
return buffer, nil
return nil
}

func extractTarFile(buffer *bytes.Buffer, destDir string) error {
func extractTarFile(input io.Reader, destDir string) error {
startTime := time.Now()
tarReader := tar.NewReader(buffer)
tarReader := tar.NewReader(input)

for {
header, err := tarReader.Next()
Expand Down Expand Up @@ -182,7 +201,7 @@ func main() {
// check required positional arguments
args := flag.Args()
if len(args) < 2 {
fmt.Println("Usage: pcurl <url> <dest> [-c concurrency] [-x]")
fmt.Println("Usage: pcurl [-c concurrency] [-x] <url> <dest>")
os.Exit(1)
}

Expand All @@ -195,26 +214,44 @@ func main() {
os.Exit(1)
}

buffer, err := downloadFileToBuffer(url, *concurrency)
// create tempfile for downloading to
cwd, err := os.Getwd()
if err != nil {
fmt.Printf("Error getting cwd: %v\n", err)
os.Exit(1)
}
destTemp, err := os.CreateTemp(cwd, dest+".partial")
if err != nil {
fmt.Printf("Failed to create temp file: %v\n", err)
os.Exit(1)
}

err = downloadFile(url, destTemp, *concurrency)
if err != nil {
fmt.Printf("Error downloading file: %v\n", err)
os.Exit(1)
}

// extract the tar file if the -x flag was provided
if *extract {
err = extractTarFile(buffer, dest)
_, err = destTemp.Seek(0, 0)
if err != nil {
fmt.Printf("Error extracting tar file: %v\n", err)
os.Exit(1)
}
err = extractTarFile(destTemp, dest)
if err != nil {
fmt.Printf("Error extracting tar file: %v\n", err)
os.Exit(1)
}
destTemp.Close()
os.Remove(destTemp.Name())
} else {
// if -x flag is not set, save the buffer to a file
err = ioutil.WriteFile(dest, buffer.Bytes(), 0644)
// move destTemp to dest
err = os.Rename(destTemp.Name(), dest)
if err != nil {
fmt.Printf("Error writing file: %v\n", err)
fmt.Printf("Error moving downloaded file to correct location: %v\n", err)
os.Exit(1)
}
}

}

0 comments on commit 22f5552

Please sign in to comment.