43 lines
1.2 KiB
Go
43 lines
1.2 KiB
Go
|
|
package websocket
|
||
|
|
|
||
|
|
import (
|
||
|
|
"net/http/httptest"
|
||
|
|
"testing"
|
||
|
|
)
|
||
|
|
|
||
|
|
func TestWebsocketOriginAllowedDefaultsToSameHost(t *testing.T) {
|
||
|
|
t.Setenv("WEBSOCKET_ALLOWED_ORIGINS", "")
|
||
|
|
|
||
|
|
req := httptest.NewRequest("GET", "http://example.com/ws", nil)
|
||
|
|
req.Host = "example.com:8080"
|
||
|
|
req.Header.Set("Origin", "https://example.com")
|
||
|
|
|
||
|
|
if !websocketOriginAllowed(req) {
|
||
|
|
t.Fatal("expected same-host websocket origin to be allowed by default")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestWebsocketOriginAllowedRejectsCrossOriginByDefault(t *testing.T) {
|
||
|
|
t.Setenv("WEBSOCKET_ALLOWED_ORIGINS", "")
|
||
|
|
|
||
|
|
req := httptest.NewRequest("GET", "http://example.com/ws", nil)
|
||
|
|
req.Host = "example.com:8080"
|
||
|
|
req.Header.Set("Origin", "https://attacker.example")
|
||
|
|
|
||
|
|
if websocketOriginAllowed(req) {
|
||
|
|
t.Fatal("expected cross-origin websocket request to be rejected by default")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestWebsocketOriginAllowedHonorsExplicitAllowlist(t *testing.T) {
|
||
|
|
t.Setenv("WEBSOCKET_ALLOWED_ORIGINS", "https://app.example, https://ops.example")
|
||
|
|
|
||
|
|
req := httptest.NewRequest("GET", "http://example.com/ws", nil)
|
||
|
|
req.Host = "example.com:8080"
|
||
|
|
req.Header.Set("Origin", "https://ops.example")
|
||
|
|
|
||
|
|
if !websocketOriginAllowed(req) {
|
||
|
|
t.Fatal("expected allowlisted websocket origin to be accepted")
|
||
|
|
}
|
||
|
|
}
|