1package http
2
3import (
4 "bufio"
5 "bytes"
6 "context"
7 "crypto/sha1"
8 "crypto/tls"
9 "encoding/base64"
10 "errors"
11 "fmt"
12 htmltemplate "html/template"
13 "io"
14 "io/fs"
15 golog "log"
16 "log/slog"
17 "net"
18 "net/http"
19 "net/http/httputil"
20 "net/textproto"
21 "net/url"
22 "os"
23 "path/filepath"
24 "sort"
25 "strings"
26 "syscall"
27 "time"
28
29 "github.com/mjl-/mox/config"
30 "github.com/mjl-/mox/dns"
31 "github.com/mjl-/mox/mlog"
32 "github.com/mjl-/mox/mox-"
33)
34
35func recvid(r *http.Request) string {
36 cid := mox.CidFromCtx(r.Context())
37 if cid <= 0 {
38 return ""
39 }
40 return " (id " + mox.ReceivedID(cid) + ")"
41}
42
43// WebHandle serves an HTTP request by going through the list of WebHandlers,
44// check if there is a domain+path match, and running the handler if so.
45// WebHandle runs after the built-in handlers for mta-sts, autoconfig, etc.
46// If no handler matched, false is returned.
47// WebHandle sets w.Name to that of the matching handler.
48func WebHandle(w *loggingWriter, r *http.Request, host dns.IPDomain) (handled bool) {
49 conf := mox.Conf.DynamicConfig()
50 redirects := conf.WebDNSDomainRedirects
51 handlers := conf.WebHandlers
52
53 for from, to := range redirects {
54 if host.Domain != from {
55 continue
56 }
57 u := r.URL
58 u.Scheme = "https"
59 u.Host = to.Name()
60 w.Handler = "(domainredirect)"
61 http.Redirect(w, r, u.String(), http.StatusPermanentRedirect)
62 return true
63 }
64
65 for _, h := range handlers {
66 if host.Domain != h.DNSDomain {
67 continue
68 }
69 loc := h.Path.FindStringIndex(r.URL.Path)
70 if loc == nil {
71 continue
72 }
73 s := loc[0]
74 e := loc[1]
75 path := r.URL.Path[s:e]
76
77 if r.TLS == nil && !h.DontRedirectPlainHTTP {
78 u := *r.URL
79 u.Scheme = "https"
80 u.Host = h.DNSDomain.Name()
81 w.Handler = h.Name
82 w.Compress = h.Compress
83 http.Redirect(w, r, u.String(), http.StatusPermanentRedirect)
84 return true
85 }
86
87 // We don't want the loggingWriter to override the static handler's decisions to compress.
88 w.Compress = h.Compress
89 if h.WebStatic != nil && HandleStatic(h.WebStatic, h.Compress, w, r) {
90 w.Handler = h.Name
91 return true
92 }
93 if h.WebRedirect != nil && HandleRedirect(h.WebRedirect, w, r) {
94 w.Handler = h.Name
95 return true
96 }
97 if h.WebForward != nil && HandleForward(h.WebForward, w, r, path) {
98 w.Handler = h.Name
99 return true
100 }
101 if h.WebInternal != nil && HandleInternal(h.WebInternal, w, r) {
102 w.Handler = h.Name
103 return true
104 }
105 }
106 w.Compress = false
107 return false
108}
109
110var lsTemplate = htmltemplate.Must(htmltemplate.New("ls").Parse(`<!doctype html>
111<html>
112 <head>
113 <meta charset="utf-8" />
114 <meta name="viewport" content="width=device-width, initial-scale=1" />
115 <title>ls</title>
116 <style>
117body, html { padding: 1em; font-size: 16px; }
118* { font-size: inherit; font-family: ubuntu, lato, sans-serif; margin: 0; padding: 0; box-sizing: border-box; }
119h1 { margin-bottom: 1ex; font-size: 1.2rem; }
120table td, table th { padding: .2em .5em; }
121table > tbody > tr:nth-child(odd) { background-color: #f8f8f8; }
122[title] { text-decoration: underline; text-decoration-style: dotted; }
123 </style>
124 </head>
125 <body>
126 <h1>ls</h1>
127 <table>
128 <thead>
129 <tr>
130 <th>Size in MB</th>
131 <th>Modified (UTC)</th>
132 <th>Name</th>
133 </tr>
134 </thead>
135 <tbody>
136 {{ if not .Files }}
137 <tr><td colspan="3">No files.</td></tr>
138 {{ end }}
139 {{ range .Files }}
140 <tr>
141 <td title="{{ .Size }} bytes" style="text-align: right">{{ .SizeReadable }}{{ if .SizePad }}<span style="visibility:hidden">.  </span>{{ end }}</td>
142 <td>{{ .Modified }}</td>
143 <td><a style="display: block" href="{{ .Name }}">{{ .Name }}</a></td>
144 </tr>
145 {{ end }}
146 </tbody>
147 </table>
148 </body>
149</html>
150`))
151
152// HandleStatic serves static files. If a directory is requested and the URL
153// path doesn't end with a slash, a response with a redirect to the URL path with trailing
154// slash is written. If a directory is requested and an index.html exists, that
155// file is returned. Otherwise, for directories with ListFiles configured, a
156// directory listing is returned.
157func HandleStatic(h *config.WebStatic, compress bool, w http.ResponseWriter, r *http.Request) (handled bool) {
158 log := func() mlog.Log {
159 return pkglog.WithContext(r.Context())
160 }
161 if r.Method != "GET" && r.Method != "HEAD" {
162 if h.ContinueNotFound {
163 // Give another handler that is presumbly configured, for the same path, a chance.
164 // E.g. an app that may generate this file for future requests to pick up.
165 return false
166 }
167 http.Error(w, "405 - method not allowed", http.StatusMethodNotAllowed)
168 return true
169 }
170
171 var fspath string
172 if h.StripPrefix != "" {
173 if !strings.HasPrefix(r.URL.Path, h.StripPrefix) {
174 if h.ContinueNotFound {
175 // We haven't handled this request, try a next WebHandler in the list.
176 return false
177 }
178 http.NotFound(w, r)
179 return true
180 }
181 fspath = filepath.Join(h.Root, strings.TrimPrefix(r.URL.Path, h.StripPrefix))
182 } else {
183 fspath = filepath.Join(h.Root, r.URL.Path)
184 }
185 // fspath will not have a trailing slash anymore, we'll correct for it
186 // later when the path turns out to be file instead of a directory.
187
188 serveFile := func(name string, fi fs.FileInfo, content *os.File) {
189 // ServeContent only sets a content-type if not already present in the response headers.
190 hdr := w.Header()
191 for k, v := range h.ResponseHeaders {
192 hdr.Add(k, v)
193 }
194 // We transparently compress here, but still use ServeContent, because it handles
195 // conditional requests, range requests. It's a bit of a hack, but on first write
196 // to staticgzcacheReplacer where we are compressing, we write the full compressed
197 // file instead, and return an error to ServeContent so it stops. We still have all
198 // the useful behaviour (status code and headers) from ServeContent.
199 xw := w
200 if compress && acceptsGzip(r) && compressibleContent(content) {
201 xw = &staticgzcacheReplacer{w, r, content.Name(), content, fi.ModTime(), fi.Size(), 0, false}
202 } else {
203 w.(*loggingWriter).Compress = false
204 }
205 http.ServeContent(xw, r, name, fi.ModTime(), content)
206 }
207
208 f, err := os.Open(fspath)
209 if err != nil {
210 if os.IsNotExist(err) || errors.Is(err, syscall.ENOTDIR) {
211 if h.ContinueNotFound {
212 // We haven't handled this request, try a next WebHandler in the list.
213 return false
214 }
215 http.NotFound(w, r)
216 return true
217 } else if errors.Is(err, syscall.ENAMETOOLONG) {
218 http.NotFound(w, r)
219 return true
220 } else if os.IsPermission(err) {
221 // If we tried opening a directory, we may not have permission to read it, but
222 // still access files inside it (execute bit), such as index.html. So try to serve it.
223 index, err := os.Open(filepath.Join(fspath, "index.html"))
224 if err != nil {
225 http.Error(w, "403 - permission denied", http.StatusForbidden)
226 return true
227 }
228 defer func() {
229 err := index.Close()
230 log().Check(err, "closing index file for serving")
231 }()
232 var ifi os.FileInfo
233 ifi, err = index.Stat()
234 if err != nil {
235 log().Errorx("stat index.html in directory we cannot list", err, slog.Any("url", r.URL), slog.String("fspath", fspath))
236 http.Error(w, "500 - internal server error"+recvid(r), http.StatusInternalServerError)
237 return true
238 }
239 w.Header().Set("Content-Type", "text/html; charset=utf-8")
240 serveFile("index.html", ifi, index)
241 return true
242 }
243 log().Errorx("open file for static file serving", err, slog.Any("url", r.URL), slog.String("fspath", fspath))
244 http.Error(w, "500 - internal server error"+recvid(r), http.StatusInternalServerError)
245 return true
246 }
247 defer func() {
248 if err := f.Close(); err != nil {
249 log().Check(err, "closing file for static file serving")
250 }
251 }()
252
253 fi, err := f.Stat()
254 if err != nil {
255 log().Errorx("stat file for static file serving", err, slog.Any("url", r.URL), slog.String("fspath", fspath))
256 http.Error(w, "500 - internal server error"+recvid(r), http.StatusInternalServerError)
257 return true
258 }
259 // Redirect if the local path is a directory.
260 if fi.IsDir() && !strings.HasSuffix(r.URL.Path, "/") {
261 http.Redirect(w, r, r.URL.Path+"/", http.StatusTemporaryRedirect)
262 return true
263 } else if !fi.IsDir() && strings.HasSuffix(r.URL.Path, "/") {
264 if h.ContinueNotFound {
265 return false
266 }
267 http.NotFound(w, r)
268 return true
269 }
270
271 if fi.IsDir() {
272 index, err := os.Open(filepath.Join(fspath, "index.html"))
273 if err != nil && os.IsPermission(err) {
274 http.Error(w, "403 - permission denied", http.StatusForbidden)
275 return true
276 } else if err != nil && os.IsNotExist(err) && !h.ListFiles {
277 if h.ContinueNotFound {
278 return false
279 }
280 http.Error(w, "403 - permission denied", http.StatusForbidden)
281 return true
282 } else if err == nil {
283 defer func() {
284 if err := index.Close(); err != nil {
285 log().Check(err, "closing index file for serving")
286 }
287 }()
288
289 var ifi os.FileInfo
290 ifi, err = index.Stat()
291 if err == nil {
292 w.Header().Set("Content-Type", "text/html; charset=utf-8")
293 serveFile("index.html", ifi, index)
294 return true
295 }
296 }
297 if !os.IsNotExist(err) {
298 log().Errorx("stat for static file serving", err, slog.Any("url", r.URL), slog.String("fspath", fspath))
299 http.Error(w, "500 - internal server error"+recvid(r), http.StatusInternalServerError)
300 return true
301 }
302
303 type File struct {
304 Name string
305 Size int64
306 SizeReadable string
307 SizePad bool // Whether the size needs padding because it has no decimal point.
308 Modified string
309 }
310 files := []File{}
311 if r.URL.Path != "/" {
312 files = append(files, File{"..", 0, "", false, ""})
313 }
314 for {
315 l, err := f.Readdir(1000)
316 for _, e := range l {
317 mb := float64(e.Size()) / (1024 * 1024)
318 var size string
319 var sizepad bool
320 if !e.IsDir() {
321 if mb >= 10 {
322 size = fmt.Sprintf("%d", int64(mb))
323 sizepad = true
324 } else {
325 size = fmt.Sprintf("%.2f", mb)
326 }
327 }
328 const dateTime = "2006-01-02 15:04:05" // time.DateTime, but only since go1.20.
329 modified := e.ModTime().UTC().Format(dateTime)
330 f := File{e.Name(), e.Size(), size, sizepad, modified}
331 if e.IsDir() {
332 f.Name += "/"
333 }
334 files = append(files, f)
335 }
336 if err == io.EOF {
337 break
338 } else if err != nil {
339 log().Errorx("reading directory for file listing", err, slog.Any("url", r.URL), slog.String("fspath", fspath))
340 http.Error(w, "500 - internal server error"+recvid(r), http.StatusInternalServerError)
341 return true
342 }
343 }
344 sort.Slice(files, func(i, j int) bool {
345 return files[i].Name < files[j].Name
346 })
347 hdr := w.Header()
348 hdr.Set("Content-Type", "text/html; charset=utf-8")
349 for k, v := range h.ResponseHeaders {
350 if !strings.EqualFold(k, "content-type") {
351 hdr.Add(k, v)
352 }
353 }
354 err = lsTemplate.Execute(w, map[string]any{"Files": files})
355 if err != nil {
356 log().Check(err, "executing directory listing template")
357 }
358 return true
359 }
360
361 serveFile(fspath, fi, f)
362 return true
363}
364
365// HandleRedirect writes a response with an HTTP redirect.
366func HandleRedirect(h *config.WebRedirect, w http.ResponseWriter, r *http.Request) (handled bool) {
367 var dstpath string
368 if h.OrigPath == nil {
369 // No path rewrite necessary.
370 dstpath = r.URL.Path
371 } else if !h.OrigPath.MatchString(r.URL.Path) {
372 http.NotFound(w, r)
373 return true
374 } else {
375 dstpath = h.OrigPath.ReplaceAllString(r.URL.Path, h.ReplacePath)
376 }
377
378 u := *r.URL
379 u.Opaque = ""
380 u.RawPath = ""
381 u.OmitHost = false
382 if h.URL != nil {
383 u.Scheme = h.URL.Scheme
384 u.Host = h.URL.Host
385 u.ForceQuery = h.URL.ForceQuery
386 u.RawQuery = h.URL.RawQuery
387 u.Fragment = h.URL.Fragment
388 if r.URL.RawQuery != "" {
389 if u.RawQuery != "" {
390 u.RawQuery += "&"
391 }
392 u.RawQuery += r.URL.RawQuery
393 }
394 }
395 u.Path = dstpath
396 code := http.StatusPermanentRedirect
397 if h.StatusCode != 0 {
398 code = h.StatusCode
399 }
400
401 // If we would be redirecting to the same scheme,host,path, we would get here again
402 // causing a redirect loop. Instead, this causes this redirect to not match,
403 // allowing to try the next WebHandler. This can be used to redirect all plain http
404 // requests to https.
405 reqscheme := "http"
406 if r.TLS != nil {
407 reqscheme = "https"
408 }
409 if reqscheme == u.Scheme && r.Host == u.Host && r.URL.Path == u.Path {
410 return false
411 }
412
413 http.Redirect(w, r, u.String(), code)
414 return true
415}
416
417// HandleInternal passes the request to an internal service.
418func HandleInternal(h *config.WebInternal, w http.ResponseWriter, r *http.Request) (handled bool) {
419 h.Handler.ServeHTTP(w, r)
420 return true
421}
422
423// HandleForward handles a request by forwarding it to another webserver and
424// passing the response on. I.e. a reverse proxy. It handles websocket
425// connections by monitoring the websocket handshake and then just passing along the
426// websocket frames.
427func HandleForward(h *config.WebForward, w http.ResponseWriter, r *http.Request, path string) (handled bool) {
428 log := func() mlog.Log {
429 return pkglog.WithContext(r.Context())
430 }
431
432 xr := *r
433 r = &xr
434 if h.StripPath {
435 u := *r.URL
436 u.Path = r.URL.Path[len(path):]
437 if !strings.HasPrefix(u.Path, "/") {
438 u.Path = "/" + u.Path
439 }
440 u.RawPath = ""
441 r.URL = &u
442 }
443
444 // Remove any forwarded headers passed in by client.
445 hdr := http.Header{}
446 for k, vl := range r.Header {
447 if k == "Forwarded" || k == "X-Forwarded" || strings.HasPrefix(k, "X-Forwarded-") {
448 continue
449 }
450 hdr[k] = vl
451 }
452 r.Header = hdr
453
454 // Add our own X-Forwarded headers. ReverseProxy will add X-Forwarded-For.
455 r.Header["X-Forwarded-Host"] = []string{r.Host}
456 proto := "http"
457 if r.TLS != nil {
458 proto = "https"
459 }
460 r.Header["X-Forwarded-Proto"] = []string{proto}
461 // note: We are not using "ws" or "wss" for websocket. The request we are
462 // forwarding is http(s), and we don't yet know if the backend even supports
463 // websockets.
464
465 // todo: add Forwarded header? is anyone using it?
466
467 // If we see an Upgrade: websocket, we're going to assume the client needs
468 // websocket and only attempt to talk websocket with the backend. If the backend
469 // doesn't do websocket, we'll send back a "bad request" response. For other values
470 // of Upgrade, we don't do anything special.
471 // https://www.iana.org/assignments/http-upgrade-tokens/http-upgrade-tokens.xhtml
472 // Upgrade: ../rfc/9110:2798
473 // Upgrade headers are not for http/1.0, ../rfc/9110:2880
474 // Websocket client "handshake" is described at ../rfc/6455:1134
475 upgrade := r.Header.Get("Upgrade")
476 if upgrade != "" && !(r.ProtoMajor == 1 && r.ProtoMinor == 0) {
477 // Websockets have case-insensitive string "websocket".
478 for _, s := range strings.Split(upgrade, ",") {
479 if strings.EqualFold(textproto.TrimString(s), "websocket") {
480 forwardWebsocket(h, w, r, path)
481 return true
482 }
483 }
484 }
485
486 // ReverseProxy will append any remaining path to the configured target URL.
487 proxy := httputil.NewSingleHostReverseProxy(h.TargetURL)
488 proxy.FlushInterval = time.Duration(-1) // Flush after each write.
489 proxy.ErrorLog = golog.New(mlog.LogWriter(mlog.New("net/http/httputil", nil).WithContext(r.Context()), mlog.LevelDebug, "reverseproxy error"), "", 0)
490 proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
491 if errors.Is(err, context.Canceled) {
492 log().Debugx("forwarding request to backend webserver", err, slog.Any("url", r.URL))
493 return
494 }
495 log().Errorx("forwarding request to backend webserver", err, slog.Any("url", r.URL))
496 if os.IsTimeout(err) {
497 http.Error(w, "504 - gateway timeout"+recvid(r), http.StatusGatewayTimeout)
498 } else {
499 http.Error(w, "502 - bad gateway"+recvid(r), http.StatusBadGateway)
500 }
501 }
502 whdr := w.Header()
503 for k, v := range h.ResponseHeaders {
504 whdr.Add(k, v)
505 }
506 proxy.ServeHTTP(w, r)
507 return true
508}
509
510var errResponseNotWebsocket = errors.New("not a valid websocket response to request")
511var errNotImplemented = errors.New("functionality not yet implemented")
512
513// Request has an Upgrade: websocket header. Check more websocketiness about the
514// request. If it looks good, we forward it to the backend. If the backend responds
515// with a valid websocket response, indicating it is indeed a websocket server, we
516// pass the response along and start copying data between the client and the
517// backend. We don't look at the frames and payloads. The backend already needs to
518// know enough websocket to handle the frames. It wouldn't necessarily hurt to
519// monitor the frames too, and check if they are valid, but it's quite a bit of
520// work for little benefit. Besides, the whole point of websockets is to exchange
521// bytes without HTTP being in the way, so let's do that.
522func forwardWebsocket(h *config.WebForward, w http.ResponseWriter, r *http.Request, path string) (handled bool) {
523 log := func() mlog.Log {
524 return pkglog.WithContext(r.Context())
525 }
526
527 lw := w.(*loggingWriter)
528 lw.WebsocketRequest = true // For correct protocol in metrics.
529
530 // We check the requested websocket version first. A future websocket version may
531 // have different request requirements.
532 // ../rfc/6455:1160
533 wsversion := r.Header.Get("Sec-WebSocket-Version")
534 if wsversion != "13" {
535 // Indicate we only support version 13. Should get a client from the future to fall back to version 13.
536 // ../rfc/6455:1435
537 w.Header().Set("Sec-WebSocket-Version", "13")
538 http.Error(w, "400 - bad request - websockets only supported with version 13"+recvid(r), http.StatusBadRequest)
539 lw.error(fmt.Errorf("Sec-WebSocket-Version %q not supported", wsversion))
540 return true
541 }
542
543 // ../rfc/6455:1143
544 if r.Method != "GET" {
545 http.Error(w, "400 - bad request - websockets only allowed with method GET"+recvid(r), http.StatusBadRequest)
546 lw.error(fmt.Errorf("websocket request only allowed with method GET"))
547 return true
548 }
549
550 // ../rfc/6455:1153
551 var connectionUpgrade bool
552 for _, s := range strings.Split(r.Header.Get("Connection"), ",") {
553 if strings.EqualFold(textproto.TrimString(s), "upgrade") {
554 connectionUpgrade = true
555 break
556 }
557 }
558 if !connectionUpgrade {
559 http.Error(w, "400 - bad request - connection header must be \"upgrade\""+recvid(r), http.StatusBadRequest)
560 lw.error(fmt.Errorf(`connection header is %q, must be "upgrade"`, r.Header.Get("Connection")))
561 return true
562 }
563
564 // ../rfc/6455:1156
565 wskey := r.Header.Get("Sec-WebSocket-Key")
566 key, err := base64.StdEncoding.DecodeString(wskey)
567 if err != nil || len(key) != 16 {
568 http.Error(w, "400 - bad request - websockets requires Sec-WebSocket-Key with 16 bytes base64-encoded value"+recvid(r), http.StatusBadRequest)
569 lw.error(fmt.Errorf("bad Sec-WebSocket-Key %q, must be 16 byte base64-encoded value", wskey))
570 return true
571 }
572
573 // ../rfc/6455:1162
574 // We don't look at the origin header. The backend needs to handle it, if it thinks
575 // that helps...
576 // We also don't look at Sec-WebSocket-Protocol and Sec-WebSocket-Extensions. The
577 // backend can set them, but it doesn't influence our forwarding of the data.
578
579 // If this is not a hijacker, there is not point in connecting to the backend.
580 hj, ok := lw.W.(http.Hijacker)
581 var cbr *bufio.ReadWriter
582 if !ok {
583 log().Info("cannot turn http connection into tcp connection (http.Hijacker)")
584 http.Error(w, "501 - not implemented - cannot turn this connection into websocket"+recvid(r), http.StatusNotImplemented)
585 lw.error(fmt.Errorf("connection not a http.Hijacker (%T)", lw.W))
586 return
587 }
588
589 freq := *r
590 freq.Proto = "HTTP/1.1"
591 freq.ProtoMajor = 1
592 freq.ProtoMinor = 1
593 fresp, beconn, err := websocketTransact(r.Context(), h.TargetURL, &freq)
594 if err != nil {
595 if errors.Is(err, errResponseNotWebsocket) {
596 http.Error(w, "400 - bad request - websocket not supported"+recvid(r), http.StatusBadRequest)
597 } else if errors.Is(err, errNotImplemented) {
598 http.Error(w, "501 - not implemented - "+err.Error()+recvid(r), http.StatusNotImplemented)
599 } else if os.IsTimeout(err) {
600 http.Error(w, "504 - gateway timeout"+recvid(r), http.StatusGatewayTimeout)
601 } else {
602 http.Error(w, "502 - bad gateway"+recvid(r), http.StatusBadGateway)
603 }
604 lw.error(err)
605 return
606 }
607 defer func() {
608 if beconn != nil {
609 if err := beconn.Close(); err != nil {
610 log().Check(err, "closing backend websocket connection")
611 }
612 }
613 }()
614
615 // Hijack the client connection so we can write the response ourselves, and start
616 // copying the websocket frames.
617 var cconn net.Conn
618 cconn, cbr, err = hj.Hijack()
619 if err != nil {
620 log().Debugx("cannot turn http transaction into websocket connection", err)
621 http.Error(w, "501 - not implemented - cannot turn this connection into websocket"+recvid(r), http.StatusNotImplemented)
622 lw.error(err)
623 return
624 }
625 defer func() {
626 if cconn != nil {
627 if err := cconn.Close(); err != nil {
628 log().Check(err, "closing client websocket connection")
629 }
630 }
631 }()
632
633 // Below this point, we can no longer write to the ResponseWriter.
634
635 // Mark as websocket response, for logging.
636 lw.WebsocketResponse = true
637 lw.setStatusCode(fresp.StatusCode)
638
639 for k, v := range h.ResponseHeaders {
640 fresp.Header.Add(k, v)
641 }
642
643 // Write the response to the client, completing its websocket handshake.
644 if err := fresp.Write(cconn); err != nil {
645 lw.error(fmt.Errorf("writing websocket response to client: %w", err))
646 return
647 }
648
649 errc := make(chan error, 1)
650
651 // Copy from client to backend.
652 go func() {
653 buf, err := cbr.Peek(cbr.Reader.Buffered())
654 if err != nil {
655 errc <- err
656 return
657 }
658 if len(buf) > 0 {
659 n, err := beconn.Write(buf)
660 if err != nil {
661 errc <- err
662 return
663 }
664 lw.SizeFromClient += int64(n)
665 }
666 n, err := io.Copy(beconn, cconn)
667 lw.SizeFromClient += n
668 errc <- err
669 }()
670
671 // Copy from backend to client.
672 go func() {
673 n, err := io.Copy(cconn, beconn)
674 lw.SizeToClient = n
675 errc <- err
676 }()
677
678 // Stop and close connection on first error from either size, typically a closed
679 // connection whose closing was already announced with a websocket frame.
680 lw.error(<-errc)
681 // Close connections so other goroutine stops as well.
682 if err := cconn.Close(); err != nil {
683 log().Check(err, "closing client websocket connection")
684 }
685 if err := beconn.Close(); err != nil {
686 log().Check(err, "closing backend websocket connection")
687 }
688 // Wait for goroutine so it has updated the logWriter.Size*Client fields before we
689 // continue with logging.
690 <-errc
691 cconn = nil
692 return true
693}
694
695func websocketTransact(ctx context.Context, targetURL *url.URL, r *http.Request) (rresp *http.Response, rconn net.Conn, rerr error) {
696 log := func() mlog.Log {
697 return pkglog.WithContext(r.Context())
698 }
699
700 // Dial the backend, possibly doing TLS. We assume the net/http DefaultTransport is
701 // unmodified.
702 transport := http.DefaultTransport.(*http.Transport)
703
704 // We haven't implemented using a proxy for websocket requests yet. If we need one,
705 // return an error instead of trying to connect directly, which would be a
706 // potential security issue.
707 treq := *r
708 treq.URL = targetURL
709 if purl, err := transport.Proxy(&treq); err != nil {
710 return nil, nil, fmt.Errorf("determining proxy for websocket backend connection: %w", err)
711 } else if purl != nil {
712 return nil, nil, fmt.Errorf("%w: proxy required for websocket connection to backend", errNotImplemented) // todo: implement?
713 }
714
715 host, port, err := net.SplitHostPort(targetURL.Host)
716 if err != nil {
717 host = targetURL.Host
718 if targetURL.Scheme == "https" {
719 port = "443"
720 } else {
721 port = "80"
722 }
723 }
724 addr := net.JoinHostPort(host, port)
725 conn, err := transport.DialContext(r.Context(), "tcp", addr)
726 if err != nil {
727 return nil, nil, fmt.Errorf("dial: %w", err)
728 }
729 if targetURL.Scheme == "https" {
730 tlsconn := tls.Client(conn, transport.TLSClientConfig)
731 ctx, cancel := context.WithTimeout(r.Context(), transport.TLSHandshakeTimeout)
732 defer cancel()
733 if err := tlsconn.HandshakeContext(ctx); err != nil {
734 return nil, nil, fmt.Errorf("tls handshake: %w", err)
735 }
736 conn = tlsconn
737 }
738 defer func() {
739 if rerr != nil {
740 if xerr := conn.Close(); xerr != nil {
741 log().Check(xerr, "cleaning up websocket connection")
742 }
743 }
744 }()
745
746 // todo: make timeout configurable?
747 if err := conn.SetDeadline(time.Now().Add(30 * time.Second)); err != nil {
748 log().Check(err, "set deadline for websocket request to backend")
749 }
750
751 // Set clean connection headers.
752 removeHopByHopHeaders(r.Header)
753 r.Header.Set("Connection", "Upgrade")
754 r.Header.Set("Upgrade", "websocket")
755
756 // Write the websocket request to the backend.
757 if err := r.Write(conn); err != nil {
758 return nil, nil, fmt.Errorf("writing request to backend: %w", err)
759 }
760
761 // Read response from backend.
762 br := bufio.NewReader(conn)
763 resp, err := http.ReadResponse(br, r)
764 if err != nil {
765 return nil, nil, fmt.Errorf("reading response from backend: %w", err)
766 }
767 defer func() {
768 if rerr != nil {
769 if xerr := resp.Body.Close(); xerr != nil {
770 log().Check(xerr, "closing response body after error")
771 }
772 }
773 }()
774 if err := conn.SetDeadline(time.Time{}); err != nil {
775 log().Check(err, "clearing deadline on websocket connection to backend")
776 }
777
778 // Check that the response from the backend server indicates it is websocket. If
779 // not, don't pass the backend response, but an error that websocket is not
780 // appropriate.
781 if err := checkWebsocketResponse(resp, r); err != nil {
782 return resp, nil, err
783 }
784
785 // note: net/http.Response.Body documents that it implements io.Writer for a
786 // status: 101 response. But that's not the case when the response has been read
787 // with http.ReadResponse. We'll write to the connection directly.
788
789 buf, err := br.Peek(br.Buffered())
790 if err != nil {
791 return resp, nil, fmt.Errorf("peek at buffered data written by backend: %w", err)
792 }
793 return resp, websocketConn{io.MultiReader(bytes.NewReader(buf), conn), conn}, nil
794}
795
796// A net.Conn but with reads coming from an io multireader (due to buffered reader
797// needed for http.ReadResponse).
798type websocketConn struct {
799 r io.Reader
800 net.Conn
801}
802
803func (c websocketConn) Read(buf []byte) (int, error) {
804 return c.r.Read(buf)
805}
806
807// Check that an HTTP response (from a backend) is a valid websocket response, i.e.
808// that it accepts the WebSocket "upgrade".
809// ../rfc/6455:1299
810func checkWebsocketResponse(resp *http.Response, req *http.Request) error {
811 if resp.StatusCode != 101 {
812 return fmt.Errorf("%w: response http status not 101 but %s", errResponseNotWebsocket, resp.Status)
813 }
814 if upgrade := resp.Header.Get("Upgrade"); !strings.EqualFold(upgrade, "websocket") {
815 return fmt.Errorf(`%w: response http status is 101, but Upgrade header is %q, should be "websocket"`, errResponseNotWebsocket, upgrade)
816 }
817 if connection := resp.Header.Get("Connection"); !strings.EqualFold(connection, "upgrade") {
818 return fmt.Errorf(`%w: response http status is 101, Upgrade is websocket, but Connection header is %q, should be "Upgrade"`, errResponseNotWebsocket, connection)
819 }
820 accept, err := base64.StdEncoding.DecodeString(resp.Header.Get("Sec-WebSocket-Accept"))
821 if err != nil {
822 return fmt.Errorf(`%w: response http status, Upgrade and Connection header are websocket, but Sec-WebSocket-Accept header is not valid base64: %v`, errResponseNotWebsocket, err)
823 }
824 exp := sha1.Sum([]byte(req.Header.Get("Sec-WebSocket-Key") + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"))
825 if !bytes.Equal(accept, exp[:]) {
826 return fmt.Errorf(`%w: response http status, Upgrade and Connection header are websocket, but backend Sec-WebSocket-Accept value does not match`, errResponseNotWebsocket)
827 }
828 // We don't have requirements for the other Sec-WebSocket headers. ../rfc/6455:1340
829 return nil
830}
831
832// From Go 1.20.4 src/net/http/httputil/reverseproxy.go:
833// Hop-by-hop headers. These are removed when sent to the backend.
834// As of RFC 7230, hop-by-hop headers are required to appear in the
835// Connection header field. These are the headers defined by the
836// obsoleted RFC 2616 (section 13.5.1) and are used for backward
837// compatibility.
838// ../rfc/2616:5128
839var hopHeaders = []string{
840 "Connection",
841 "Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
842 "Keep-Alive",
843 "Proxy-Authenticate",
844 "Proxy-Authorization",
845 "Te", // canonicalized version of "TE"
846 "Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522
847 "Transfer-Encoding",
848 "Upgrade",
849}
850
851// From Go 1.20.4 src/net/http/httputil/reverseproxy.go:
852// removeHopByHopHeaders removes hop-by-hop headers.
853func removeHopByHopHeaders(h http.Header) {
854 // RFC 7230, section 6.1: Remove headers listed in the "Connection" header.
855 // ../rfc/7230:2817
856 for _, f := range h["Connection"] {
857 for _, sf := range strings.Split(f, ",") {
858 if sf = textproto.TrimString(sf); sf != "" {
859 h.Del(sf)
860 }
861 }
862 }
863 // RFC 2616, section 13.5.1: Remove a set of known hop-by-hop headers.
864 // This behavior is superseded by the RFC 7230 Connection header, but
865 // preserve it for backwards compatibility.
866 // ../rfc/2616:5128
867 for _, f := range hopHeaders {
868 h.Del(f)
869 }
870}
871