A RetroSearch Logo

Home - News ( United States | United Kingdom | Italy | Germany ) - Football scores

Search Query:

Showing content from https://coder.github.io/websocket/coverage.html below:

websocket: Go Coverage Report

//go:build !js

package websocket

import (
        "bytes"
        "context"
        "crypto/sha1"
        "encoding/base64"
        "errors"
        "fmt"
        "io"
        "log"
        "net/http"
        "net/textproto"
        "net/url"
        "path"
        "strings"

        "github.com/coder/websocket/internal/errd"
)

// AcceptOptions represents Accept's options.
type AcceptOptions struct {
        // Subprotocols lists the WebSocket subprotocols that Accept will negotiate with the client.
        // The empty subprotocol will always be negotiated as per RFC 6455. If you would like to
        // reject it, close the connection when c.Subprotocol() == "".
        Subprotocols []string

        // InsecureSkipVerify is used to disable Accept's origin verification behaviour.
        //
        // You probably want to use OriginPatterns instead.
        InsecureSkipVerify bool

        // OriginPatterns lists the host patterns for authorized origins.
        // The request host is always authorized.
        // Use this to enable cross origin WebSockets.
        //
        // i.e javascript running on example.com wants to access a WebSocket server at chat.example.com.
        // In such a case, example.com is the origin and chat.example.com is the request host.
        // One would set this field to []string{"example.com"} to authorize example.com to connect.
        //
        // Each pattern is matched case insensitively against the request origin host
        // with path.Match.
        // See https://golang.org/pkg/path/#Match
        //
        // Please ensure you understand the ramifications of enabling this.
        // If used incorrectly your WebSocket server will be open to CSRF attacks.
        //
        // Do not use * as a pattern to allow any origin, prefer to use InsecureSkipVerify instead
        // to bring attention to the danger of such a setting.
        OriginPatterns []string

        // CompressionMode controls the compression mode.
        // Defaults to CompressionDisabled.
        //
        // See docs on CompressionMode for details.
        CompressionMode CompressionMode

        // CompressionThreshold controls the minimum size of a message before compression is applied.
        //
        // Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes
        // for CompressionContextTakeover.
        CompressionThreshold int

        // OnPingReceived is an optional callback invoked synchronously when a ping frame is received.
        //
        // The payload contains the application data of the ping frame.
        // If the callback returns false, the subsequent pong frame will not be sent.
        // To avoid blocking, any expensive processing should be performed asynchronously using a goroutine.
        OnPingReceived func(ctx context.Context, payload []byte) bool

        // OnPongReceived is an optional callback invoked synchronously when a pong frame is received.
        //
        // The payload contains the application data of the pong frame.
        // To avoid blocking, any expensive processing should be performed asynchronously using a goroutine.
        //
        // Unlike OnPingReceived, this callback does not return a value because a pong frame
        // is a response to a ping and does not trigger any further frame transmission.
        OnPongReceived func(ctx context.Context, payload []byte)
}

func (opts *AcceptOptions) cloneWithDefaults() *AcceptOptions {
        var o AcceptOptions
        if opts != nil {
                o = *opts
        }
        return &o
}

// Accept accepts a WebSocket handshake from a client and upgrades the
// the connection to a WebSocket.
//
// Accept will not allow cross origin requests by default.
// See the InsecureSkipVerify and OriginPatterns options to allow cross origin requests.
//
// Accept will write a response to w on all errors.
//
// Note that using the http.Request Context after Accept returns may lead to
// unexpected behavior (see http.Hijacker).
func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) {
        return accept(w, r, opts)
}

func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Conn, err error) {
        defer errd.Wrap(&err, "failed to accept WebSocket connection")

        errCode, err := verifyClientRequest(w, r)
        if err != nil {
                http.Error(w, err.Error(), errCode)
                return nil, err
        }

        opts = opts.cloneWithDefaults()
        if !opts.InsecureSkipVerify {
                err = authenticateOrigin(r, opts.OriginPatterns)
                if err != nil {
                        if errors.Is(err, path.ErrBadPattern) {
                                log.Printf("websocket: %v", err)
                                err = errors.New(http.StatusText(http.StatusForbidden))
                        }
                        http.Error(w, err.Error(), http.StatusForbidden)
                        return nil, err
                }
        }

        hj, ok := hijacker(w)
        if !ok {
                err = errors.New("http.ResponseWriter does not implement http.Hijacker")
                http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented)
                return nil, err
        }

        w.Header().Set("Upgrade", "websocket")
        w.Header().Set("Connection", "Upgrade")

        key := r.Header.Get("Sec-WebSocket-Key")
        w.Header().Set("Sec-WebSocket-Accept", secWebSocketAccept(key))

        subproto := selectSubprotocol(r, opts.Subprotocols)
        if subproto != "" {
                w.Header().Set("Sec-WebSocket-Protocol", subproto)
        }

        copts, ok := selectDeflate(websocketExtensions(r.Header), opts.CompressionMode)
        if ok {
                w.Header().Set("Sec-WebSocket-Extensions", copts.String())
        }

        w.WriteHeader(http.StatusSwitchingProtocols)
        // See https://github.com/nhooyr/websocket/issues/166
        if ginWriter, ok := w.(interface {
                WriteHeaderNow()
        }); ok {
                ginWriter.WriteHeaderNow()
        }

        netConn, brw, err := hj.Hijack()
        if err != nil {
                err = fmt.Errorf("failed to hijack connection: %w", err)
                http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
                return nil, err
        }

        // https://github.com/golang/go/issues/32314
        b, _ := brw.Reader.Peek(brw.Reader.Buffered())
        brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), netConn))

        return newConn(connConfig{
                subprotocol:    w.Header().Get("Sec-WebSocket-Protocol"),
                rwc:            netConn,
                client:         false,
                copts:          copts,
                flateThreshold: opts.CompressionThreshold,
                onPingReceived: opts.OnPingReceived,
                onPongReceived: opts.OnPongReceived,

                br: brw.Reader,
                bw: brw.Writer,
        }), nil
}

func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _ error) {
        if !r.ProtoAtLeast(1, 1) {
                return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: handshake request must be at least HTTP/1.1: %q", r.Proto)
        }

        if !headerContainsTokenIgnoreCase(r.Header, "Connection", "Upgrade") {
                w.Header().Set("Connection", "Upgrade")
                w.Header().Set("Upgrade", "websocket")
                return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection"))
        }

        if !headerContainsTokenIgnoreCase(r.Header, "Upgrade", "websocket") {
                w.Header().Set("Connection", "Upgrade")
                w.Header().Set("Upgrade", "websocket")
                return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade"))
        }

        if r.Method != "GET" {
                return http.StatusMethodNotAllowed, fmt.Errorf("WebSocket protocol violation: handshake request method is not GET but %q", r.Method)
        }

        if r.Header.Get("Sec-WebSocket-Version") != "13" {
                w.Header().Set("Sec-WebSocket-Version", "13")
                return http.StatusBadRequest, fmt.Errorf("unsupported WebSocket protocol version (only 13 is supported): %q", r.Header.Get("Sec-WebSocket-Version"))
        }

        websocketSecKeys := r.Header.Values("Sec-WebSocket-Key")
        if len(websocketSecKeys) == 0 {
                return http.StatusBadRequest, errors.New("WebSocket protocol violation: missing Sec-WebSocket-Key")
        }

        if len(websocketSecKeys) > 1 {
                return http.StatusBadRequest, errors.New("WebSocket protocol violation: multiple Sec-WebSocket-Key headers")
        }

        // The RFC states to remove any leading or trailing whitespace.
        websocketSecKey := strings.TrimSpace(websocketSecKeys[0])
        if v, err := base64.StdEncoding.DecodeString(websocketSecKey); err != nil || len(v) != 16 {
                return http.StatusBadRequest, fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Key %q, must be a 16 byte base64 encoded string", websocketSecKey)
        }

        return 0, nil
}

func authenticateOrigin(r *http.Request, originHosts []string) error {
        origin := r.Header.Get("Origin")
        if origin == "" {
                return nil
        }

        u, err := url.Parse(origin)
        if err != nil {
                return fmt.Errorf("failed to parse Origin header %q: %w", origin, err)
        }

        if strings.EqualFold(r.Host, u.Host) {
                return nil
        }

        for _, hostPattern := range originHosts {
                matched, err := match(hostPattern, u.Host)
                if err != nil {
                        return fmt.Errorf("failed to parse path pattern %q: %w", hostPattern, err)
                }
                if matched {
                        return nil
                }
        }
        if u.Host == "" {
                return fmt.Errorf("request Origin %q is not a valid URL with a host", origin)
        }
        return fmt.Errorf("request Origin %q is not authorized for Host %q", u.Host, r.Host)
}

func match(pattern, s string) (bool, error) {
        return path.Match(strings.ToLower(pattern), strings.ToLower(s))
}

func selectSubprotocol(r *http.Request, subprotocols []string) string {
        cps := headerTokens(r.Header, "Sec-WebSocket-Protocol")
        for _, sp := range subprotocols {
                for _, cp := range cps {
                        if strings.EqualFold(sp, cp) {
                                return cp
                        }
                }
        }
        return ""
}

func selectDeflate(extensions []websocketExtension, mode CompressionMode) (*compressionOptions, bool) {
        if mode == CompressionDisabled {
                return nil, false
        }
        for _, ext := range extensions {
                switch ext.name {
                // We used to implement x-webkit-deflate-frame too for Safari but Safari has bugs...
                // See https://github.com/nhooyr/websocket/issues/218
                case "permessage-deflate":
                        copts, ok := acceptDeflate(ext, mode)
                        if ok {
                                return copts, true
                        }
                }
        }
        return nil, false
}

func acceptDeflate(ext websocketExtension, mode CompressionMode) (*compressionOptions, bool) {
        copts := mode.opts()
        for _, p := range ext.params {
                switch p {
                case "client_no_context_takeover":
                        copts.clientNoContextTakeover = true
                        continue
                case "server_no_context_takeover":
                        copts.serverNoContextTakeover = true
                        continue
                case "client_max_window_bits",
                        "server_max_window_bits=15":
                        continue
                }

                if strings.HasPrefix(p, "client_max_window_bits=") {
                        // We can't adjust the deflate window, but decoding with a larger window is acceptable.
                        continue
                }
                return nil, false
        }
        return copts, true
}

func headerContainsTokenIgnoreCase(h http.Header, key, token string) bool {
        for _, t := range headerTokens(h, key) {
                if strings.EqualFold(t, token) {
                        return true
                }
        }
        return false
}

type websocketExtension struct {
        name   string
        params []string
}

func websocketExtensions(h http.Header) []websocketExtension {
        var exts []websocketExtension
        extStrs := headerTokens(h, "Sec-WebSocket-Extensions")
        for _, extStr := range extStrs {
                if extStr == "" {
                        continue
                }

                vals := strings.Split(extStr, ";")
                for i := range vals {
                        vals[i] = strings.TrimSpace(vals[i])
                }

                e := websocketExtension{
                        name:   vals[0],
                        params: vals[1:],
                }

                exts = append(exts, e)
        }
        return exts
}

func headerTokens(h http.Header, key string) []string {
        key = textproto.CanonicalMIMEHeaderKey(key)
        var tokens []string
        for _, v := range h[key] {
                v = strings.TrimSpace(v)
                for _, t := range strings.Split(v, ",") {
                        t = strings.TrimSpace(t)
                        tokens = append(tokens, t)
                }
        }
        return tokens
}

var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")

func secWebSocketAccept(secWebSocketKey string) string {
        h := sha1.New()
        h.Write([]byte(secWebSocketKey))
        h.Write(keyGUID)

        return base64.StdEncoding.EncodeToString(h.Sum(nil))
}
//go:build !js

package websocket

import (
        "context"
        "encoding/binary"
        "errors"
        "fmt"
        "net"
        "time"

        "github.com/coder/websocket/internal/errd"
)

// StatusCode represents a WebSocket status code.
// https://tools.ietf.org/html/rfc6455#section-7.4
type StatusCode int

