153 lines
5.1 KiB
Go
153 lines
5.1 KiB
Go
package router
|
|
|
|
import (
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestNewRouter(t *testing.T) {
|
|
r := NewRouter()
|
|
|
|
assert.NotNil(t, r, "Router should not be nil")
|
|
assert.NotNil(t, r.router, "Chi router should not be nil")
|
|
assert.Nil(t, r.logger, "Logger should be nil initially")
|
|
assert.Nil(t, r.rateLimiter, "RateLimiter should be nil initially")
|
|
}
|
|
|
|
func TestSetupRouter(t *testing.T) {
|
|
origins := []string{"http://localhost:3000"}
|
|
r := SetupRouter(origins)
|
|
|
|
assert.NotNil(t, r, "Router should not be nil")
|
|
assert.NotNil(t, r.router, "Chi router should not be nil")
|
|
assert.NotNil(t, r.logger, "Logger should not be nil")
|
|
assert.NotNil(t, r.rateLimiter, "RateLimiter should not be nil")
|
|
}
|
|
|
|
func TestGetRouter(t *testing.T) {
|
|
r := NewRouter()
|
|
chiRouter := r.GetRouter()
|
|
|
|
assert.Equal(t, r.router, chiRouter, "GetRouter should return the router field")
|
|
}
|
|
|
|
func TestCoffeeEndpoint(t *testing.T) {
|
|
origins := []string{"http://localhost:8080"}
|
|
r := SetupRouter(origins)
|
|
|
|
// Create a test server
|
|
ts := httptest.NewServer(r.GetRouter())
|
|
defer ts.Close()
|
|
|
|
// Make request to the coffee endpoint
|
|
resp, err := http.Get(ts.URL + "/v1/api/coffee")
|
|
require.NoError(t, err, "Error making request to coffee endpoint")
|
|
defer resp.Body.Close()
|
|
|
|
// Check status code
|
|
assert.Equal(t, http.StatusTeapot, resp.StatusCode, "Should return teapot status")
|
|
|
|
// Check response body
|
|
var responseBody map[string]string
|
|
err = json.NewDecoder(resp.Body).Decode(&responseBody)
|
|
require.NoError(t, err, "Error decoding response body")
|
|
|
|
assert.Equal(t, "I'm A Teapot!", responseBody["error"], "Response should contain teapot message")
|
|
}
|
|
|
|
func TestCORSMiddleware(t *testing.T) {
|
|
testCases := []struct {
|
|
name string
|
|
origins []string
|
|
requestOrigin string
|
|
expectedHeader string
|
|
}{
|
|
{
|
|
name: "Allowed origin",
|
|
origins: []string{"http://allowed-origin.com"},
|
|
requestOrigin: "http://allowed-origin.com",
|
|
expectedHeader: "http://allowed-origin.com",
|
|
},
|
|
{
|
|
name: "Multiple allowed origins",
|
|
origins: []string{"http://origin1.com", "http://origin2.com"},
|
|
requestOrigin: "http://origin2.com",
|
|
expectedHeader: "http://origin2.com",
|
|
},
|
|
{
|
|
name: "Wildcard origin",
|
|
origins: []string{"*"},
|
|
requestOrigin: "http://any-origin.com",
|
|
expectedHeader: "*",
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
r := SetupRouter(tc.origins)
|
|
|
|
req := httptest.NewRequest("OPTIONS", "/v1/api/coffee", nil)
|
|
req.Header.Set("Origin", tc.requestOrigin)
|
|
req.Header.Set("Access-Control-Request-Method", "GET")
|
|
|
|
rr := httptest.NewRecorder()
|
|
r.GetRouter().ServeHTTP(rr, req)
|
|
|
|
// For wildcard origin or matching origin, CORS headers should be set
|
|
if tc.origins[0] == "*" || contains(tc.origins, tc.requestOrigin) {
|
|
assert.Equal(t, tc.expectedHeader, rr.Header().Get("Access-Control-Allow-Origin"),
|
|
"CORS origin header should match expected value")
|
|
assert.Contains(t, rr.Header().Get("Access-Control-Allow-Methods"), "GET",
|
|
"CORS methods should contain GET")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRouterMiddlewares(t *testing.T) {
|
|
r := SetupRouter([]string{"http://localhost:3000"})
|
|
|
|
// Check if the router has middleware registered (indirect test)
|
|
// We can't easily test the actual middleware without mocking, but we can verify
|
|
// the router itself has routes registered
|
|
|
|
chiRouter := r.GetRouter()
|
|
routes := getAllRoutes(chiRouter)
|
|
|
|
assert.Contains(t, routes, "/v1/api/coffee", "Coffee endpoint should be registered")
|
|
assert.Contains(t, routes, "/v1/api/text", "Text endpoint should be registered")
|
|
assert.Contains(t, routes, "/v1/api/visual", "Visual endpoint should be registered")
|
|
}
|
|
|
|
// Helper functions
|
|
|
|
// contains checks if a slice contains a specific string
|
|
func contains(slice []string, item string) bool {
|
|
for _, s := range slice {
|
|
if s == item {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// getAllRoutes is a helper to extract all routes from a chi router
|
|
func getAllRoutes(router *chi.Mux) []string {
|
|
var routes []string
|
|
|
|
// This is a simplified way to check for routes
|
|
// Chi doesn't expose routes directly, so this is a proxy check
|
|
walkFunc := func(method string, route string, handler http.Handler, middlewares ...func(http.Handler) http.Handler) error {
|
|
routes = append(routes, route)
|
|
return nil
|
|
}
|
|
|
|
_ = chi.Walk(router, walkFunc)
|
|
return routes
|
|
} |