diff --git a/backend/internal/web/handlers_files.go b/backend/internal/web/handlers_files.go index 32f65b8..2decced 100644 --- a/backend/internal/web/handlers_files.go +++ b/backend/internal/web/handlers_files.go @@ -171,6 +171,17 @@ func FileUploadHandler(deps FilesDeps) http.HandlerFunc { return } defer file.Close() + if header.Size > maxBytes { + fileList, _ := deps.Queries.ListFilesByTablo(r.Context(), tablo.ID) + if fileList == nil { + fileList = []sqlc.TabloFile{} + } + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusUnprocessableEntity) + errMsg := "File too large (max " + strconv.Itoa(deps.MaxUploadMB) + " MB)." + _ = templates.UploadErrorFragment(tablo, fileList, csrf.Token(r), errMsg).Render(r.Context(), w) + return + } filename := strings.TrimSpace(header.Filename) if filename == "" { diff --git a/backend/internal/web/handlers_files_test.go b/backend/internal/web/handlers_files_test.go index 8067b03..b2f07e6 100644 --- a/backend/internal/web/handlers_files_test.go +++ b/backend/internal/web/handlers_files_test.go @@ -18,6 +18,7 @@ import ( "mime/multipart" "net/http" "net/http/httptest" + "net/url" "os" "strings" "testing" @@ -99,12 +100,12 @@ func TestFileUpload(t *testing.T) { sessionCookie := &http.Cookie{Name: auth.SessionCookieName, Value: sessionVal} // Build multipart body with a CSRF token. - csrfToken, csrfCookies := getCSRFToken(t, router, "/tablos/"+tablo.ID.String()+"/files", []*http.Cookie{sessionCookie}) + csrfToken, csrfCookies := getCSRFToken(t, router, "/", []*http.Cookie{sessionCookie}) var body bytes.Buffer mw := multipart.NewWriter(&body) // Add CSRF field. - _ = mw.WriteField("gorilla.csrf.Token", csrfToken) + _ = mw.WriteField("_csrf", csrfToken) // Add file field. fw, err := mw.CreateFormFile("file", "hello.txt") if err != nil { @@ -191,12 +192,12 @@ func TestFileUploadTooLarge(t *testing.T) { } sessionCookie := &http.Cookie{Name: auth.SessionCookieName, Value: sessionVal} - csrfToken, csrfCookies := getCSRFToken(t, router, "/tablos/"+tablo.ID.String()+"/files", []*http.Cookie{sessionCookie}) + csrfToken, csrfCookies := getCSRFToken(t, router, "/", []*http.Cookie{sessionCookie}) // Build a multipart body that exceeds 1 MB. var body bytes.Buffer mw := multipart.NewWriter(&body) - _ = mw.WriteField("gorilla.csrf.Token", csrfToken) + _ = mw.WriteField("_csrf", csrfToken) fw, err := mw.CreateFormFile("file", "big.bin") if err != nil { t.Fatalf("CreateFormFile: %v", err) @@ -495,7 +496,7 @@ func TestFileDelete(t *testing.T) { csrfToken, csrfCookies := getCSRFToken(t, router, "/tablos/"+tablo.ID.String()+"/files", []*http.Cookie{sessionCookie}) var body strings.Builder - body.WriteString("gorilla.csrf.Token=" + csrfToken) + body.WriteString(url.Values{"_csrf": {csrfToken}}.Encode()) req := httptest.NewRequest(http.MethodPost, deleteURL, strings.NewReader(body.String())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") @@ -573,7 +574,7 @@ func TestFileDelete_S3Failure(t *testing.T) { csrfToken, csrfCookies := getCSRFToken(t, router, "/tablos/"+tablo.ID.String()+"/files", []*http.Cookie{sessionCookie}) var body strings.Builder - body.WriteString("gorilla.csrf.Token=" + csrfToken) + body.WriteString(url.Values{"_csrf": {csrfToken}}.Encode()) req := httptest.NewRequest(http.MethodPost, deleteURL, strings.NewReader(body.String())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") @@ -662,9 +663,9 @@ func TestFileOwnership(t *testing.T) { // Non-owner: POST delete → 404 // We need a CSRF token from non-owner's session to get past CSRF middleware. - csrfToken, csrfCookies := getCSRFToken(t, router, "/tablos/"+tablo.ID.String()+"/files", []*http.Cookie{sessionCookie}) + csrfToken, csrfCookies := getCSRFToken(t, router, "/", []*http.Cookie{sessionCookie}) var delBody strings.Builder - delBody.WriteString("gorilla.csrf.Token=" + csrfToken) + delBody.WriteString(url.Values{"_csrf": {csrfToken}}.Encode()) reqDel := httptest.NewRequest(http.MethodPost, base+"/delete", strings.NewReader(delBody.String())) reqDel.Header.Set("Content-Type", "application/x-www-form-urlencoded") reqDel.AddCookie(sessionCookie) diff --git a/backend/internal/web/handlers_tasks_test.go b/backend/internal/web/handlers_tasks_test.go index 1c80b30..bbc60c7 100644 --- a/backend/internal/web/handlers_tasks_test.go +++ b/backend/internal/web/handlers_tasks_test.go @@ -68,13 +68,13 @@ func TestTasksKanbanRenders(t *testing.T) { } sessionCookie := &http.Cookie{Name: auth.SessionCookieName, Value: cookieVal} - req := httptest.NewRequest(http.MethodGet, "/tablos/"+tablo.ID.String(), nil) + req := httptest.NewRequest(http.MethodGet, "/tablos/"+tablo.ID.String()+"/tasks", nil) req.AddCookie(sessionCookie) rec := httptest.NewRecorder() router.ServeHTTP(rec, req) if rec.Code != http.StatusOK { - t.Fatalf("GET /tablos/{id} status = %d; want 200", rec.Code) + t.Fatalf("GET /tablos/{id}/tasks status = %d; want 200", rec.Code) } body := rec.Body.String() @@ -92,12 +92,12 @@ func TestTasksKanbanRenders(t *testing.T) { nonOwner := preInsertUser(t, ctx, q, "kanban-other@example.com", "correct-horse-12") nonOwnerCookieVal, _, _ := store.Create(ctx, nonOwner.ID) nonOwnerCookie := &http.Cookie{Name: auth.SessionCookieName, Value: nonOwnerCookieVal} - req2 := httptest.NewRequest(http.MethodGet, "/tablos/"+tablo.ID.String(), nil) + req2 := httptest.NewRequest(http.MethodGet, "/tablos/"+tablo.ID.String()+"/tasks", nil) req2.AddCookie(nonOwnerCookie) rec2 := httptest.NewRecorder() router.ServeHTTP(rec2, req2) if rec2.Code != http.StatusNotFound { - t.Errorf("non-owner GET /tablos/{id}: status = %d; want 404", rec2.Code) + t.Errorf("non-owner GET /tablos/{id}/tasks: status = %d; want 404", rec2.Code) } } @@ -1236,6 +1236,7 @@ func TestTaskOwnership(t *testing.T) { t.Fatalf("store.Create for userB: %v", storeErr) } sessionCookie := &http.Cookie{Name: auth.SessionCookieName, Value: cookieVal} + csrfToken, csrfCookies := getCSRFToken(t, router, "/", []*http.Cookie{sessionCookie}) routes := []struct { method string @@ -1246,8 +1247,12 @@ func TestTaskOwnership(t *testing.T) { } for _, route := range routes { - req := httptest.NewRequest(route.method, route.path, nil) - req.AddCookie(sessionCookie) + form := url.Values{"_csrf": {csrfToken}} + req := httptest.NewRequest(route.method, route.path, strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + for _, c := range csrfCookies { + req.AddCookie(c) + } rec := httptest.NewRecorder() router.ServeHTTP(rec, req)