// https://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number
//
// These are only the status codes defined by the protocol.
//
// You can define custom codes in the 3000-4999 range.
// The 3000-3999 range is reserved for use by libraries, frameworks and applications.
// The 4000-4999 range is reserved for private use.
const (
        StatusNormalClosure   StatusCode = 1000
        StatusGoingAway       StatusCode = 1001
        StatusProtocolError   StatusCode = 1002
        StatusUnsupportedData StatusCode = 1003

        // 1004 is reserved and so unexported.
        statusReserved StatusCode = 1004

        // StatusNoStatusRcvd cannot be sent in a close message.
        // It is reserved for when a close message is received without
        // a status code.
        StatusNoStatusRcvd StatusCode = 1005

        // StatusAbnormalClosure is exported for use only with Wasm.
        // In non Wasm Go, the returned error will indicate whether the
        // connection was closed abnormally.
        StatusAbnormalClosure StatusCode = 1006

        StatusInvalidFramePayloadData StatusCode = 1007
        StatusPolicyViolation         StatusCode = 1008
        StatusMessageTooBig           StatusCode = 1009
        StatusMandatoryExtension      StatusCode = 1010
        StatusInternalError           StatusCode = 1011
        StatusServiceRestart          StatusCode = 1012
        StatusTryAgainLater           StatusCode = 1013
        StatusBadGateway              StatusCode = 1014

        // StatusTLSHandshake is only exported for use with Wasm.
        // In non Wasm Go, the returned error will indicate whether there was
        // a TLS handshake failure.
        StatusTLSHandshake StatusCode = 1015
)

// CloseError is returned when the connection is closed with a status and reason.
//
// Use Go 1.13's errors.As to check for this error.
// Also see the CloseStatus helper.
type CloseError struct {
        Code   StatusCode
        Reason string
}

func (ce CloseError) Error() string {
        return fmt.Sprintf("status = %v and reason = %q", ce.Code, ce.Reason)
}

// CloseStatus is a convenience wrapper around Go 1.13's errors.As to grab
// the status code from a CloseError.
//
// -1 will be returned if the passed error is nil or not a CloseError.
func CloseStatus(err error) StatusCode {
        var ce CloseError
        if errors.As(err, &ce) {
                return ce.Code
        }
        return -1
}

// Close performs the WebSocket close handshake with the given status code and reason.
//
// It will write a WebSocket close frame with a timeout of 5s and then wait 5s for
// the peer to send a close frame.
// All data messages received from the peer during the close handshake will be discarded.
//
// The connection can only be closed once. Additional calls to Close
// are no-ops.
//
// The maximum length of reason must be 125 bytes. Avoid sending a dynamic reason.
//
// Close will unblock all goroutines interacting with the connection once
// complete.
func (c *Conn) Close(code StatusCode, reason string) (err error) {
        defer errd.Wrap(&err, "failed to close WebSocket")

        if c.casClosing() {
                err = c.waitGoroutines()
                if err != nil {
                        return err
                }
                return net.ErrClosed
        }
        defer func() {
                if errors.Is(err, net.ErrClosed) {
                        err = nil
                }
        }()

        err = c.closeHandshake(code, reason)

        err2 := c.close()
        if err == nil && err2 != nil {
                err = err2
        }

        err2 = c.waitGoroutines()
        if err == nil && err2 != nil {
                err = err2
        }

        return err
}

// CloseNow closes the WebSocket connection without attempting a close handshake.
// Use when you do not want the overhead of the close handshake.
func (c *Conn) CloseNow() (err error) {
        defer errd.Wrap(&err, "failed to immediately close WebSocket")

        if c.casClosing() {
                err = c.waitGoroutines()
                if err != nil {
                        return err
                }
                return net.ErrClosed
        }
        defer func() {
                if errors.Is(err, net.ErrClosed) {
                        err = nil
                }
        }()

        err = c.close()

        err2 := c.waitGoroutines()
        if err == nil && err2 != nil {
                err = err2
        }
        return err
}

func (c *Conn) closeHandshake(code StatusCode, reason string) error {
        err := c.writeClose(code, reason)
        if err != nil {
                return err
        }

        err = c.waitCloseHandshake()
        if CloseStatus(err) != code {
                return err
        }
        return nil
}

func (c *Conn) writeClose(code StatusCode, reason string) error {
        ce := CloseError{
                Code:   code,
                Reason: reason,
        }

        var p []byte
        var err error
        if ce.Code != StatusNoStatusRcvd {
                p, err = ce.bytes()
                if err != nil {
                        return err
                }
        }

        ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
        defer cancel()

        err = c.writeControl(ctx, opClose, p)
        // If the connection closed as we're writing we ignore the error as we might
        // have written the close frame, the peer responded and then someone else read it
        // and closed the connection.
        if err != nil && !errors.Is(err, net.ErrClosed) {
                return err
        }
        return nil
}

func (c *Conn) waitCloseHandshake() error {
        ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
        defer cancel()

        err := c.readMu.lock(ctx)
        if err != nil {
                return err
        }
        defer c.readMu.unlock()

        for i := int64(0); i < c.msgReader.payloadLength; i++ {
                _, err := c.br.ReadByte()
                if err != nil {
                        return err
                }
        }

        for {
                h, err := c.readLoop(ctx)
                if err != nil {
                        return err
                }

                for i := int64(0); i < h.payloadLength; i++ {
                        _, err := c.br.ReadByte()
                        if err != nil {
                                return err
                        }
                }
        }
}

func (c *Conn) waitGoroutines() error {
        t := time.NewTimer(time.Second * 15)
        defer t.Stop()

        select {
        case <-c.timeoutLoopDone:
        case <-t.C:
                return errors.New("failed to wait for timeoutLoop goroutine to exit")
        }

        c.closeReadMu.Lock()
        closeRead := c.closeReadCtx != nil
        c.closeReadMu.Unlock()
        if closeRead {
                select {
                case <-c.closeReadDone:
                case <-t.C:
                        return errors.New("failed to wait for close read goroutine to exit")
                }
        }

        select {
        case <-c.closed:
        case <-t.C:
                return errors.New("failed to wait for connection to be closed")
        }

        return nil
}

func parseClosePayload(p []byte) (CloseError, error) {
        if len(p) == 0 {
                return CloseError{
                        Code: StatusNoStatusRcvd,
                }, nil
        }

        if len(p) < 2 {
                return CloseError{}, fmt.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p)
        }

        ce := CloseError{
                Code:   StatusCode(binary.BigEndian.Uint16(p)),
                Reason: string(p[2:]),
        }

        if !validWireCloseCode(ce.Code) {
                return CloseError{}, fmt.Errorf("invalid status code %v", ce.Code)
        }

        return ce, nil
}

// See http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number
// and https://tools.ietf.org/html/rfc6455#section-7.4.1
func validWireCloseCode(code StatusCode) bool {
        switch code {
        case statusReserved, StatusNoStatusRcvd, StatusAbnormalClosure, StatusTLSHandshake:
                return false
        }

        if code >= StatusNormalClosure && code <= StatusBadGateway {
                return true
        }
        if code >= 3000 && code <= 4999 {
                return true
        }

        return false
}

func (ce CloseError) bytes() ([]byte, error) {
        p, err := ce.bytesErr()
        if err != nil {
                err = fmt.Errorf("failed to marshal close frame: %w", err)
                ce = CloseError{
                        Code: StatusInternalError,
                }
                p, _ = ce.bytesErr()
        }
        return p, err
}

const maxCloseReason = maxControlPayload - 2

func (ce CloseError) bytesErr() ([]byte, error) {
        if len(ce.Reason) > maxCloseReason {
                return nil, fmt.Errorf("reason string max is %v but got %q with length %v", maxCloseReason, ce.Reason, len(ce.Reason))
        }

        if !validWireCloseCode(ce.Code) {
                return nil, fmt.Errorf("status code %v cannot be set", ce.Code)
        }

        buf := make([]byte, 2+len(ce.Reason))
        binary.BigEndian.PutUint16(buf, uint16(ce.Code))
        copy(buf[2:], ce.Reason)
        return buf, nil
}

func (c *Conn) casClosing() bool {
        return c.closing.Swap(true)
}

func (c *Conn) isClosed() bool {
        select {
        case <-c.closed:
                return true
        default:
                return false
        }
}
//go:build !js

package websocket

import (
        "compress/flate"
        "io"
        "sync"
)

// CompressionMode represents the modes available to the permessage-deflate extension.
// See https://tools.ietf.org/html/rfc7692
//
// Works in all modern browsers except Safari which does not implement the permessage-deflate extension.
//
// Compression is only used if the peer supports the mode selected.
type CompressionMode int

const (
        // CompressionDisabled disables the negotiation of the permessage-deflate extension.
        //
        // This is the default. Do not enable compression without benchmarking for your particular use case first.
        CompressionDisabled CompressionMode = iota

        // CompressionContextTakeover compresses each message greater than 128 bytes reusing the 32 KB sliding window from
        // previous messages. i.e compression context across messages is preserved.
        //
        // As most WebSocket protocols are text based and repetitive, this compression mode can be very efficient.
        //
        // The memory overhead is a fixed 32 KB sliding window, a fixed 1.2 MB flate.Writer and a sync.Pool of 40 KB flate.Reader's
        // that are used when reading and then returned.
        //
        // Thus, it uses more memory than CompressionNoContextTakeover but compresses more efficiently.
        //
        // If the peer does not support CompressionContextTakeover then we will fall back to CompressionNoContextTakeover.
        CompressionContextTakeover

        // CompressionNoContextTakeover compresses each message greater than 512 bytes. Each message is compressed with
        // a new 1.2 MB flate.Writer pulled from a sync.Pool. Each message is read with a 40 KB flate.Reader pulled from
        // a sync.Pool.
        //
        // This means less efficient compression as the sliding window from previous messages will not be used but the
        // memory overhead will be lower as there will be no fixed cost for the flate.Writer nor the 32 KB sliding window.
        // Especially if the connections are long lived and seldom written to.
        //
        // Thus, it uses less memory than CompressionContextTakeover but compresses less efficiently.
        //
        // If the peer does not support CompressionNoContextTakeover then we will fall back to CompressionDisabled.
        CompressionNoContextTakeover
)

func (m CompressionMode) opts() *compressionOptions {
        return &compressionOptions{
                clientNoContextTakeover: m == CompressionNoContextTakeover,
                serverNoContextTakeover: m == CompressionNoContextTakeover,
        }
}

type compressionOptions struct {
        clientNoContextTakeover bool
        serverNoContextTakeover bool
}

func (copts *compressionOptions) String() string {
        s := "permessage-deflate"
        if copts.clientNoContextTakeover {
                s += "; client_no_context_takeover"
        }
        if copts.serverNoContextTakeover {
                s += "; server_no_context_takeover"
        }
        return s
}

// These bytes are required to get flate.Reader to return.
// They are removed when sending to avoid the overhead as
// WebSocket framing tell's when the message has ended but then
// we need to add them back otherwise flate.Reader keeps
// trying to read more bytes.
const deflateMessageTail = "\x00\x00\xff\xff"

type trimLastFourBytesWriter struct {
        w    io.Writer
        tail []byte
}

func (tw *trimLastFourBytesWriter) reset() {
        if tw != nil && tw.tail != nil {
                tw.tail = tw.tail[:0]
        }
}

func (tw *trimLastFourBytesWriter) Write(p []byte) (int, error) {
        if tw.tail == nil {
                tw.tail = make([]byte, 0, 4)
        }

        extra := len(tw.tail) + len(p) - 4

        if extra <= 0 {
                tw.tail = append(tw.tail, p...)
                return len(p), nil
        }

        // Now we need to write as many extra bytes as we can from the previous tail.
        if extra > len(tw.tail) {
                extra = len(tw.tail)
        }
        if extra > 0 {
                _, err := tw.w.Write(tw.tail[:extra])
                if err != nil {
                        return 0, err
                }

                // Shift remaining bytes in tail over.
                n := copy(tw.tail, tw.tail[extra:])
                tw.tail = tw.tail[:n]
        }

        // If p is less than or equal to 4 bytes,
        // all of it is is part of the tail.
        if len(p) <= 4 {
                tw.tail = append(tw.tail, p...)
                return len(p), nil
        }

        // Otherwise, only the last 4 bytes are.
        tw.tail = append(tw.tail, p[len(p)-4:]...)

        p = p[:len(p)-4]
        n, err := tw.w.Write(p)
        return n + 4, err
}

var flateReaderPool sync.Pool

func getFlateReader(r io.Reader, dict []byte) io.Reader {
        fr, ok := flateReaderPool.Get().(io.Reader)
        if !ok {
                return flate.NewReaderDict(r, dict)
        }
        fr.(flate.Resetter).Reset(r, dict)
        return fr
}

func putFlateReader(fr io.Reader) {
        flateReaderPool.Put(fr)
}

var flateWriterPool sync.Pool

func getFlateWriter(w io.Writer) *flate.Writer {
        fw, ok := flateWriterPool.Get().(*flate.Writer)
        if !ok {
                fw, _ = flate.NewWriter(w, flate.BestSpeed)
                return fw
        }
        fw.Reset(w)
        return fw
}

func putFlateWriter(w *flate.Writer) {
        flateWriterPool.Put(w)
}

type slidingWindow struct {
        buf []byte
}

var (
        swPoolMu sync.RWMutex
        swPool   = map[int]*sync.Pool{}
)

func slidingWindowPool(n int) *sync.Pool {
        swPoolMu.RLock()
        p, ok := swPool[n]
        swPoolMu.RUnlock()
        if ok {
                return p
        }

        p = &sync.Pool{}

        swPoolMu.Lock()
        swPool[n] = p
        swPoolMu.Unlock()

        return p
}

func (sw *slidingWindow) init(n int) {
        if sw.buf != nil {
                return
        }

        if n == 0 {
                n = 32768
        }

        p := slidingWindowPool(n)
        sw2, ok := p.Get().(*slidingWindow)
        if ok {
                *sw = *sw2
        } else {
                sw.buf = make([]byte, 0, n)
        }
}

func (sw *slidingWindow) close() {
        sw.buf = sw.buf[:0]
        swPoolMu.Lock()
        swPool[cap(sw.buf)].Put(sw)
        swPoolMu.Unlock()
}

func (sw *slidingWindow) write(p []byte) {
        if len(p) >= cap(sw.buf) {
                sw.buf = sw.buf[:cap(sw.buf)]
                p = p[len(p)-cap(sw.buf):]
                copy(sw.buf, p)
                return
        }

        left := cap(sw.buf) - len(sw.buf)
        if left < len(p) {
                // We need to shift spaceNeeded bytes from the end to make room for p at the end.
                spaceNeeded := len(p) - left
                copy(sw.buf, sw.buf[spaceNeeded:])
                sw.buf = sw.buf[:len(sw.buf)-spaceNeeded]
        }

        sw.buf = append(sw.buf, p...)
}
//go:build !js

package websocket

import (
        "bufio"
        "context"
        "fmt"
        "io"
        "net"
        "runtime"
        "strconv"
        "sync"
        "sync/atomic"
)

// MessageType represents the type of a WebSocket message.
// See https://tools.ietf.org/html/rfc6455#section-5.6
type MessageType int

// MessageType constants.
const (
        // MessageText is for UTF-8 encoded text messages like JSON.
        MessageText MessageType = iota + 1
        // MessageBinary is for binary messages like protobufs.
        MessageBinary
)

// Conn represents a WebSocket connection.
// All methods may be called concurrently except for Reader and Read.
//
// You must always read from the connection. Otherwise control
// frames will not be handled. See Reader and CloseRead.
//
// Be sure to call Close on the connection when you
// are finished with it to release associated resources.
//
// On any error from any method, the connection is closed
// with an appropriate reason.
//
// This applies to context expirations as well unfortunately.
// See https://github.com/nhooyr/websocket/issues/242#issuecomment-633182220
type Conn struct {
        noCopy noCopy

        subprotocol    string
        rwc            io.ReadWriteCloser
        client         bool
        copts          *compressionOptions
        flateThreshold int
        br             *bufio.Reader
        bw             *bufio.Writer

        readTimeout     chan context.Context
        writeTimeout    chan context.Context
        timeoutLoopDone chan struct{}

        // Read state.
        readMu         *mu
        readHeaderBuf  [8]byte
        readControlBuf [maxControlPayload]byte
        msgReader      *msgReader

        // Write state.
        msgWriter      *msgWriter
        writeFrameMu   *mu
        writeBuf       []byte
        writeHeaderBuf [8]byte
        writeHeader    header

        // Close handshake state.
        closeStateMu     sync.RWMutex
        closeReceivedErr error
        closeSentErr     error

        // CloseRead state.
        closeReadMu   sync.Mutex
        closeReadCtx  context.Context
        closeReadDone chan struct{}

        closing atomic.Bool
        closeMu sync.Mutex // Protects following.
        closed  chan struct{}

        pingCounter    atomic.Int64
        activePingsMu  sync.Mutex
        activePings    map[string]chan<- struct{}
        onPingReceived func(context.Context, []byte) bool
        onPongReceived func(context.Context, []byte)
}

type connConfig struct {
        subprotocol    string
        rwc            io.ReadWriteCloser
        client         bool
        copts          *compressionOptions
        flateThreshold int
        onPingReceived func(context.Context, []byte) bool
        onPongReceived func(context.Context, []byte)

        br *bufio.Reader
        bw *bufio.Writer
}

func newConn(cfg connConfig) *Conn {
        c := &Conn{
                subprotocol:    cfg.subprotocol,
                rwc:            cfg.rwc,
                client:         cfg.client,
                copts:          cfg.copts,
                flateThreshold: cfg.flateThreshold,

                br: cfg.br,
                bw: cfg.bw,

                readTimeout:     make(chan context.Context),
                writeTimeout:    make(chan context.Context),
                timeoutLoopDone: make(chan struct{}),

                closed:         make(chan struct{}),
                activePings:    make(map[string]chan<- struct{}),
                onPingReceived: cfg.onPingReceived,
                onPongReceived: cfg.onPongReceived,
        }

        c.readMu = newMu(c)
        c.writeFrameMu = newMu(c)

        c.msgReader = newMsgReader(c)

        c.msgWriter = newMsgWriter(c)
        if c.client {
                c.writeBuf = extractBufioWriterBuf(c.bw, c.rwc)
        }

        if c.flate() && c.flateThreshold == 0 {
                c.flateThreshold = 128
                if !c.msgWriter.flateContextTakeover() {
                        c.flateThreshold = 512
                }
        }

        runtime.SetFinalizer(c, func(c *Conn) {
                c.close()
        })

        go c.timeoutLoop()

        return c
}

// Subprotocol returns the negotiated subprotocol.
// An empty string means the default protocol.
func (c *Conn) Subprotocol() string {
        return c.subprotocol
}

func (c *Conn) close() error {
        c.closeMu.Lock()
        defer c.closeMu.Unlock()

        if c.isClosed() {
                return net.ErrClosed
        }
        runtime.SetFinalizer(c, nil)
        close(c.closed)

        // Have to close after c.closed is closed to ensure any goroutine that wakes up
        // from the connection being closed also sees that c.closed is closed and returns
        // closeErr.
        err := c.rwc.Close()
        // With the close of rwc, these become safe to close.
        c.msgWriter.close()
        c.msgReader.close()
        return err
}

func (c *Conn) timeoutLoop() {
        defer close(c.timeoutLoopDone)

        readCtx := context.Background()
        writeCtx := context.Background()

        for {
                select {
                case <-c.closed:
                        return

                case writeCtx = <-c.writeTimeout:
                case readCtx = <-c.readTimeout:

                case <-readCtx.Done():
                        c.close()
                        return
                case <-writeCtx.Done():
                        c.close()
                        return
                }
        }
}

func (c *Conn) flate() bool {
        return c.copts != nil
}

// Ping sends a ping to the peer and waits for a pong.
// Use this to measure latency or ensure the peer is responsive.
// Ping must be called concurrently with Reader as it does
// not read from the connection but instead waits for a Reader call
// to read the pong.
//
// TCP Keepalives should suffice for most use cases.
func (c *Conn) Ping(ctx context.Context) error {
        p := c.pingCounter.Add(1)

        err := c.ping(ctx, strconv.FormatInt(p, 10))
        if err != nil {
                return fmt.Errorf("failed to ping: %w", err)
        }
        return nil
}

func (c *Conn) ping(ctx context.Context, p string) error {
        pong := make(chan struct{}, 1)

        c.activePingsMu.Lock()
        c.activePings[p] = pong
        c.activePingsMu.Unlock()

        defer func() {
                c.activePingsMu.Lock()
                delete(c.activePings, p)
                c.activePingsMu.Unlock()
        }()

        err := c.writeControl(ctx, opPing, []byte(p))
        if err != nil {
                return err
        }

        select {
        case <-c.closed:
                return net.ErrClosed
        case <-ctx.Done():
                return fmt.Errorf("failed to wait for pong: %w", ctx.Err())
        case <-pong:
                return nil
        }
}

type mu struct {
        c  *Conn
        ch chan struct{}
}

func newMu(c *Conn) *mu {
        return &mu{
                c:  c,
                ch: make(chan struct{}, 1),
        }
}

func (m *mu) forceLock() {
        m.ch <- struct{}{}
}

func (m *mu) tryLock() bool {
        select {
        case m.ch <- struct{}{}:
                return true
        default:
                return false
        }
}

func (m *mu) lock(ctx context.Context) error {
        select {
        case <-m.c.closed:
                return net.ErrClosed
        case <-ctx.Done():
                return fmt.Errorf("failed to acquire lock: %w", ctx.Err())
        case m.ch <- struct{}{}:
                // To make sure the connection is certainly alive.
                // As it's possible the send on m.ch was selected
                // over the receive on closed.
                select {
                case <-m.c.closed:
                        // Make sure to release.
                        m.unlock()
                        return net.ErrClosed
                default:
                }
                return nil
        }
}

func (m *mu) unlock() {
        select {
        case <-m.ch:
        default:
        }
}

type noCopy struct{}

func (*noCopy) Lock() {}
//go:build !js

package websocket

import (
        "bufio"
        "bytes"
        "context"
        "crypto/rand"
        "encoding/base64"
        "fmt"
        "io"
        "net/http"
        "net/url"
        "strings"
        "sync"
        "time"

        "github.com/coder/websocket/internal/errd"
)

// DialOptions represents Dial's options.
type DialOptions struct {
        // HTTPClient is used for the connection.
        // Its Transport must return writable bodies for WebSocket handshakes.
        // http.Transport does beginning with Go 1.12.
        HTTPClient *http.Client

        // HTTPHeader specifies the HTTP headers included in the handshake request.
        HTTPHeader http.Header

        // Host optionally overrides the Host HTTP header to send. If empty, the value
        // of URL.Host will be used.
        Host string

        // Subprotocols lists the WebSocket subprotocols to negotiate with the server.
        Subprotocols []string

        // CompressionMode controls the compression mode.
        // Defaults to CompressionDisabled.
        //
        // See docs on CompressionMode for details.
        CompressionMode CompressionMode

        // CompressionThreshold controls the minimum size of a message before compression is applied.
        //
        // Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes
        // for CompressionContextTakeover.
        CompressionThreshold int

        // OnPingReceived is an optional callback invoked synchronously when a ping frame is received.
        //
        // The payload contains the application data of the ping frame.
        // If the callback returns false, the subsequent pong frame will not be sent.
        // To avoid blocking, any expensive processing should be performed asynchronously using a goroutine.
        OnPingReceived func(ctx context.Context, payload []byte) bool

        // OnPongReceived is an optional callback invoked synchronously when a pong frame is received.
        //
        // The payload contains the application data of the pong frame.
        // To avoid blocking, any expensive processing should be performed asynchronously using a goroutine.
        //
        // Unlike OnPingReceived, this callback does not return a value because a pong frame
        // is a response to a ping and does not trigger any further frame transmission.
        OnPongReceived func(ctx context.Context, payload []byte)
}

func (opts *DialOptions) cloneWithDefaults(ctx context.Context) (context.Context, context.CancelFunc, *DialOptions) {
        var cancel context.CancelFunc

        var o DialOptions
        if opts != nil {
                o = *opts
        }
        if o.HTTPClient == nil {
                o.HTTPClient = http.DefaultClient
        }
        if o.HTTPClient.Timeout > 0 {
                ctx, cancel = context.WithTimeout(ctx, o.HTTPClient.Timeout)

                newClient := *o.HTTPClient
                newClient.Timeout = 0
                o.HTTPClient = &newClient
        }
        if o.HTTPHeader == nil {
                o.HTTPHeader = http.Header{}
        }
        newClient := *o.HTTPClient
        oldCheckRedirect := o.HTTPClient.CheckRedirect
        newClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
                switch req.URL.Scheme {
                case "ws":
                        req.URL.Scheme = "http"
                case "wss":
                        req.URL.Scheme = "https"
                }
                if oldCheckRedirect != nil {
                        return oldCheckRedirect(req, via)
                }
                return nil
        }
        o.HTTPClient = &newClient

        return ctx, cancel, &o
}

// Dial performs a WebSocket handshake on url.
//
// The response is the WebSocket handshake response from the server.
// You never need to close resp.Body yourself.
//
// If an error occurs, the returned response may be non nil.
// However, you can only read the first 1024 bytes of the body.
//
// This function requires at least Go 1.12 as it uses a new feature
// in net/http to perform WebSocket handshakes.
// See docs on the HTTPClient option and https://github.com/golang/go/issues/26937#issuecomment-415855861
//
// URLs with http/https schemes will work and are interpreted as ws/wss.
func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Response, error) {
        return dial(ctx, u, opts, nil)
}

func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (_ *Conn, _ *http.Response, err error) {
        defer errd.Wrap(&err, "failed to WebSocket dial")

        var cancel context.CancelFunc
        ctx, cancel, opts = opts.cloneWithDefaults(ctx)
        if cancel != nil {
                defer cancel()
        }

        secWebSocketKey, err := secWebSocketKey(rand)
        if err != nil {
                return nil, nil, fmt.Errorf("failed to generate Sec-WebSocket-Key: %w", err)
        }

        var copts *compressionOptions
        if opts.CompressionMode != CompressionDisabled {
                copts = opts.CompressionMode.opts()
        }

        resp, err := handshakeRequest(ctx, urls, opts, copts, secWebSocketKey)
        if err != nil {
                return nil, resp, err
        }
        respBody := resp.Body
        resp.Body = nil
        defer func() {
                if err != nil {
                        // We read a bit of the body for easier debugging.
                        r := io.LimitReader(respBody, 1024)

                        timer := time.AfterFunc(time.Second*3, func() {
                                respBody.Close()
                        })
                        defer timer.Stop()

                        b, _ := io.ReadAll(r)
                        respBody.Close()
                        resp.Body = io.NopCloser(bytes.NewReader(b))
                }
        }()

        copts, err = verifyServerResponse(opts, copts, secWebSocketKey, resp)
        if err != nil {
                return nil, resp, err
        }

        rwc, ok := respBody.(io.ReadWriteCloser)
        if !ok {
                return nil, resp, fmt.Errorf("response body is not a io.ReadWriteCloser: %T", respBody)
        }

        return newConn(connConfig{
                subprotocol:    resp.Header.Get("Sec-WebSocket-Protocol"),
                rwc:            rwc,
                client:         true,
                copts:          copts,
                flateThreshold: opts.CompressionThreshold,
                onPingReceived: opts.OnPingReceived,
                onPongReceived: opts.OnPongReceived,
                br:             getBufioReader(rwc),
                bw:             getBufioWriter(rwc),
        }), resp, nil
}

func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts *compressionOptions, secWebSocketKey string) (*http.Response, error) {
        u, err := url.Parse(urls)
        if err != nil {
                return nil, fmt.Errorf("failed to parse url: %w", err)
        }

        switch u.Scheme {
        case "ws":
                u.Scheme = "http"
        case "wss":
                u.Scheme = "https"
        case "http", "https":
        default:
                return nil, fmt.Errorf("unexpected url scheme: %q", u.Scheme)
        }

        req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
        if err != nil {
                return nil, fmt.Errorf("failed to create new http request: %w", err)
        }
        if len(opts.Host) > 0 {
                req.Host = opts.Host
        }
        req.Header = opts.HTTPHeader.Clone()
        req.Header.Set("Connection", "Upgrade")
        req.Header.Set("Upgrade", "websocket")
        req.Header.Set("Sec-WebSocket-Version", "13")
        req.Header.Set("Sec-WebSocket-Key", secWebSocketKey)
        if len(opts.Subprotocols) > 0 {
                req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ","))
        }
        if copts != nil {
                req.Header.Set("Sec-WebSocket-Extensions", copts.String())
        }

        resp, err := opts.HTTPClient.Do(req)
        if err != nil {
                return nil, fmt.Errorf("failed to send handshake request: %w", err)
        }
        return resp, nil
}

func secWebSocketKey(rr io.Reader) (string, error) {
        if rr == nil {
                rr = rand.Reader
        }
        b := make([]byte, 16)
        _, err := io.ReadFull(rr, b)
        if err != nil {
                return "", fmt.Errorf("failed to read random data from rand.Reader: %w", err)
        }
        return base64.StdEncoding.EncodeToString(b), nil
}

func verifyServerResponse(opts *DialOptions, copts *compressionOptions, secWebSocketKey string, resp *http.Response) (*compressionOptions, error) {
        if resp.StatusCode != http.StatusSwitchingProtocols {
                return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode)
        }

        if !headerContainsTokenIgnoreCase(resp.Header, "Connection", "Upgrade") {
                return nil, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection"))
        }

        if !headerContainsTokenIgnoreCase(resp.Header, "Upgrade", "WebSocket") {
                return nil, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade"))
        }

        if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(secWebSocketKey) {
                return nil, fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q",
                        resp.Header.Get("Sec-WebSocket-Accept"),
                        secWebSocketKey,
                )
        }

        err := verifySubprotocol(opts.Subprotocols, resp)
        if err != nil {
                return nil, err
        }

        return verifyServerExtensions(copts, resp.Header)
}

func verifySubprotocol(subprotos []string, resp *http.Response) error {
        proto := resp.Header.Get("Sec-WebSocket-Protocol")
        if proto == "" {
                return nil
        }

        for _, sp2 := range subprotos {
                if strings.EqualFold(sp2, proto) {
                        return nil
                }
        }

        return fmt.Errorf("WebSocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto)
}

func verifyServerExtensions(copts *compressionOptions, h http.Header) (*compressionOptions, error) {
        exts := websocketExtensions(h)
        if len(exts) == 0 {
                return nil, nil
        }

        ext := exts[0]
        if ext.name != "permessage-deflate" || len(exts) > 1 || copts == nil {
                return nil, fmt.Errorf("WebSocket protcol violation: unsupported extensions from server: %+v", exts[1:])
        }

        _copts := *copts
        copts = &_copts

        for _, p := range ext.params {
                switch p {
                case "client_no_context_takeover":
                        copts.clientNoContextTakeover = true
                        continue
                case "server_no_context_takeover":
                        copts.serverNoContextTakeover = true
                        continue
                }
                if strings.HasPrefix(p, "server_max_window_bits=") {
                        // We can't adjust the deflate window, but decoding with a larger window is acceptable.
                        continue
                }

                return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p)
        }

        return copts, nil
}

var bufioReaderPool sync.Pool

func getBufioReader(r io.Reader) *bufio.Reader {
        br, ok := bufioReaderPool.Get().(*bufio.Reader)
        if !ok {
                return bufio.NewReader(r)
        }
        br.Reset(r)
        return br
}

func putBufioReader(br *bufio.Reader) {
        bufioReaderPool.Put(br)
}

var bufioWriterPool sync.Pool

func getBufioWriter(w io.Writer) *bufio.Writer {
        bw, ok := bufioWriterPool.Get().(*bufio.Writer)
        if !ok {
                return bufio.NewWriter(w)
        }
        bw.Reset(w)
        return bw
}

func putBufioWriter(bw *bufio.Writer) {
        bufioWriterPool.Put(bw)
}
//go:build !js

package websocket

import (
        "bufio"
        "encoding/binary"
        "fmt"
        "io"
        "math"

        "github.com/coder/websocket/internal/errd"
)

// opcode represents a WebSocket opcode.
type opcode int

// https://tools.ietf.org/html/rfc6455#section-11.8.
const (
        opContinuation opcode = iota
        opText
        opBinary
        // 3 - 7 are reserved for further non-control frames.
        _
        _
        _
        _
        _
        opClose
        opPing
        opPong
        // 11-16 are reserved for further control frames.
)

// header represents a WebSocket frame header.
// See https://tools.ietf.org/html/rfc6455#section-5.2.
type header struct {
        fin    bool
        rsv1   bool
        rsv2   bool
        rsv3   bool
        opcode opcode

        payloadLength int64

        masked  bool
        maskKey uint32
}

// readFrameHeader reads a header from the reader.
// See https://tools.ietf.org/html/rfc6455#section-5.2.
func readFrameHeader(r *bufio.Reader, readBuf []byte) (h header, err error) {
        defer errd.Wrap(&err, "failed to read frame header")

        b, err := r.ReadByte()
        if err != nil {
                return header{}, err
        }

        h.fin = b&(1<<7) != 0
        h.rsv1 = b&(1<<6) != 0
        h.rsv2 = b&(1<<5) != 0
        h.rsv3 = b&(1<<4) != 0

        h.opcode = opcode(b & 0xf)

        b, err = r.ReadByte()
        if err != nil {
                return header{}, err
        }

        h.masked = b&(1<<7) != 0

        payloadLength := b &^ (1 << 7)
        switch {
        case payloadLength < 126:
                h.payloadLength = int64(payloadLength)
        case payloadLength == 126:
                _, err = io.ReadFull(r, readBuf[:2])
                h.payloadLength = int64(binary.BigEndian.Uint16(readBuf))
        case payloadLength == 127:
                _, err = io.ReadFull(r, readBuf)
                h.payloadLength = int64(binary.BigEndian.Uint64(readBuf))
        }
        if err != nil {
                return header{}, err
        }

        if h.payloadLength < 0 {
                return header{}, fmt.Errorf("received negative payload length: %v", h.payloadLength)
        }

        if h.masked {
                _, err = io.ReadFull(r, readBuf[:4])
                if err != nil {
                        return header{}, err
                }
                h.maskKey = binary.LittleEndian.Uint32(readBuf)
        }

        return h, nil
}

// maxControlPayload is the maximum length of a control frame payload.
// See https://tools.ietf.org/html/rfc6455#section-5.5.
const maxControlPayload = 125

// writeFrameHeader writes the bytes of the header to w.
// See https://tools.ietf.org/html/rfc6455#section-5.2
func writeFrameHeader(h header, w *bufio.Writer, buf []byte) (err error) {
        defer errd.Wrap(&err, "failed to write frame header")

        var b byte
        if h.fin {
                b |= 1 << 7
        }
        if h.rsv1 {
                b |= 1 << 6
        }
        if h.rsv2 {
                b |= 1 << 5
        }
        if h.rsv3 {
                b |= 1 << 4
        }

        b |= byte(h.opcode)

        err = w.WriteByte(b)
        if err != nil {
                return err
        }

        lengthByte := byte(0)
        if h.masked {
                lengthByte |= 1 << 7
        }

        switch {
        case h.payloadLength > math.MaxUint16:
                lengthByte |= 127
        case h.payloadLength > 125:
                lengthByte |= 126
        case h.payloadLength >= 0:
                lengthByte |= byte(h.payloadLength)
        }
        err = w.WriteByte(lengthByte)
        if err != nil {
                return err
        }

        switch {
        case h.payloadLength > math.MaxUint16:
                binary.BigEndian.PutUint64(buf, uint64(h.payloadLength))
                _, err = w.Write(buf)
        case h.payloadLength > 125:
                binary.BigEndian.PutUint16(buf, uint16(h.payloadLength))
                _, err = w.Write(buf[:2])
        }
        if err != nil {
                return err
        }

        if h.masked {
                binary.LittleEndian.PutUint32(buf, h.maskKey)
                _, err = w.Write(buf[:4])
                if err != nil {
                        return err
                }
        }

        return nil
}
//go:build !js

package websocket

import (
        "net/http"
)

type rwUnwrapper interface {
        Unwrap() http.ResponseWriter
}

// hijacker returns the Hijacker interface of the http.ResponseWriter.
// It follows the Unwrap method of the http.ResponseWriter if available,
// matching the behavior of http.ResponseController. If the Hijacker
// interface is not found, it returns false.
//
// Since the http.ResponseController is not available in Go 1.19, and
// does not support checking the presence of the Hijacker interface,
// this function is used to provide a consistent way to check for the
// Hijacker interface across Go versions.
func hijacker(rw http.ResponseWriter) (http.Hijacker, bool) {
        for {
                switch t := rw.(type) {
                case http.Hijacker:
                        return t, true
                case rwUnwrapper:
                        rw = t.Unwrap()
                default:
                        return nil, false
                }
        }
}
package bpool

import (
        "bytes"
        "sync"
)

var bpool = sync.Pool{
        New: func() any {
                return &bytes.Buffer{}
        },
}

// Get returns a buffer from the pool or creates a new one if
// the pool is empty.
func Get() *bytes.Buffer {
        b := bpool.Get()
        return b.(*bytes.Buffer)
}

// Put returns a buffer into the pool.
func Put(b *bytes.Buffer) {
        b.Reset()
        bpool.Put(b)
}
package errd

import (
        "fmt"
)

// Wrap wraps err with fmt.Errorf if err is non nil.
// Intended for use with defer and a named error return.
// Inspired by https://github.com/golang/go/issues/32676.
func Wrap(err *error, f string, v ...any) {
        if *err != nil {
                *err = fmt.Errorf(f+": %w", append(v, *err)...)
        }
}
package assert

import (
        "errors"
        "fmt"
        "reflect"
        "strings"
        "testing"
)

// Equal asserts exp == act.
func Equal(t testing.TB, name string, exp, got any) {
        t.Helper()

        if !reflect.DeepEqual(exp, got) {
                t.Fatalf("unexpected %v: expected %#v but got %#v", name, exp, got)
        }
}

// Success asserts err == nil.
func Success(t testing.TB, err error) {
        t.Helper()

        if err != nil {
                t.Fatal(err)
        }
}

// Error asserts err != nil.
func Error(t testing.TB, err error) {
        t.Helper()

        if err == nil {
                t.Fatal("expected error")
        }
}

// Contains asserts the fmt.Sprint(v) contains sub.
func Contains(t testing.TB, v any, sub string) {
        t.Helper()

        s := fmt.Sprint(v)
        if !strings.Contains(s, sub) {
                t.Fatalf("expected %q to contain %q", s, sub)
        }
}

// ErrorIs asserts errors.Is(got, exp)
func ErrorIs(t testing.TB, exp, got error) {
        t.Helper()

        if !errors.Is(got, exp) {
                t.Fatalf("expected %v but got %v", exp, got)
        }
}
package wstest

import (
        "bytes"
        "context"
        "fmt"
        "io"
        "time"

        "github.com/coder/websocket"
        "github.com/coder/websocket/internal/test/xrand"
        "github.com/coder/websocket/internal/xsync"
)

// EchoLoop echos every msg received from c until an error
// occurs or the context expires.
// The read limit is set to 1 << 30.
func EchoLoop(ctx context.Context, c *websocket.Conn) error {
        defer c.Close(websocket.StatusInternalError, "")

        c.SetReadLimit(1 << 30)

        ctx, cancel := context.WithTimeout(ctx, time.Minute*5)
        defer cancel()

        b := make([]byte, 32<<10)
        for {
                typ, r, err := c.Reader(ctx)
                if err != nil {
                        return err
                }

                w, err := c.Writer(ctx, typ)
                if err != nil {
                        return err
                }

                _, err = io.CopyBuffer(w, r, b)
                if err != nil {
                        return err
                }

                err = w.Close()
                if err != nil {
                        return err
                }
        }
}

// Echo writes a message and ensures the same is sent back on c.
func Echo(ctx context.Context, c *websocket.Conn, max int) error {
        expType := websocket.MessageBinary
        if xrand.Bool() {
                expType = websocket.MessageText
        }

        msg := randMessage(expType, xrand.Int(max))

        writeErr := xsync.Go(func() error {
                return c.Write(ctx, expType, msg)
        })

        actType, act, err := c.Read(ctx)
        if err != nil {
                return err
        }

        err = <-writeErr
        if err != nil {
                return err
        }

        if expType != actType {
                return fmt.Errorf("unexpected message typ (%v): %v", expType, actType)
        }

        if !bytes.Equal(msg, act) {
                return fmt.Errorf("unexpected msg read: %#v", act)
        }

        return nil
}

func randMessage(typ websocket.MessageType, n int) []byte {
        if typ == websocket.MessageBinary {
                return xrand.Bytes(n)
        }
        return []byte(xrand.String(n))
}
//go:build !js

package wstest

import (
        "bufio"
        "context"
        "net"
        "net/http"
        "net/http/httptest"

        "github.com/coder/websocket"
)

// Pipe is used to create an in memory connection
// between two websockets analogous to net.Pipe.
func Pipe(dialOpts *websocket.DialOptions, acceptOpts *websocket.AcceptOptions) (clientConn, serverConn *websocket.Conn) {
        tt := fakeTransport{
                h: func(w http.ResponseWriter, r *http.Request) {
                        serverConn, _ = websocket.Accept(w, r, acceptOpts)
                },
        }

        if dialOpts == nil {
                dialOpts = &websocket.DialOptions{}
        }
        _dialOpts := *dialOpts
        dialOpts = &_dialOpts
        dialOpts.HTTPClient = &http.Client{
                Transport: tt,
        }

        clientConn, _, _ = websocket.Dial(context.Background(), "ws://example.com", dialOpts)
        return clientConn, serverConn
}

type fakeTransport struct {
        h http.HandlerFunc
}

func (t fakeTransport) RoundTrip(r *http.Request) (*http.Response, error) {
        clientConn, serverConn := net.Pipe()

        hj := testHijacker{
                ResponseRecorder: httptest.NewRecorder(),
                serverConn:       serverConn,
        }

        t.h.ServeHTTP(hj, r)

        resp := hj.ResponseRecorder.Result()
        if resp.StatusCode == http.StatusSwitchingProtocols {
                resp.Body = clientConn
        }
        return resp, nil
}

type testHijacker struct {
        *httptest.ResponseRecorder
        serverConn net.Conn
}

var _ http.Hijacker = testHijacker{}

func (hj testHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
        return hj.serverConn, bufio.NewReadWriter(bufio.NewReader(hj.serverConn), bufio.NewWriter(hj.serverConn)), nil
}
package xrand

import (
        "crypto/rand"
        "encoding/base64"
        "fmt"
        "math/big"
        "strings"
)

// Bytes generates random bytes with length n.
func Bytes(n int) []byte {
        b := make([]byte, n)
        _, err := rand.Reader.Read(b)
        if err != nil {
                panic(fmt.Sprintf("failed to generate rand bytes: %v", err))
        }
        return b
}

// String generates a random string with length n.
func String(n int) string {
        s := strings.ToValidUTF8(string(Bytes(n)), "_")
        s = strings.ReplaceAll(s, "\x00", "_")
        if len(s) > n {
                return s[:n]
        }
        if len(s) < n {
                // Pad with =
                extra := n - len(s)
                return s + strings.Repeat("=", extra)
        }
        return s
}

// Bool returns a randomly generated boolean.
func Bool() bool {
        return Int(2) == 1
}

// Int returns a randomly generated integer between [0, max).
func Int(max int) int {
        x, err := rand.Int(rand.Reader, big.NewInt(int64(max)))
        if err != nil {
                panic(fmt.Sprintf("failed to get random int: %v", err))
        }
        return int(x.Int64())
}

// Base64 returns a randomly generated base64 string of length n.
func Base64(n int) string {
        return base64.StdEncoding.EncodeToString(Bytes(n))
}
package util

// WriterFunc is used to implement one off io.Writers.
type WriterFunc func(p []byte) (int, error)

func (f WriterFunc) Write(p []byte) (int, error) {
        return f(p)
}

// ReaderFunc is used to implement one off io.Readers.
type ReaderFunc func(p []byte) (int, error)

func (f ReaderFunc) Read(p []byte) (int, error) {
        return f(p)
}
package xsync

import (
        "fmt"
        "runtime/debug"
)

// Go allows running a function in another goroutine
// and waiting for its error.
func Go(fn func() error) <-chan error {
        errs := make(chan error, 1)
        go func() {
                defer func() {
                        r := recover()
                        if r != nil {
                                select {
                                case errs <- fmt.Errorf("panic in go fn: %v, %s", r, debug.Stack()):
                                default:
                                }
                        }
                }()
                errs <- fn()
        }()

        return errs
}
package websocket

import (
        "encoding/binary"
        "math/bits"
)

// maskGo applies the WebSocket masking algorithm to p
// with the given key.
// See https://tools.ietf.org/html/rfc6455#section-5.3
//
// The returned value is the correctly rotated key to
// to continue to mask/unmask the message.
//
// It is optimized for LittleEndian and expects the key
// to be in little endian.
//
// See https://github.com/golang/go/issues/31586
func maskGo(b []byte, key uint32) uint32 {
        if len(b) >= 8 {
                key64 := uint64(key)<<32 | uint64(key)

                // At some point in the future we can clean these unrolled loops up.
                // See https://github.com/golang/go/issues/31586#issuecomment-487436401

                // Then we xor until b is less than 128 bytes.
                for len(b) >= 128 {
                        v := binary.LittleEndian.Uint64(b)
                        binary.LittleEndian.PutUint64(b, v^key64)
                        v = binary.LittleEndian.Uint64(b[8:16])
                        binary.LittleEndian.PutUint64(b[8:16], v^key64)
                        v = binary.LittleEndian.Uint64(b[16:24])
                        binary.LittleEndian.PutUint64(b[16:24], v^key64)
                        v = binary.LittleEndian.Uint64(b[24:32])
                        binary.LittleEndian.PutUint64(b[24:32], v^key64)
                        v = binary.LittleEndian.Uint64(b[32:40])
                        binary.LittleEndian.PutUint64(b[32:40], v^key64)
                        v = binary.LittleEndian.Uint64(b[40:48])
                        binary.LittleEndian.PutUint64(b[40:48], v^key64)
                        v = binary.LittleEndian.Uint64(b[48:56])
                        binary.LittleEndian.PutUint64(b[48:56], v^key64)
                        v = binary.LittleEndian.Uint64(b[56:64])
                        binary.LittleEndian.PutUint64(b[56:64], v^key64)
                        v = binary.LittleEndian.Uint64(b[64:72])
                        binary.LittleEndian.PutUint64(b[64:72], v^key64)
                        v = binary.LittleEndian.Uint64(b[72:80])
                        binary.LittleEndian.PutUint64(b[72:80], v^key64)
                        v = binary.LittleEndian.Uint64(b[80:88])
                        binary.LittleEndian.PutUint64(b[80:88], v^key64)
                        v = binary.LittleEndian.Uint64(b[88:96])
                        binary.LittleEndian.PutUint64(b[88:96], v^key64)
                        v = binary.LittleEndian.Uint64(b[96:104])
                        binary.LittleEndian.PutUint64(b[96:104], v^key64)
                        v = binary.LittleEndian.Uint64(b[104:112])
                        binary.LittleEndian.PutUint64(b[104:112], v^key64)
                        v = binary.LittleEndian.Uint64(b[112:120])
                        binary.LittleEndian.PutUint64(b[112:120], v^key64)
                        v = binary.LittleEndian.Uint64(b[120:128])
                        binary.LittleEndian.PutUint64(b[120:128], v^key64)
                        b = b[128:]
                }

                // Then we xor until b is less than 64 bytes.
                for len(b) >= 64 {
                        v := binary.LittleEndian.Uint64(b)
                        binary.LittleEndian.PutUint64(b, v^key64)
                        v = binary.LittleEndian.Uint64(b[8:16])
                        binary.LittleEndian.PutUint64(b[8:16], v^key64)
                        v = binary.LittleEndian.Uint64(b[16:24])
                        binary.LittleEndian.PutUint64(b[16:24], v^key64)
                        v = binary.LittleEndian.Uint64(b[24:32])
                        binary.LittleEndian.PutUint64(b[24:32], v^key64)
                        v = binary.LittleEndian.Uint64(b[32:40])
                        binary.LittleEndian.PutUint64(b[32:40], v^key64)
                        v = binary.LittleEndian.Uint64(b[40:48])
                        binary.LittleEndian.PutUint64(b[40:48], v^key64)
                        v = binary.LittleEndian.Uint64(b[48:56])
                        binary.LittleEndian.PutUint64(b[48:56], v^key64)
                        v = binary.LittleEndian.Uint64(b[56:64])
                        binary.LittleEndian.PutUint64(b[56:64], v^key64)
                        b = b[64:]
                }

                // Then we xor until b is less than 32 bytes.
                for len(b) >= 32 {
                        v := binary.LittleEndian.Uint64(b)
                        binary.LittleEndian.PutUint64(b, v^key64)
                        v = binary.LittleEndian.Uint64(b[8:16])
                        binary.LittleEndian.PutUint64(b[8:16], v^key64)
                        v = binary.LittleEndian.Uint64(b[16:24])
                        binary.LittleEndian.PutUint64(b[16:24], v^key64)
                        v = binary.LittleEndian.Uint64(b[24:32])
                        binary.LittleEndian.PutUint64(b[24:32], v^key64)
                        b = b[32:]
                }

                // Then we xor until b is less than 16 bytes.
                for len(b) >= 16 {
                        v := binary.LittleEndian.Uint64(b)
                        binary.LittleEndian.PutUint64(b, v^key64)
                        v = binary.LittleEndian.Uint64(b[8:16])
                        binary.LittleEndian.PutUint64(b[8:16], v^key64)
                        b = b[16:]
                }

                // Then we xor until b is less than 8 bytes.
                for len(b) >= 8 {
                        v := binary.LittleEndian.Uint64(b)
                        binary.LittleEndian.PutUint64(b, v^key64)
                        b = b[8:]
                }
        }

        // Then we xor until b is less than 4 bytes.
        for len(b) >= 4 {
                v := binary.LittleEndian.Uint32(b)
                binary.LittleEndian.PutUint32(b, v^key)
                b = b[4:]
        }

        // xor remaining bytes.
        for i := range b {
                b[i] ^= byte(key)
                key = bits.RotateLeft32(key, -8)
        }

        return key
}
//go:build amd64 || arm64

package websocket

func mask(b []byte, key uint32) uint32 {
        // TODO: Will enable in v1.9.0.
        return maskGo(b, key)
        /*
                if len(b) > 0 {
                        return maskAsm(&b[0], len(b), key)
                }
                return key
        */
}

// @nhooyr: I am not confident that the amd64 or the arm64 implementations of this
// function are perfect. There are almost certainly missing optimizations or
// opportunities for simplification. I'm confident there are no bugs though.
// For example, the arm64 implementation doesn't align memory like the amd64.
// Or the amd64 implementation could use AVX512 instead of just AVX2.
// The AVX2 code I had to disable anyway as it wasn't performing as expected.
// See https://github.com/nhooyr/websocket/pull/326#issuecomment-1771138049
//
//go:noescape
//lint:ignore U1000 disabled till v1.9.0
func maskAsm(b *byte, len int, key uint32) uint32
package websocket

import (
        "context"
        "fmt"
        "io"
        "math"
        "net"
        "sync/atomic"
        "time"
)

// NetConn converts a *websocket.Conn into a net.Conn.
//
// It's for tunneling arbitrary protocols over WebSockets.
// Few users of the library will need this but it's tricky to implement
// correctly and so provided in the library.
// See https://github.com/nhooyr/websocket/issues/100.
//
// Every Write to the net.Conn will correspond to a message write of
// the given type on *websocket.Conn.
//
// The passed ctx bounds the lifetime of the net.Conn. If cancelled,
// all reads and writes on the net.Conn will be cancelled.
//
// If a message is read that is not of the correct type, the connection
// will be closed with StatusUnsupportedData and an error will be returned.
//
// Close will close the *websocket.Conn with StatusNormalClosure.
//
// When a deadline is hit and there is an active read or write goroutine, the
// connection will be closed. This is different from most net.Conn implementations
// where only the reading/writing goroutines are interrupted but the connection
// is kept alive.
//
// The Addr methods will return the real addresses for connections obtained
// from websocket.Accept. But for connections obtained from websocket.Dial, a mock net.Addr
// will be returned that gives "websocket" for Network() and "websocket/unknown-addr" for
// String(). This is because websocket.Dial only exposes a io.ReadWriteCloser instead of the
// full net.Conn to us.
//
// When running as WASM, the Addr methods will always return the mock address described above.
//
// A received StatusNormalClosure or StatusGoingAway close frame will be translated to
// io.EOF when reading.
//
// Furthermore, the ReadLimit is set to -1 to disable it.
func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn {
        c.SetReadLimit(-1)

        nc := &netConn{
                c:       c,
                msgType: msgType,
                readMu:  newMu(c),
                writeMu: newMu(c),
        }

        nc.writeCtx, nc.writeCancel = context.WithCancel(ctx)
        nc.readCtx, nc.readCancel = context.WithCancel(ctx)

        nc.writeTimer = time.AfterFunc(math.MaxInt64, func() {
                if !nc.writeMu.tryLock() {
                        // If the lock cannot be acquired, then there is an
                        // active write goroutine and so we should cancel the context.
                        nc.writeCancel()
                        return
                }
                defer nc.writeMu.unlock()

                // Prevents future writes from writing until the deadline is reset.
                nc.writeExpired.Store(1)
        })
        if !nc.writeTimer.Stop() {
                <-nc.writeTimer.C
        }

        nc.readTimer = time.AfterFunc(math.MaxInt64, func() {
                if !nc.readMu.tryLock() {
                        // If the lock cannot be acquired, then there is an
                        // active read goroutine and so we should cancel the context.
                        nc.readCancel()
                        return
                }
                defer nc.readMu.unlock()

                // Prevents future reads from reading until the deadline is reset.
                nc.readExpired.Store(1)
        })
        if !nc.readTimer.Stop() {
                <-nc.readTimer.C
        }

        return nc
}

type netConn struct {
        c       *Conn
        msgType MessageType

        writeTimer   *time.Timer
        writeMu      *mu
        writeExpired atomic.Int64
        writeCtx     context.Context
        writeCancel  context.CancelFunc

        readTimer   *time.Timer
        readMu      *mu
        readExpired atomic.Int64
        readCtx     context.Context
        readCancel  context.CancelFunc
        readEOFed   bool
        reader      io.Reader
}

var _ net.Conn = &netConn{}

func (nc *netConn) Close() error {
        nc.writeTimer.Stop()
        nc.writeCancel()
        nc.readTimer.Stop()
        nc.readCancel()
        return nc.c.Close(StatusNormalClosure, "")
}

func (nc *netConn) Write(p []byte) (int, error) {
        nc.writeMu.forceLock()
        defer nc.writeMu.unlock()

        if nc.writeExpired.Load() == 1 {
                return 0, fmt.Errorf("failed to write: %w", context.DeadlineExceeded)
        }

        err := nc.c.Write(nc.writeCtx, nc.msgType, p)
        if err != nil {
                return 0, err
        }
        return len(p), nil
}

func (nc *netConn) Read(p []byte) (int, error) {
        nc.readMu.forceLock()
        defer nc.readMu.unlock()

        for {
                n, err := nc.read(p)
                if err != nil {
                        return n, err
                }
                if n == 0 {
                        continue
                }
                return n, nil
        }
}

func (nc *netConn) read(p []byte) (int, error) {
        if nc.readExpired.Load() == 1 {
                return 0, fmt.Errorf("failed to read: %w", context.DeadlineExceeded)
        }

        if nc.readEOFed {
                return 0, io.EOF
        }

        if nc.reader == nil {
                typ, r, err := nc.c.Reader(nc.readCtx)
                if err != nil {
                        switch CloseStatus(err) {
                        case StatusNormalClosure, StatusGoingAway:
                                nc.readEOFed = true
                                return 0, io.EOF
                        }
                        return 0, err
                }
                if typ != nc.msgType {
                        err := fmt.Errorf("unexpected frame type read (expected %v): %v", nc.msgType, typ)
                        nc.c.Close(StatusUnsupportedData, err.Error())
                        return 0, err
                }
                nc.reader = r
        }

        n, err := nc.reader.Read(p)
        if err == io.EOF {
                nc.reader = nil
                err = nil
        }
        return n, err
}

type websocketAddr struct{}

func (a websocketAddr) Network() string {
        return "websocket"
}

func (a websocketAddr) String() string {
        return "websocket/unknown-addr"
}

func (nc *netConn) SetDeadline(t time.Time) error {
        nc.SetWriteDeadline(t)
        nc.SetReadDeadline(t)
        return nil
}

func (nc *netConn) SetWriteDeadline(t time.Time) error {
        nc.writeExpired.Store(0)
        if t.IsZero() {
                nc.writeTimer.Stop()
        } else {
                dur := time.Until(t)
                if dur <= 0 {
                        dur = 1
                }
                nc.writeTimer.Reset(dur)
        }
        return nil
}

func (nc *netConn) SetReadDeadline(t time.Time) error {
        nc.readExpired.Store(0)
        if t.IsZero() {
                nc.readTimer.Stop()
        } else {
                dur := time.Until(t)
                if dur <= 0 {
                        dur = 1
                }
                nc.readTimer.Reset(dur)
        }
        return nil
}
//go:build !js

package websocket

import "net"

func (nc *netConn) RemoteAddr() net.Addr {
        if unc, ok := nc.c.rwc.(net.Conn); ok {
                return unc.RemoteAddr()
        }
        return websocketAddr{}
}

func (nc *netConn) LocalAddr() net.Addr {
        if unc, ok := nc.c.rwc.(net.Conn); ok {
                return unc.LocalAddr()
        }
        return websocketAddr{}
}
//go:build !js

package websocket

import (
        "bufio"
        "context"
        "errors"
        "fmt"
        "io"
        "net"
        "strings"
        "sync/atomic"
        "time"

        "github.com/coder/websocket/internal/errd"
        "github.com/coder/websocket/internal/util"
)

// Reader reads from the connection until there is a WebSocket
// data message to be read. It will handle ping, pong and close frames as appropriate.
//
// It returns the type of the message and an io.Reader to read it.
// The passed context will also bound the reader.
// Ensure you read to EOF otherwise the connection will hang.
//
// Call CloseRead if you do not expect any data messages from the peer.
//
// Only one Reader may be open at a time.
//
// If you need a separate timeout on the Reader call and the Read itself,
// use time.AfterFunc to cancel the context passed in.
// See https://github.com/nhooyr/websocket/issues/87#issue-451703332
// Most users should not need this.
func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
        return c.reader(ctx)
}

// Read is a convenience method around Reader to read a single message
// from the connection.
func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
        typ, r, err := c.Reader(ctx)
        if err != nil {
                return 0, nil, err
        }

        b, err := io.ReadAll(r)
        return typ, b, err
}

// CloseRead starts a goroutine to read from the connection until it is closed
// or a data message is received.
//
// Once CloseRead is called you cannot read any messages from the connection.
// The returned context will be cancelled when the connection is closed.
//
// If a data message is received, the connection will be closed with StatusPolicyViolation.
//
// Call CloseRead when you do not expect to read any more messages.
// Since it actively reads from the connection, it will ensure that ping, pong and close
// frames are responded to. This means c.Ping and c.Close will still work as expected.
//
// This function is idempotent.
func (c *Conn) CloseRead(ctx context.Context) context.Context {
        c.closeReadMu.Lock()
        ctx2 := c.closeReadCtx
        if ctx2 != nil {
                c.closeReadMu.Unlock()
                return ctx2
        }
        ctx, cancel := context.WithCancel(ctx)
        c.closeReadCtx = ctx
        c.closeReadDone = make(chan struct{})
        c.closeReadMu.Unlock()

        go func() {
                defer close(c.closeReadDone)
                defer cancel()
                defer c.close()
                _, _, err := c.Reader(ctx)
                if err == nil {
                        c.Close(StatusPolicyViolation, "unexpected data message")
                }
        }()
        return ctx
}

// SetReadLimit sets the max number of bytes to read for a single message.
// It applies to the Reader and Read methods.
//
// By default, the connection has a message read limit of 32768 bytes.
//
// When the limit is hit, the connection will be closed with StatusMessageTooBig.
//
// Set to -1 to disable.
func (c *Conn) SetReadLimit(n int64) {
        if n >= 0 {
                // We read one more byte than the limit in case
                // there is a fin frame that needs to be read.
                n++
        }

        c.msgReader.limitReader.limit.Store(n)
}

const defaultReadLimit = 32768

func newMsgReader(c *Conn) *msgReader {
        mr := &msgReader{
                c:   c,
                fin: true,
        }
        mr.readFunc = mr.read

        mr.limitReader = newLimitReader(c, mr.readFunc, defaultReadLimit+1)
        return mr
}

func (mr *msgReader) resetFlate() {
        if mr.flateContextTakeover() {
                if mr.dict == nil {
                        mr.dict = &slidingWindow{}
                }
                mr.dict.init(32768)
        }
        if mr.flateBufio == nil {
                mr.flateBufio = getBufioReader(mr.readFunc)
        }

        if mr.flateContextTakeover() {
                mr.flateReader = getFlateReader(mr.flateBufio, mr.dict.buf)
        } else {
                mr.flateReader = getFlateReader(mr.flateBufio, nil)
        }
        mr.limitReader.r = mr.flateReader
        mr.flateTail.Reset(deflateMessageTail)
}

func (mr *msgReader) putFlateReader() {
        if mr.flateReader != nil {
                putFlateReader(mr.flateReader)
                mr.flateReader = nil
        }
}

func (mr *msgReader) close() {
        mr.c.readMu.forceLock()
        mr.putFlateReader()
        if mr.dict != nil {
                mr.dict.close()
                mr.dict = nil
        }
        if mr.flateBufio != nil {
                putBufioReader(mr.flateBufio)
        }

        if mr.c.client {
                putBufioReader(mr.c.br)
                mr.c.br = nil
        }
}

func (mr *msgReader) flateContextTakeover() bool {
        if mr.c.client {
                return !mr.c.copts.serverNoContextTakeover
        }
        return !mr.c.copts.clientNoContextTakeover
}

func (c *Conn) readRSV1Illegal(h header) bool {
        // If compression is disabled, rsv1 is illegal.
        if !c.flate() {
                return true
        }
        // rsv1 is only allowed on data frames beginning messages.
        if h.opcode != opText && h.opcode != opBinary {
                return true
        }
        return false
}

func (c *Conn) readLoop(ctx context.Context) (header, error) {
        for {
                h, err := c.readFrameHeader(ctx)
                if err != nil {
                        return header{}, err
                }

                if h.rsv1 && c.readRSV1Illegal(h) || h.rsv2 || h.rsv3 {
                        err := fmt.Errorf("received header with unexpected rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3)
                        c.writeError(StatusProtocolError, err)
                        return header{}, err
                }

                if !c.client && !h.masked {
                        return header{}, errors.New("received unmasked frame from client")
                }

                switch h.opcode {
                case opClose, opPing, opPong:
                        err = c.handleControl(ctx, h)
                        if err != nil {
                                // Pass through CloseErrors when receiving a close frame.
                                if h.opcode == opClose && CloseStatus(err) != -1 {
                                        return header{}, err
                                }
                                return header{}, fmt.Errorf("failed to handle control frame %v: %w", h.opcode, err)
                        }
                case opContinuation, opText, opBinary:
                        return h, nil
                default:
                        err := fmt.Errorf("received unknown opcode %v", h.opcode)
                        c.writeError(StatusProtocolError, err)
                        return header{}, err
                }
        }
}

// prepareRead sets the readTimeout context and returns a done function
// to be called after the read is done. It also returns an error if the
// connection is closed. The reference to the error is used to assign
// an error depending on if the connection closed or the context timed
// out during use. Typically the referenced error is a named return
// variable of the function calling this method.
func (c *Conn) prepareRead(ctx context.Context, err *error) (func(), error) {
        select {
        case <-c.closed:
                return nil, net.ErrClosed
        case c.readTimeout <- ctx:
        }

        done := func() {
                select {
                case <-c.closed:
                        if *err != nil {
                                *err = net.ErrClosed
                        }
                case c.readTimeout <- context.Background():
                }
                if *err != nil && ctx.Err() != nil {
                        *err = ctx.Err()
                }
        }

        c.closeStateMu.Lock()
        closeReceivedErr := c.closeReceivedErr
        c.closeStateMu.Unlock()
        if closeReceivedErr != nil {
                defer done()
                return nil, closeReceivedErr
        }

        return done, nil
}

func (c *Conn) readFrameHeader(ctx context.Context) (_ header, err error) {
        readDone, err := c.prepareRead(ctx, &err)
        if err != nil {
                return header{}, err
        }
        defer readDone()

        h, err := readFrameHeader(c.br, c.readHeaderBuf[:])
        if err != nil {
                return header{}, err
        }

        return h, nil
}

func (c *Conn) readFramePayload(ctx context.Context, p []byte) (_ int, err error) {
        readDone, err := c.prepareRead(ctx, &err)
        if err != nil {
                return 0, err
        }
        defer readDone()

        n, err := io.ReadFull(c.br, p)
        if err != nil {
                return n, fmt.Errorf("failed to read frame payload: %w", err)
        }

        return n, err
}

func (c *Conn) handleControl(ctx context.Context, h header) (err error) {
        if h.payloadLength < 0 || h.payloadLength > maxControlPayload {
                err := fmt.Errorf("received control frame payload with invalid length: %d", h.payloadLength)
                c.writeError(StatusProtocolError, err)
                return err
        }

        if !h.fin {
                err := errors.New("received fragmented control frame")
                c.writeError(StatusProtocolError, err)
                return err
        }

        ctx, cancel := context.WithTimeout(ctx, time.Second*5)
        defer cancel()

        b := c.readControlBuf[:h.payloadLength]
        _, err = c.readFramePayload(ctx, b)
        if err != nil {
                return err
        }

        if h.masked {
                mask(b, h.maskKey)
        }

        switch h.opcode {
        case opPing:
                if c.onPingReceived != nil {
                        if !c.onPingReceived(ctx, b) {
                                return nil
                        }
                }
                return c.writeControl(ctx, opPong, b)
        case opPong:
                if c.onPongReceived != nil {
                        c.onPongReceived(ctx, b)
                }
                c.activePingsMu.Lock()
                pong, ok := c.activePings[string(b)]
                c.activePingsMu.Unlock()
                if ok {
                        select {
                        case pong <- struct{}{}:
                        default:
                        }
                }
                return nil
        }

        // opClose

        ce, err := parseClosePayload(b)
        if err != nil {
                err = fmt.Errorf("received invalid close payload: %w", err)
                c.writeError(StatusProtocolError, err)
                return err
        }

        err = fmt.Errorf("received close frame: %w", ce)
        c.closeStateMu.Lock()
        c.closeReceivedErr = err
        closeSent := c.closeSentErr != nil
        c.closeStateMu.Unlock()

        // Only unlock readMu if this connection is being closed becaue
        // c.close will try to acquire the readMu lock. We unlock for
        // writeClose as well because it may also call c.close.
        if !closeSent {
                c.readMu.unlock()
                _ = c.writeClose(ce.Code, ce.Reason)
        }
        if !c.casClosing() {
                c.readMu.unlock()
                _ = c.close()
        }
        return err
}

func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err error) {
        defer errd.Wrap(&err, "failed to get reader")

        err = c.readMu.lock(ctx)
        if err != nil {
                return 0, nil, err
        }
        defer c.readMu.unlock()

        if !c.msgReader.fin {
                return 0, nil, errors.New("previous message not read to completion")
        }

        h, err := c.readLoop(ctx)
        if err != nil {
                return 0, nil, err
        }

        if h.opcode == opContinuation {
                err := errors.New("received continuation frame without text or binary frame")
                c.writeError(StatusProtocolError, err)
                return 0, nil, err
        }

        c.msgReader.reset(ctx, h)

        return MessageType(h.opcode), c.msgReader, nil
}

type msgReader struct {
        c *Conn

        ctx         context.Context
        flate       bool
        flateReader io.Reader
        flateBufio  *bufio.Reader
        flateTail   strings.Reader
        limitReader *limitReader
        dict        *slidingWindow

        fin           bool
        payloadLength int64
        maskKey       uint32

        // util.ReaderFunc(mr.Read) to avoid continuous allocations.
        readFunc util.ReaderFunc
}

func (mr *msgReader) reset(ctx context.Context, h header) {
        mr.ctx = ctx
        mr.flate = h.rsv1
        mr.limitReader.reset(mr.readFunc)

        if mr.flate {
                mr.resetFlate()
        }

        mr.setFrame(h)
}

func (mr *msgReader) setFrame(h header) {
        mr.fin = h.fin
        mr.payloadLength = h.payloadLength
        mr.maskKey = h.maskKey
}

func (mr *msgReader) Read(p []byte) (n int, err error) {
        err = mr.c.readMu.lock(mr.ctx)
        if err != nil {
                return 0, fmt.Errorf("failed to read: %w", err)
        }
        defer mr.c.readMu.unlock()

        n, err = mr.limitReader.Read(p)
        if mr.flate && mr.flateContextTakeover() {
                p = p[:n]
                mr.dict.write(p)
        }
        if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) && mr.fin && mr.flate {
                mr.putFlateReader()
                return n, io.EOF
        }
        if err != nil {
                return n, fmt.Errorf("failed to read: %w", err)
        }
        return n, nil
}

func (mr *msgReader) read(p []byte) (int, error) {
        for {
                if mr.payloadLength == 0 {
                        if mr.fin {
                                if mr.flate {
                                        return mr.flateTail.Read(p)
                                }
                                return 0, io.EOF
                        }

                        h, err := mr.c.readLoop(mr.ctx)
                        if err != nil {
                                return 0, err
                        }
                        if h.opcode != opContinuation {
                                err := errors.New("received new data message without finishing the previous message")
                                mr.c.writeError(StatusProtocolError, err)
                                return 0, err
                        }
                        mr.setFrame(h)

                        continue
                }

                if int64(len(p)) > mr.payloadLength {
                        p = p[:mr.payloadLength]
                }

                n, err := mr.c.readFramePayload(mr.ctx, p)
                if err != nil {
                        return n, err
                }

                mr.payloadLength -= int64(n)

                if !mr.c.client {
                        mr.maskKey = mask(p, mr.maskKey)
                }

                return n, nil
        }
}

type limitReader struct {
        c     *Conn
        r     io.Reader
        limit atomic.Int64
        n     int64
}

func newLimitReader(c *Conn, r io.Reader, limit int64) *limitReader {
        lr := &limitReader{
                c: c,
        }
        lr.limit.Store(limit)
        lr.reset(r)
        return lr
}

func (lr *limitReader) reset(r io.Reader) {
        lr.n = lr.limit.Load()
        lr.r = r
}

func (lr *limitReader) Read(p []byte) (int, error) {
        if lr.n < 0 {
                return lr.r.Read(p)
        }

        if lr.n == 0 {
                err := fmt.Errorf("read limited at %v bytes", lr.limit.Load())
                lr.c.writeError(StatusMessageTooBig, err)
                return 0, err
        }

        if int64(len(p)) > lr.n {
                p = p[:lr.n]
        }
        n, err := lr.r.Read(p)
        lr.n -= int64(n)
        if lr.n < 0 {
                lr.n = 0
        }
        return n, err
}
//go:build !js

package websocket

import (
        "bufio"
        "compress/flate"
        "context"
        "crypto/rand"
        "encoding/binary"
        "errors"
        "fmt"
        "io"
        "net"
        "time"

        "github.com/coder/websocket/internal/errd"
        "github.com/coder/websocket/internal/util"
)

// Writer returns a writer bounded by the context that will write
// a WebSocket message of type dataType to the connection.
//
// You must close the writer once you have written the entire message.
//
// Only one writer can be open at a time, multiple calls will block until the previous writer
// is closed.
func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
        w, err := c.writer(ctx, typ)
        if err != nil {
                return nil, fmt.Errorf("failed to get writer: %w", err)
        }
        return w, nil
}

// Write writes a message to the connection.
//
// See the Writer method if you want to stream a message.
//
// If compression is disabled or the compression threshold is not met, then it
// will write the message in a single frame.
func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
        _, err := c.write(ctx, typ, p)
        if err != nil {
                return fmt.Errorf("failed to write msg: %w", err)
        }
        return nil
}

type msgWriter struct {
        c *Conn

        mu      *mu
        writeMu *mu
        closed  bool

        ctx    context.Context
        opcode opcode
        flate  bool

        trimWriter  *trimLastFourBytesWriter
        flateWriter *flate.Writer
}

func newMsgWriter(c *Conn) *msgWriter {
        mw := &msgWriter{
                c:       c,
                mu:      newMu(c),
                writeMu: newMu(c),
        }
        return mw
}

func (mw *msgWriter) ensureFlate() {
        if mw.trimWriter == nil {
                mw.trimWriter = &trimLastFourBytesWriter{
                        w: util.WriterFunc(mw.write),
                }
        }

        if mw.flateWriter == nil {
                mw.flateWriter = getFlateWriter(mw.trimWriter)
        }
        mw.flate = true
}

func (mw *msgWriter) flateContextTakeover() bool {
        if mw.c.client {
                return !mw.c.copts.clientNoContextTakeover
        }
        return !mw.c.copts.serverNoContextTakeover
}

func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
        err := c.msgWriter.reset(ctx, typ)
        if err != nil {
                return nil, err
        }
        return c.msgWriter, nil
}

func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) {
        mw, err := c.writer(ctx, typ)
        if err != nil {
                return 0, err
        }

        if !c.flate() {
                defer c.msgWriter.mu.unlock()
                return c.writeFrame(ctx, true, false, c.msgWriter.opcode, p)
        }

        n, err := mw.Write(p)
        if err != nil {
                return n, err
        }

        err = mw.Close()
        return n, err
}

func (mw *msgWriter) reset(ctx context.Context, typ MessageType) error {
        err := mw.mu.lock(ctx)
        if err != nil {
                return err
        }

        mw.ctx = ctx
        mw.opcode = opcode(typ)
        mw.flate = false
        mw.closed = false

        mw.trimWriter.reset()

        return nil
}

func (mw *msgWriter) putFlateWriter() {
        if mw.flateWriter != nil {
                putFlateWriter(mw.flateWriter)
                mw.flateWriter = nil
        }
}

// Write writes the given bytes to the WebSocket connection.
func (mw *msgWriter) Write(p []byte) (_ int, err error) {
        err = mw.writeMu.lock(mw.ctx)
        if err != nil {
                return 0, fmt.Errorf("failed to write: %w", err)
        }
        defer mw.writeMu.unlock()

        if mw.closed {
                return 0, errors.New("cannot use closed writer")
        }

        defer func() {
                if err != nil {
                        err = fmt.Errorf("failed to write: %w", err)
                }
        }()

        if mw.c.flate() {
                // Only enables flate if the length crosses the
                // threshold on the first frame
                if mw.opcode != opContinuation && len(p) >= mw.c.flateThreshold {
                        mw.ensureFlate()
                }
        }

        if mw.flate {
                return mw.flateWriter.Write(p)
        }

        return mw.write(p)
}

func (mw *msgWriter) write(p []byte) (int, error) {
        n, err := mw.c.writeFrame(mw.ctx, false, mw.flate, mw.opcode, p)
        if err != nil {
                return n, fmt.Errorf("failed to write data frame: %w", err)
        }
        mw.opcode = opContinuation
        return n, nil
}

// Close flushes the frame to the connection.
func (mw *msgWriter) Close() (err error) {
        defer errd.Wrap(&err, "failed to close writer")

        err = mw.writeMu.lock(mw.ctx)
        if err != nil {
                return err
        }
        defer mw.writeMu.unlock()

        if mw.closed {
                return errors.New("writer already closed")
        }
        mw.closed = true

        if mw.flate {
                err = mw.flateWriter.Flush()
                if err != nil {
                        return fmt.Errorf("failed to flush flate: %w", err)
                }
        }

        _, err = mw.c.writeFrame(mw.ctx, true, mw.flate, mw.opcode, nil)
        if err != nil {
                return fmt.Errorf("failed to write fin frame: %w", err)
        }

        if mw.flate && !mw.flateContextTakeover() {
                mw.putFlateWriter()
        }
        mw.mu.unlock()
        return nil
}

func (mw *msgWriter) close() {
        if mw.c.client {
                mw.c.writeFrameMu.forceLock()
                putBufioWriter(mw.c.bw)
        }

        mw.writeMu.forceLock()
        mw.putFlateWriter()
}

func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error {
        ctx, cancel := context.WithTimeout(ctx, time.Second*5)
        defer cancel()

        _, err := c.writeFrame(ctx, true, false, opcode, p)
        if err != nil {
                return fmt.Errorf("failed to write control frame %v: %w", opcode, err)
        }
        return nil
}

// writeFrame handles all writes to the connection.
func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (_ int, err error) {
        err = c.writeFrameMu.lock(ctx)
        if err != nil {
                return 0, err
        }
        defer c.writeFrameMu.unlock()

        defer func() {
                if c.isClosed() && opcode == opClose {
                        err = nil
                }
                if err != nil {
                        if ctx.Err() != nil {
                                err = ctx.Err()
                        } else if c.isClosed() {
                                err = net.ErrClosed
                        }
                        err = fmt.Errorf("failed to write frame: %w", err)
                }
        }()

        c.closeStateMu.Lock()
        closeSentErr := c.closeSentErr
        c.closeStateMu.Unlock()
        if closeSentErr != nil {
                return 0, net.ErrClosed
        }

        select {
        case <-c.closed:
                return 0, net.ErrClosed
        case c.writeTimeout <- ctx:
        }
        defer func() {
                select {
                case <-c.closed:
                case c.writeTimeout <- context.Background():
                }
        }()

        c.writeHeader.fin = fin
        c.writeHeader.opcode = opcode
        c.writeHeader.payloadLength = int64(len(p))

        if c.client {
                c.writeHeader.masked = true
                _, err = io.ReadFull(rand.Reader, c.writeHeaderBuf[:4])
                if err != nil {
                        return 0, fmt.Errorf("failed to generate masking key: %w", err)
                }
                c.writeHeader.maskKey = binary.LittleEndian.Uint32(c.writeHeaderBuf[:])
        }

        c.writeHeader.rsv1 = false
        if flate && (opcode == opText || opcode == opBinary) {
                c.writeHeader.rsv1 = true
        }

        err = writeFrameHeader(c.writeHeader, c.bw, c.writeHeaderBuf[:])
        if err != nil {
                return 0, err
        }

        n, err := c.writeFramePayload(p)
        if err != nil {
                return n, err
        }

        if c.writeHeader.fin {
                err = c.bw.Flush()
                if err != nil {
                        return n, fmt.Errorf("failed to flush: %w", err)
                }
        }

        if opcode == opClose {
                c.closeStateMu.Lock()
                c.closeSentErr = fmt.Errorf("sent close frame: %w", net.ErrClosed)
                closeReceived := c.closeReceivedErr != nil
                c.closeStateMu.Unlock()

                if closeReceived && !c.casClosing() {
                        c.writeFrameMu.unlock()
                        _ = c.close()
                }
        }

        return n, nil
}

func (c *Conn) writeFramePayload(p []byte) (n int, err error) {
        defer errd.Wrap(&err, "failed to write frame payload")

        if !c.writeHeader.masked {
                return c.bw.Write(p)
        }

        maskKey := c.writeHeader.maskKey
        for len(p) > 0 {
                // If the buffer is full, we need to flush.
                if c.bw.Available() == 0 {
                        err = c.bw.Flush()
                        if err != nil {
                                return n, err
                        }
                }

                // Start of next write in the buffer.
                i := c.bw.Buffered()

                j := min(len(p), c.bw.Available())

                _, err := c.bw.Write(p[:j])
                if err != nil {
                        return n, err
                }

                maskKey = mask(c.writeBuf[i:c.bw.Buffered()], maskKey)

                p = p[j:]
                n += j
        }

        return n, nil
}

// extractBufioWriterBuf grabs the []byte backing a *bufio.Writer
// and returns it.
func extractBufioWriterBuf(bw *bufio.Writer, w io.Writer) []byte {
        var writeBuf []byte
        bw.Reset(util.WriterFunc(func(p2 []byte) (int, error) {
                writeBuf = p2[:cap(p2)]
                return len(p2), nil
        }))

        bw.WriteByte(0)
        bw.Flush()

        bw.Reset(w)

        return writeBuf
}

func (c *Conn) writeError(code StatusCode, err error) {
        c.writeClose(code, err.Error())
}
// Package wsjson provides helpers for reading and writing JSON messages.
package wsjson // import "github.com/coder/websocket/wsjson"

import (
        "context"
        "encoding/json"
        "fmt"

        "github.com/coder/websocket"
        "github.com/coder/websocket/internal/bpool"
        "github.com/coder/websocket/internal/errd"
        "github.com/coder/websocket/internal/util"
)

// Read reads a JSON message from c into v.
// It will reuse buffers in between calls to avoid allocations.
func Read(ctx context.Context, c *websocket.Conn, v any) error {
        return read(ctx, c, v)
}

func read(ctx context.Context, c *websocket.Conn, v any) (err error) {
        defer errd.Wrap(&err, "failed to read JSON message")

        _, r, err := c.Reader(ctx)
        if err != nil {
                return err
        }

        b := bpool.Get()
        defer bpool.Put(b)

        _, err = b.ReadFrom(r)
        if err != nil {
                return err
        }

        err = json.Unmarshal(b.Bytes(), v)
        if err != nil {
                c.Close(websocket.StatusInvalidFramePayloadData, "failed to unmarshal JSON")
                return fmt.Errorf("failed to unmarshal JSON: %w", err)
        }

        return nil
}

// Write writes the JSON message v to c.
// It will reuse buffers in between calls to avoid allocations.
func Write(ctx context.Context, c *websocket.Conn, v any) error {
        return write(ctx, c, v)
}

func write(ctx context.Context, c *websocket.Conn, v any) (err error) {
        defer errd.Wrap(&err, "failed to write JSON message")

        // json.Marshal cannot reuse buffers between calls as it has to return
        // a copy of the byte slice but Encoder does as it directly writes to w.
        err = json.NewEncoder(util.WriterFunc(func(p []byte) (int, error) {
                err := c.Write(ctx, websocket.MessageText, p)
                if err != nil {
                        return 0, err
                }
                return len(p), nil
        })).Encode(v)
        if err != nil {
                return fmt.Errorf("failed to marshal JSON: %w", err)
        }
        return nil
}

RetroSearch is an open source project built by @garambo | Open a GitHub Issue

Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo

HTML: 3.2 | Encoding: UTF-8 | Version: 0.7.4