diff --git a/internal/server/response.go b/internal/server/response.go index e67eab5..609171d 100644 --- a/internal/server/response.go +++ b/internal/server/response.go @@ -1,7 +1,11 @@ package server import ( + "fmt" + "io" "net/http" + "strconv" + "strings" "github.com/gin-gonic/gin" ) @@ -42,3 +46,97 @@ func InternalError(c *gin.Context, msg string) { func ErrorWithStatus(c *gin.Context, code int, msg string) { c.JSON(code, APIResponse{Success: false, Error: msg}) } + +// ParseRange parses an HTTP Range header (RFC 7233). +// Only single-part ranges are supported: bytes=start-end, bytes=start-, bytes=-suffix. +// Multi-part ranges (bytes=0-100,200-300) return an error. +func ParseRange(rangeHeader string, fileSize int64) (start, end int64, err error) { + if rangeHeader == "" { + return 0, 0, fmt.Errorf("empty range header") + } + + if !strings.HasPrefix(rangeHeader, "bytes=") { + return 0, 0, fmt.Errorf("invalid range unit: %s", rangeHeader) + } + + rangeSpec := strings.TrimPrefix(rangeHeader, "bytes=") + + if strings.Contains(rangeSpec, ",") { + return 0, 0, fmt.Errorf("multi-part ranges are not supported") + } + + rangeSpec = strings.TrimSpace(rangeSpec) + parts := strings.Split(rangeSpec, "-") + if len(parts) != 2 { + return 0, 0, fmt.Errorf("invalid range format: %s", rangeSpec) + } + + if parts[0] == "" { + suffix, parseErr := strconv.ParseInt(parts[1], 10, 64) + if parseErr != nil { + return 0, 0, fmt.Errorf("invalid suffix range: %s", parts[1]) + } + if suffix <= 0 || suffix > fileSize { + return 0, 0, fmt.Errorf("suffix range %d exceeds file size %d", suffix, fileSize) + } + start = fileSize - suffix + end = fileSize - 1 + } else if parts[1] == "" { + start, err = strconv.ParseInt(parts[0], 10, 64) + if err != nil { + return 0, 0, fmt.Errorf("invalid range start: %s", parts[0]) + } + if start >= fileSize { + return 0, 0, fmt.Errorf("range start %d exceeds file size %d", start, fileSize) + } + end = fileSize - 1 + } else { + start, err = strconv.ParseInt(parts[0], 10, 64) + if err != nil { + return 0, 0, fmt.Errorf("invalid range start: %s", parts[0]) + } + end, err = strconv.ParseInt(parts[1], 10, 64) + if err != nil { + return 0, 0, fmt.Errorf("invalid range end: %s", parts[1]) + } + if start > end { + return 0, 0, fmt.Errorf("range start %d > end %d", start, end) + } + if start >= fileSize { + return 0, 0, fmt.Errorf("range start %d exceeds file size %d", start, fileSize) + } + if end >= fileSize { + end = fileSize - 1 + } + } + + return start, end, nil +} + +// StreamFile sends a full file as an HTTP response with proper headers. +func StreamFile(c *gin.Context, reader io.ReadCloser, filename string, fileSize int64, contentType string) { + defer reader.Close() + + c.Header("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, filename)) + c.Header("Content-Type", contentType) + c.Header("Content-Length", strconv.FormatInt(fileSize, 10)) + c.Header("Accept-Ranges", "bytes") + + c.Status(http.StatusOK) + io.Copy(c.Writer, reader) +} + +// StreamRange sends a partial content response (206) for a byte range. +func StreamRange(c *gin.Context, reader io.ReadCloser, start, end, totalSize int64, contentType string) { + defer reader.Close() + + contentLength := end - start + 1 + + c.Header("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, end, totalSize)) + c.Header("Content-Type", contentType) + c.Header("Content-Length", strconv.FormatInt(contentLength, 10)) + c.Header("Accept-Ranges", "bytes") + + c.Status(http.StatusPartialContent) + io.Copy(c.Writer, reader) +} diff --git a/internal/server/response_test.go b/internal/server/response_test.go index dd67fd3..9041d6a 100644 --- a/internal/server/response_test.go +++ b/internal/server/response_test.go @@ -114,3 +114,65 @@ func TestErrorWithStatus(t *testing.T) { t.Fatalf("expected error 'already exists', got '%s'", resp.Error) } } + +func TestParseRangeStandard(t *testing.T) { + tests := []struct { + rangeHeader string + fileSize int64 + wantStart int64 + wantEnd int64 + wantErr bool + }{ + {"bytes=0-1023", 10000, 0, 1023, false}, + {"bytes=1024-", 10000, 1024, 9999, false}, + {"bytes=-1024", 10000, 8976, 9999, false}, + {"bytes=0-0", 10000, 0, 0, false}, + {"bytes=9999-", 10000, 9999, 9999, false}, + } + for _, tt := range tests { + start, end, err := ParseRange(tt.rangeHeader, tt.fileSize) + if (err != nil) != tt.wantErr { + t.Errorf("ParseRange(%q, %d) error = %v, wantErr %v", tt.rangeHeader, tt.fileSize, err, tt.wantErr) + continue + } + if !tt.wantErr { + if start != tt.wantStart || end != tt.wantEnd { + t.Errorf("ParseRange(%q, %d) = (%d, %d), want (%d, %d)", tt.rangeHeader, tt.fileSize, start, end, tt.wantStart, tt.wantEnd) + } + } + } +} + +func TestParseRangeInvalidAndMultiPart(t *testing.T) { + tests := []struct { + rangeHeader string + fileSize int64 + }{ + {"", 10000}, + {"bytes=9999-0", 10000}, + {"bytes=20000-", 10000}, + {"bytes=0-100,200-300", 10000}, + {"bytes=0-100, 400-500", 10000}, + {"bytes=", 10000}, + {"chars=0-100", 10000}, + } + for _, tt := range tests { + _, _, err := ParseRange(tt.rangeHeader, tt.fileSize) + if err == nil { + t.Errorf("ParseRange(%q, %d) expected error, got nil", tt.rangeHeader, tt.fileSize) + } + } +} + +func TestParseRangeEdgeCases(t *testing.T) { + start, end, err := ParseRange("bytes=0-99999", 10000) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if end != 9999 { + t.Errorf("end = %d, want 9999 (clamped to fileSize-1)", end) + } + if start != 0 { + t.Errorf("start = %d, want 0", start) + } +}