test(09-04): harden db-backed regressions
This commit is contained in:
parent
3a3ecf5803
commit
cf07c29ae5
3 changed files with 31 additions and 14 deletions
|
|
@ -171,6 +171,17 @@ func FileUploadHandler(deps FilesDeps) http.HandlerFunc {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer file.Close()
|
defer file.Close()
|
||||||
|
if header.Size > maxBytes {
|
||||||
|
fileList, _ := deps.Queries.ListFilesByTablo(r.Context(), tablo.ID)
|
||||||
|
if fileList == nil {
|
||||||
|
fileList = []sqlc.TabloFile{}
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||||
|
w.WriteHeader(http.StatusUnprocessableEntity)
|
||||||
|
errMsg := "File too large (max " + strconv.Itoa(deps.MaxUploadMB) + " MB)."
|
||||||
|
_ = templates.UploadErrorFragment(tablo, fileList, csrf.Token(r), errMsg).Render(r.Context(), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
filename := strings.TrimSpace(header.Filename)
|
filename := strings.TrimSpace(header.Filename)
|
||||||
if filename == "" {
|
if filename == "" {
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ import (
|
||||||
"mime/multipart"
|
"mime/multipart"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
@ -99,12 +100,12 @@ func TestFileUpload(t *testing.T) {
|
||||||
sessionCookie := &http.Cookie{Name: auth.SessionCookieName, Value: sessionVal}
|
sessionCookie := &http.Cookie{Name: auth.SessionCookieName, Value: sessionVal}
|
||||||
|
|
||||||
// Build multipart body with a CSRF token.
|
// Build multipart body with a CSRF token.
|
||||||
csrfToken, csrfCookies := getCSRFToken(t, router, "/tablos/"+tablo.ID.String()+"/files", []*http.Cookie{sessionCookie})
|
csrfToken, csrfCookies := getCSRFToken(t, router, "/", []*http.Cookie{sessionCookie})
|
||||||
|
|
||||||
var body bytes.Buffer
|
var body bytes.Buffer
|
||||||
mw := multipart.NewWriter(&body)
|
mw := multipart.NewWriter(&body)
|
||||||
// Add CSRF field.
|
// Add CSRF field.
|
||||||
_ = mw.WriteField("gorilla.csrf.Token", csrfToken)
|
_ = mw.WriteField("_csrf", csrfToken)
|
||||||
// Add file field.
|
// Add file field.
|
||||||
fw, err := mw.CreateFormFile("file", "hello.txt")
|
fw, err := mw.CreateFormFile("file", "hello.txt")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -191,12 +192,12 @@ func TestFileUploadTooLarge(t *testing.T) {
|
||||||
}
|
}
|
||||||
sessionCookie := &http.Cookie{Name: auth.SessionCookieName, Value: sessionVal}
|
sessionCookie := &http.Cookie{Name: auth.SessionCookieName, Value: sessionVal}
|
||||||
|
|
||||||
csrfToken, csrfCookies := getCSRFToken(t, router, "/tablos/"+tablo.ID.String()+"/files", []*http.Cookie{sessionCookie})
|
csrfToken, csrfCookies := getCSRFToken(t, router, "/", []*http.Cookie{sessionCookie})
|
||||||
|
|
||||||
// Build a multipart body that exceeds 1 MB.
|
// Build a multipart body that exceeds 1 MB.
|
||||||
var body bytes.Buffer
|
var body bytes.Buffer
|
||||||
mw := multipart.NewWriter(&body)
|
mw := multipart.NewWriter(&body)
|
||||||
_ = mw.WriteField("gorilla.csrf.Token", csrfToken)
|
_ = mw.WriteField("_csrf", csrfToken)
|
||||||
fw, err := mw.CreateFormFile("file", "big.bin")
|
fw, err := mw.CreateFormFile("file", "big.bin")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("CreateFormFile: %v", err)
|
t.Fatalf("CreateFormFile: %v", err)
|
||||||
|
|
@ -495,7 +496,7 @@ func TestFileDelete(t *testing.T) {
|
||||||
csrfToken, csrfCookies := getCSRFToken(t, router, "/tablos/"+tablo.ID.String()+"/files", []*http.Cookie{sessionCookie})
|
csrfToken, csrfCookies := getCSRFToken(t, router, "/tablos/"+tablo.ID.String()+"/files", []*http.Cookie{sessionCookie})
|
||||||
|
|
||||||
var body strings.Builder
|
var body strings.Builder
|
||||||
body.WriteString("gorilla.csrf.Token=" + csrfToken)
|
body.WriteString(url.Values{"_csrf": {csrfToken}}.Encode())
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodPost, deleteURL, strings.NewReader(body.String()))
|
req := httptest.NewRequest(http.MethodPost, deleteURL, strings.NewReader(body.String()))
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
|
@ -573,7 +574,7 @@ func TestFileDelete_S3Failure(t *testing.T) {
|
||||||
csrfToken, csrfCookies := getCSRFToken(t, router, "/tablos/"+tablo.ID.String()+"/files", []*http.Cookie{sessionCookie})
|
csrfToken, csrfCookies := getCSRFToken(t, router, "/tablos/"+tablo.ID.String()+"/files", []*http.Cookie{sessionCookie})
|
||||||
|
|
||||||
var body strings.Builder
|
var body strings.Builder
|
||||||
body.WriteString("gorilla.csrf.Token=" + csrfToken)
|
body.WriteString(url.Values{"_csrf": {csrfToken}}.Encode())
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodPost, deleteURL, strings.NewReader(body.String()))
|
req := httptest.NewRequest(http.MethodPost, deleteURL, strings.NewReader(body.String()))
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
|
@ -662,9 +663,9 @@ func TestFileOwnership(t *testing.T) {
|
||||||
|
|
||||||
// Non-owner: POST delete → 404
|
// Non-owner: POST delete → 404
|
||||||
// We need a CSRF token from non-owner's session to get past CSRF middleware.
|
// We need a CSRF token from non-owner's session to get past CSRF middleware.
|
||||||
csrfToken, csrfCookies := getCSRFToken(t, router, "/tablos/"+tablo.ID.String()+"/files", []*http.Cookie{sessionCookie})
|
csrfToken, csrfCookies := getCSRFToken(t, router, "/", []*http.Cookie{sessionCookie})
|
||||||
var delBody strings.Builder
|
var delBody strings.Builder
|
||||||
delBody.WriteString("gorilla.csrf.Token=" + csrfToken)
|
delBody.WriteString(url.Values{"_csrf": {csrfToken}}.Encode())
|
||||||
reqDel := httptest.NewRequest(http.MethodPost, base+"/delete", strings.NewReader(delBody.String()))
|
reqDel := httptest.NewRequest(http.MethodPost, base+"/delete", strings.NewReader(delBody.String()))
|
||||||
reqDel.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
reqDel.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
reqDel.AddCookie(sessionCookie)
|
reqDel.AddCookie(sessionCookie)
|
||||||
|
|
|
||||||
|
|
@ -68,13 +68,13 @@ func TestTasksKanbanRenders(t *testing.T) {
|
||||||
}
|
}
|
||||||
sessionCookie := &http.Cookie{Name: auth.SessionCookieName, Value: cookieVal}
|
sessionCookie := &http.Cookie{Name: auth.SessionCookieName, Value: cookieVal}
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/tablos/"+tablo.ID.String(), nil)
|
req := httptest.NewRequest(http.MethodGet, "/tablos/"+tablo.ID.String()+"/tasks", nil)
|
||||||
req.AddCookie(sessionCookie)
|
req.AddCookie(sessionCookie)
|
||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
router.ServeHTTP(rec, req)
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
if rec.Code != http.StatusOK {
|
if rec.Code != http.StatusOK {
|
||||||
t.Fatalf("GET /tablos/{id} status = %d; want 200", rec.Code)
|
t.Fatalf("GET /tablos/{id}/tasks status = %d; want 200", rec.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
body := rec.Body.String()
|
body := rec.Body.String()
|
||||||
|
|
@ -92,12 +92,12 @@ func TestTasksKanbanRenders(t *testing.T) {
|
||||||
nonOwner := preInsertUser(t, ctx, q, "kanban-other@example.com", "correct-horse-12")
|
nonOwner := preInsertUser(t, ctx, q, "kanban-other@example.com", "correct-horse-12")
|
||||||
nonOwnerCookieVal, _, _ := store.Create(ctx, nonOwner.ID)
|
nonOwnerCookieVal, _, _ := store.Create(ctx, nonOwner.ID)
|
||||||
nonOwnerCookie := &http.Cookie{Name: auth.SessionCookieName, Value: nonOwnerCookieVal}
|
nonOwnerCookie := &http.Cookie{Name: auth.SessionCookieName, Value: nonOwnerCookieVal}
|
||||||
req2 := httptest.NewRequest(http.MethodGet, "/tablos/"+tablo.ID.String(), nil)
|
req2 := httptest.NewRequest(http.MethodGet, "/tablos/"+tablo.ID.String()+"/tasks", nil)
|
||||||
req2.AddCookie(nonOwnerCookie)
|
req2.AddCookie(nonOwnerCookie)
|
||||||
rec2 := httptest.NewRecorder()
|
rec2 := httptest.NewRecorder()
|
||||||
router.ServeHTTP(rec2, req2)
|
router.ServeHTTP(rec2, req2)
|
||||||
if rec2.Code != http.StatusNotFound {
|
if rec2.Code != http.StatusNotFound {
|
||||||
t.Errorf("non-owner GET /tablos/{id}: status = %d; want 404", rec2.Code)
|
t.Errorf("non-owner GET /tablos/{id}/tasks: status = %d; want 404", rec2.Code)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1236,6 +1236,7 @@ func TestTaskOwnership(t *testing.T) {
|
||||||
t.Fatalf("store.Create for userB: %v", storeErr)
|
t.Fatalf("store.Create for userB: %v", storeErr)
|
||||||
}
|
}
|
||||||
sessionCookie := &http.Cookie{Name: auth.SessionCookieName, Value: cookieVal}
|
sessionCookie := &http.Cookie{Name: auth.SessionCookieName, Value: cookieVal}
|
||||||
|
csrfToken, csrfCookies := getCSRFToken(t, router, "/", []*http.Cookie{sessionCookie})
|
||||||
|
|
||||||
routes := []struct {
|
routes := []struct {
|
||||||
method string
|
method string
|
||||||
|
|
@ -1246,8 +1247,12 @@ func TestTaskOwnership(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, route := range routes {
|
for _, route := range routes {
|
||||||
req := httptest.NewRequest(route.method, route.path, nil)
|
form := url.Values{"_csrf": {csrfToken}}
|
||||||
req.AddCookie(sessionCookie)
|
req := httptest.NewRequest(route.method, route.path, strings.NewReader(form.Encode()))
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
for _, c := range csrfCookies {
|
||||||
|
req.AddCookie(c)
|
||||||
|
}
|
||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
router.ServeHTTP(rec, req)
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue