package clientip
import (
"context"
"errors"
)
// ResultKind is a coarse-grained classification for extraction and resolution
// results.
//
// ClassifyError returns ResultSuccess for nil and ResultUnknown for non-nil
// errors outside the package's standard extraction and resolution surface.
type ResultKind uint8
const (
// ResultUnknown indicates a non-nil error outside the package's standard
// extraction and resolution categories.
ResultUnknown ResultKind = iota
// ResultSuccess indicates the operation completed without error.
ResultSuccess
// ResultUnavailable indicates the selected source was not present.
ResultUnavailable
// ResultInvalid indicates invalid request input or an invalid client IP.
ResultInvalid
// ResultUntrusted indicates the request failed trusted-proxy validation.
ResultUntrusted
// ResultMalformed indicates malformed or conflicting proxy-header input.
ResultMalformed
// ResultCanceled indicates context cancellation or deadline expiry.
ResultCanceled
// ResultFallback indicates operational resolution succeeded via fallback.
ResultFallback
)
// ClassifyError maps the package's detailed error surface into a smaller set of
// policy-oriented result kinds.
//
// This helper is additive: typed errors and errors.Is / errors.As remain the
// detailed interface when callers need source-specific diagnostics.
func ClassifyError(err error) ResultKind {
switch {
case err == nil:
return ResultSuccess
case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded):
return ResultCanceled
case errors.Is(err, ErrSourceUnavailable):
return ResultUnavailable
case errors.Is(err, ErrUntrustedProxy),
errors.Is(err, ErrNoTrustedProxies),
errors.Is(err, ErrTooFewTrustedProxies),
errors.Is(err, ErrTooManyTrustedProxies):
return ResultUntrusted
case errors.Is(err, ErrInvalidForwardedHeader),
errors.Is(err, ErrChainTooLong),
errors.Is(err, ErrMultipleSingleIPHeaders):
return ResultMalformed
case errors.Is(err, ErrInvalidIP), errors.Is(err, ErrNilRequest):
return ResultInvalid
default:
return ResultUnknown
}
}
// String returns the stable label for k.
func (k ResultKind) String() string {
switch k {
case ResultSuccess:
return "success"
case ResultUnavailable:
return "unavailable"
case ResultInvalid:
return "invalid"
case ResultUntrusted:
return "untrusted"
case ResultMalformed:
return "malformed"
case ResultCanceled:
return "canceled"
case ResultFallback:
return "fallback"
default:
return "unknown"
}
}
package clientip
import (
"fmt"
"net/netip"
"reflect"
"slices"
)
// Option configures a Resolver.
type Option interface {
applyOption(*options)
}
type optionFunc func(*options)
func (f optionFunc) applyOption(c *options) { f(c) }
const (
// DefaultMaxChainLength is the maximum number of IPs allowed in a proxy
// chain. This prevents DoS attacks using extremely long header values that
// could cause excessive memory allocation or CPU usage during parsing. 100
// is chosen as a reasonable upper bound that accommodates complex
// multi-region, multi-CDN setups while still providing protection. Typical
// proxy chains rarely exceed 5-10 entries.
DefaultMaxChainLength = 100
)
// ChainSelection controls how the client candidate is selected from a parsed
// Forwarded or X-Forwarded-For proxy chain after trusted proxy validation. The
// default is RightmostUntrustedIP.
type ChainSelection int
const (
// RightmostUntrustedIP selects the rightmost untrusted address before the
// trailing trusted proxy suffix. This is the default and recommended mode
// for most deployments. It starts at 1 to avoid zero-value confusion and
// make invalid selections explicit.
RightmostUntrustedIP ChainSelection = iota + 1
// LeftmostUntrustedIP selects the leftmost untrusted address before the
// trailing trusted proxy suffix. Use it only when trusted proxies are
// configured and the forwarded chain is produced or sanitized by those
// trusted proxies.
LeftmostUntrustedIP
)
// String returns the canonical text representation of s.
func (s ChainSelection) String() string {
switch s {
case RightmostUntrustedIP:
return "rightmost_untrusted"
case LeftmostUntrustedIP:
return "leftmost_untrusted"
default:
return "unknown"
}
}
// valid reports whether s is a supported chain-selection mode.
func (s ChainSelection) valid() bool {
return s == RightmostUntrustedIP || s == LeftmostUntrustedIP
}
// options configures an extractor.
//
// Start from defaultOptions or one of the Preset... helpers unless you need a
// custom proxy topology. New normalizes prefixes and sources before
// validation.
type options struct {
// TrustedProxyPrefixes contains upstream proxy ranges that are allowed to
// supply header-based client IPs. Any header source in Sources requires
// this field to be non-empty, and the immediate RemoteAddr must match one
// of these prefixes when a header is present.
TrustedProxyPrefixes []netip.Prefix
// MinTrustedProxies rejects parsed proxy chains with fewer than this many
// trusted proxies. A value of 0 means no minimum.
MinTrustedProxies int
// MaxTrustedProxies rejects parsed proxy chains with more than this many
// trusted proxies. A value of 0 means no maximum.
MaxTrustedProxies int
// AllowPrivateIPs allows RFC1918 and unique-local client addresses.
// Loopback, link-local, multicast, and unspecified addresses are still
// rejected.
AllowPrivateIPs bool
// AllowedReservedClientPrefixes allows selected reserved or special-use
// client ranges that are otherwise rejected, such as documentation
// prefixes in tests.
AllowedReservedClientPrefixes []netip.Prefix
// MaxChainLength limits the number of IPs accepted in Forwarded and
// X-Forwarded-For chains. A value of 0 uses DefaultMaxChainLength.
MaxChainLength int
// ChainSelection selects the client candidate from Forwarded and
// X-Forwarded-For chains. Leave zero for the default RightmostUntrustedIP.
// LeftmostUntrustedIP requires TrustedProxyPrefixes when a chain source is
// configured.
ChainSelection ChainSelection
// DebugInfo includes parsed chain details in successful chain-source
// extractions. It is intended for diagnostics rather than hot-path
// logging.
DebugInfo bool
// Sources is the strict extraction order. A nil slice uses the default
// RemoteAddr-only source; an explicit empty slice is invalid.
Sources []Source
// Logger receives security-significant extractor events. Nil disables
// logging. Typed nil implementations are rejected during validation.
Logger Logger
// Observer receives one event per resolver call on a valid Resolver. Nil
// disables observation. Typed nil implementations are rejected during
// validation.
Observer Observer
}
// WithTrustedProxies declares upstream proxy ranges allowed to supply
// header-based client IPs.
//
// Header sources are accepted only when the immediate RemoteAddr peer is in
// one of these prefixes. Chain sources such as SourceForwarded and
// SourceXForwardedFor also use these prefixes to identify the trusted suffix
// of proxy hops before selecting the client candidate. Only include ranges
// that can actually connect to this service.
func WithTrustedProxies(prefixes ...netip.Prefix) Option {
return optionFunc(func(c *options) { c.TrustedProxyPrefixes = clonePrefixes(prefixes) })
}
// WithMinTrustedProxies rejects chains with fewer than n trusted proxies.
//
// This validates the number of CIDR-trusted hops found in a parsed chain. It
// does not make a header source trustworthy on its own; header sources still
// require WithTrustedProxies and a trusted immediate peer.
func WithMinTrustedProxies(n int) Option {
return optionFunc(func(c *options) { c.MinTrustedProxies = n })
}
// WithMaxTrustedProxies rejects chains with more than n trusted proxies.
//
// This validates the number of CIDR-trusted hops found in a parsed chain. It
// does not make a header source trustworthy on its own; header sources still
// require WithTrustedProxies and a trusted immediate peer.
func WithMaxTrustedProxies(n int) Option {
return optionFunc(func(c *options) { c.MaxTrustedProxies = n })
}
// WithAllowPrivateIPs allows RFC1918 and unique-local client addresses.
//
// Use this only when private clients are valid users in your deployment, such
// as internal services or private-network applications. Loopback, link-local,
// multicast, unspecified, and disallowed reserved addresses are still
// rejected.
func WithAllowPrivateIPs() Option {
return optionFunc(func(c *options) { c.AllowPrivateIPs = true })
}
// WithAllowedReservedClientPrefixes allows selected special-use client ranges.
//
// This is mainly useful for tests, documentation examples, and deployments
// that intentionally use a specific special-use range. It does not allow
// invalid addresses generally; only client IPs contained in the supplied
// prefixes bypass reserved-range rejection.
func WithAllowedReservedClientPrefixes(prefixes ...netip.Prefix) Option {
return optionFunc(func(c *options) { c.AllowedReservedClientPrefixes = clonePrefixes(prefixes) })
}
// WithMaxChainLength caps Forwarded and X-Forwarded-For chain length.
//
// A zero value uses DefaultMaxChainLength. Negative values are rejected by
// New.
func WithMaxChainLength(n int) Option {
return optionFunc(func(c *options) { c.MaxChainLength = n })
}
// WithChainSelection sets the chain client-candidate selection algorithm.
//
// RightmostUntrustedIP is the default and selects the nearest untrusted hop
// before the trailing trusted proxy suffix. LeftmostUntrustedIP selects the
// earliest untrusted hop and should be used only when trusted proxies fully
// produce or sanitize the forwarded chain.
func WithChainSelection(selection ChainSelection) Option {
return optionFunc(func(c *options) { c.ChainSelection = selection })
}
// WithDebugInfo includes parsed chain diagnostics on successful chain results.
//
// Debug information is intended for diagnostics and tests. Prefer Logger or
// Observer for routine operational visibility.
func WithDebugInfo() Option {
return optionFunc(func(c *options) { c.DebugInfo = true })
}
// WithSources sets the strict extraction source order.
//
// Sources are attempted in order. ErrSourceUnavailable allows the next source
// to run, while malformed headers, proxy-trust failures, chain limits, invalid
// client IPs, and context errors are terminal. Header-based sources require
// WithTrustedProxies. SourceStaticFallback is result-only and is rejected
// here.
func WithSources(sources ...Source) Option {
return optionFunc(func(c *options) { c.Sources = cloneSources(sources) })
}
// WithLogger configures security-significant warning logs.
//
// Logger implementations should be safe for concurrent use. Logging receives
// extractor security events such as malformed headers, untrusted proxies, and
// rejected client IPs; operational fallback itself is represented on Result.
func WithLogger(logger Logger) Option {
return optionFunc(func(c *options) { c.Logger = logger })
}
// WithObserver configures result-level observation.
//
// Observer implementations should be safe for concurrent use. Observation is
// result-level: it sees strict successes, strict failures, and successful
// operational fallbacks after each resolver call.
func WithObserver(observer Observer) Option {
return optionFunc(func(c *options) { c.Observer = observer })
}
// defaultOptions returns the default extractor configuration.
//
// The default is safe for direct client-to-app traffic: RemoteAddr only,
// RightmostUntrustedIP chain selection, DefaultMaxChainLength, no trusted
// proxy prefixes, and no-op logging/metrics.
func defaultOptions() options {
return options{
MaxChainLength: DefaultMaxChainLength,
ChainSelection: RightmostUntrustedIP,
Sources: []Source{builtinSource(sourceRemoteAddr)},
}
}
// LoopbackProxyPrefixes returns loopback CIDRs commonly used when the app sits
// behind a reverse proxy on the same host.
func LoopbackProxyPrefixes() []netip.Prefix {
return clonePrefixes(loopbackProxyCIDRs)
}
// PrivateProxyPrefixes returns private-network CIDRs commonly used for trusted
// upstream proxies in VM and internal network deployments.
func PrivateProxyPrefixes() []netip.Prefix {
return clonePrefixes(privateProxyCIDRs)
}
// LocalProxyPrefixes returns loopback and private-network proxy CIDRs.
func LocalProxyPrefixes() []netip.Prefix {
return mergeUniquePrefixes(clonePrefixes(loopbackProxyCIDRs), privateProxyCIDRs...)
}
// ProxyPrefixesFromAddrs converts individual proxy addresses into host-sized
// trusted prefixes.
//
// IPv4 addresses become /32 prefixes, IPv6 addresses become /128 prefixes, and
// IPv4-mapped IPv6 addresses are normalized to IPv4 before conversion.
func ProxyPrefixesFromAddrs(addrs ...netip.Addr) ([]netip.Prefix, error) {
prefixes := make([]netip.Prefix, 0, len(addrs))
for _, addr := range addrs {
if !addr.IsValid() {
return nil, fmt.Errorf("invalid proxy address %q", addr)
}
addr = normalizeIP(addr)
prefixes = append(prefixes, netip.PrefixFrom(addr, addr.BitLen()))
}
return prefixes, nil
}
// config holds normalized runtime configuration state.
type config struct {
trustedProxyCIDRs []netip.Prefix
trustedProxyMatch prefixMatcher
minTrustedProxies int
maxTrustedProxies int
allowPrivateIPs bool
allowReservedClientPrefixes []netip.Prefix
maxChainLength int
chainSelection ChainSelection
debugMode bool
sourcePriority []Source
sourceHeaderKeys []string
// clientIP and proxy are derived from the fields above and populated by
// configFromPublic after all other normalization is complete. They are
// kept here so source extractors can capture stable handles without
// round-tripping through config on every hot-path call.
clientIP clientIPPolicy
proxy proxyPolicy
logger Logger
loggerNoop bool
observer Observer
}
// validate checks normalized runtime configuration after defaults and public
// option overrides have been applied.
func (c *config) validate() error {
if c.minTrustedProxies < 0 {
return fmt.Errorf("minTrustedProxies must be >= 0, got %d", c.minTrustedProxies)
}
if c.maxTrustedProxies < 0 {
return fmt.Errorf("maxTrustedProxies must be >= 0, got %d", c.maxTrustedProxies)
}
if c.maxTrustedProxies > 0 && c.minTrustedProxies > c.maxTrustedProxies {
return fmt.Errorf("minTrustedProxies (%d) cannot exceed maxTrustedProxies (%d)", c.minTrustedProxies, c.maxTrustedProxies)
}
if c.minTrustedProxies > 0 && len(c.trustedProxyCIDRs) == 0 {
return fmt.Errorf("minTrustedProxies > 0 requires TrustedProxyPrefixes to be configured for security validation; to skip validation and trust all proxies, set TrustedProxyPrefixes to 0.0.0.0/0 and ::/0")
}
if c.maxChainLength <= 0 {
return fmt.Errorf("maxChainLength must be > 0, got %d", c.maxChainLength)
}
if !c.chainSelection.valid() {
return fmt.Errorf("invalid chain selection %d (must be RightmostUntrustedIP=1 or LeftmostUntrustedIP=2)", c.chainSelection)
}
if len(c.sourcePriority) == 0 {
return fmt.Errorf("at least one source required in priority list")
}
hasHeaderSource, hasChainSource, err := c.validateSourcePriority()
if err != nil {
return err
}
if hasChainSource && c.chainSelection == LeftmostUntrustedIP && len(c.trustedProxyCIDRs) == 0 {
return fmt.Errorf("LeftmostUntrustedIP selection requires trusted proxy prefixes to be configured; without trusted-proxy validation, this selection provides no security benefit over RightmostUntrustedIP")
}
if hasHeaderSource && len(c.trustedProxyCIDRs) == 0 {
return fmt.Errorf("header-based sources require trusted proxy prefixes; configure TrustedProxyPrefixes directly or use LoopbackProxyPrefixes, PrivateProxyPrefixes, LocalProxyPrefixes, or ProxyPrefixesFromAddrs")
}
if isNilValue(c.logger) {
return fmt.Errorf("logger cannot be nil")
}
if isNilValue(c.observer) {
return fmt.Errorf("observer cannot be nil")
}
return nil
}
// validateSourcePriority rejects invalid or duplicate canonical sources and
// enforces the one-chain-header rule. Mixing Forwarded and XFF would create two
// independent proxy chains with unclear trust semantics.
func (c *config) validateSourcePriority() (hasHeaderSource, hasChainSource bool, err error) {
seen := make(map[Source]struct{}, len(c.sourcePriority))
seenForwarded := false
seenXFF := false
for _, source := range c.sourcePriority {
source = canonicalSource(source)
if !source.valid() {
return false, false, fmt.Errorf("source names cannot be empty")
}
if _, ok := seen[source]; ok {
return false, false, fmt.Errorf("duplicate source %q in priority list", source)
}
seen[source] = struct{}{}
switch source.kind {
case sourceStaticFallback:
return false, false, fmt.Errorf("source %q is resolver-only and cannot be configured with WithSources", source)
case sourceForwarded:
seenForwarded = true
hasChainSource = true
hasHeaderSource = true
case sourceXForwardedFor:
seenXFF = true
hasChainSource = true
hasHeaderSource = true
case sourceXRealIP, sourceHeader:
hasHeaderSource = true
}
}
if seenForwarded && seenXFF {
return false, false, fmt.Errorf("priority cannot include both %q and %q; choose one proxy chain header", builtinSource(sourceForwarded), builtinSource(sourceXForwardedFor))
}
return hasHeaderSource, hasChainSource, nil
}
var (
// loopbackProxyCIDRs contains loopback networks used when the app sits
// behind a reverse proxy running on the same host.
loopbackProxyCIDRs = []netip.Prefix{
mustParsePrefix("127.0.0.0/8"),
mustParsePrefix("::1/128"),
}
// privateProxyCIDRs contains private-network ranges commonly used for
// trusted upstream proxies in VM and internal network deployments.
privateProxyCIDRs = []netip.Prefix{
mustParsePrefix("10.0.0.0/8"),
mustParsePrefix("172.16.0.0/12"),
mustParsePrefix("192.168.0.0/16"),
mustParsePrefix("fc00::/7"),
}
)
func mustParsePrefix(cidr string) netip.Prefix {
prefix, err := netip.ParsePrefix(cidr)
if err != nil {
panic(fmt.Sprintf("invalid built-in CIDR %q: %v", cidr, err))
}
return prefix
}
func clonePrefixes(prefixes []netip.Prefix) []netip.Prefix {
return slices.Clone(prefixes)
}
func cloneSources(values []Source) []Source {
return slices.Clone(values)
}
// isNilValue catches typed nil interface values, such as (*myLogger)(nil),
// that compare non-nil as an interface but would panic when called.
func isNilValue(v any) bool {
if v == nil {
return true
}
rv := reflect.ValueOf(v)
switch rv.Kind() {
case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Pointer, reflect.Slice:
return rv.IsNil()
default:
return false
}
}
// normalizePrefixes validates and masks prefixes so equivalent inputs dedupe
// and match deterministically.
func normalizePrefixes(prefixes []netip.Prefix, kind string) ([]netip.Prefix, error) {
normalized := make([]netip.Prefix, 0, len(prefixes))
for _, prefix := range prefixes {
if !prefix.IsValid() {
return nil, fmt.Errorf("invalid %s %q", kind, prefix)
}
normalized = append(normalized, prefix.Masked())
}
return normalized, nil
}
func normalizeTrustedProxyPrefixes(prefixes []netip.Prefix) ([]netip.Prefix, error) {
return normalizePrefixes(prefixes, "trusted proxy prefix")
}
func normalizeReservedClientPrefixes(prefixes []netip.Prefix) ([]netip.Prefix, error) {
return normalizePrefixes(prefixes, "reserved client prefix")
}
func mergeUniquePrefixes(existing []netip.Prefix, additions ...netip.Prefix) []netip.Prefix {
if len(existing) == 0 && len(additions) == 0 {
return nil
}
merged := make([]netip.Prefix, 0, len(existing)+len(additions))
seen := make(map[netip.Prefix]struct{}, len(existing)+len(additions))
for _, prefix := range existing {
if _, ok := seen[prefix]; ok {
continue
}
seen[prefix] = struct{}{}
merged = append(merged, prefix)
}
for _, prefix := range additions {
if _, ok := seen[prefix]; ok {
continue
}
seen[prefix] = struct{}{}
merged = append(merged, prefix)
}
return merged
}
func defaultConfig() *config {
defaults := defaultOptions()
return &config{
minTrustedProxies: defaults.MinTrustedProxies,
maxTrustedProxies: defaults.MaxTrustedProxies,
allowPrivateIPs: defaults.AllowPrivateIPs,
maxChainLength: defaults.MaxChainLength,
chainSelection: defaults.ChainSelection,
logger: noopLogger{},
loggerNoop: true,
observer: noopObserver{},
sourcePriority: cloneSources(defaults.Sources),
}
}
// configFromPublic builds immutable runtime config in stages: start with safe
// defaults, apply public option overrides, normalize/canonicalize values,
// derive hot-path policies, then validate the finished configuration.
func configFromPublic(public options) (*config, error) {
cfg := defaultConfig()
if public.TrustedProxyPrefixes != nil {
normalized, err := normalizeTrustedProxyPrefixes(public.TrustedProxyPrefixes)
if err != nil {
return nil, err
}
cfg.trustedProxyCIDRs = mergeUniquePrefixes(nil, normalized...)
}
if public.AllowedReservedClientPrefixes != nil {
normalized, err := normalizeReservedClientPrefixes(public.AllowedReservedClientPrefixes)
if err != nil {
return nil, err
}
cfg.allowReservedClientPrefixes = mergeUniquePrefixes(nil, normalized...)
}
if public.MaxChainLength != 0 {
cfg.maxChainLength = public.MaxChainLength
}
if public.ChainSelection != 0 {
cfg.chainSelection = public.ChainSelection
}
if public.Sources != nil {
cfg.sourcePriority = canonicalizeSources(cloneSources(public.Sources))
}
cfg.minTrustedProxies = public.MinTrustedProxies
cfg.maxTrustedProxies = public.MaxTrustedProxies
cfg.allowPrivateIPs = public.AllowPrivateIPs
cfg.debugMode = public.DebugInfo
if public.Logger != nil {
cfg.logger = public.Logger
cfg.loggerNoop = false
}
if public.Observer != nil {
cfg.observer = public.Observer
}
cfg.sourceHeaderKeys = sourceHeaderKeys(cfg.sourcePriority)
cfg.trustedProxyMatch = newPrefixMatcher(cfg.trustedProxyCIDRs)
cfg.clientIP = clientIPPolicy{
AllowPrivateIPs: cfg.allowPrivateIPs,
AllowReservedClientPrefixes: cfg.allowReservedClientPrefixes,
}
cfg.proxy = proxyPolicy{
TrustedProxyCIDRs: cfg.trustedProxyCIDRs,
TrustedProxyMatch: cfg.trustedProxyMatch,
MinTrustedProxies: cfg.minTrustedProxies,
MaxTrustedProxies: cfg.maxTrustedProxies,
}
if err := cfg.validate(); err != nil {
return nil, err
}
return cfg, nil
}
package clientip
import (
"errors"
"fmt"
"net/http"
"net/textproto"
)
// extractor resolves client IP information from HTTP requests and
// framework-agnostic request inputs.
//
// extractor instances are safe for concurrent reuse.
type extractor struct {
config *config
sources []configuredSource
}
type configuredSource struct {
source Source
name string
unavailableErr *ExtractionError
chain chainExtractor
single singleHeaderExtractor
remote remoteAddrExtractor
}
// newExtractor creates an extractor from a options.
//
// newExtractor applies default values, normalizes prefixes and sources,
// validates the resulting configuration, and returns an extractor safe for
// concurrent reuse.
func newExtractor(public options) (*extractor, error) {
cfg, err := configFromPublic(public)
if err != nil {
return nil, fmt.Errorf("invalid configuration: %w", err)
}
extractor := &extractor{config: cfg}
extractor.sources = extractor.buildConfiguredSources(cfg.sourcePriority)
return extractor, nil
}
// Extract resolves client IP and metadata for the request.
//
// Configured sources are attempted in order. ErrSourceUnavailable allows the
// next source to run; malformed headers, proxy-trust failures, chain limits,
// invalid client IPs, and context errors are terminal.
func (e *extractor) Extract(r *http.Request) (Extraction, error) {
if r == nil {
return Extraction{}, ErrNilRequest
}
// sourceHeaderKeys is empty only when the configured sources contain no
// header-based source. Validation guarantees at least one configured
// source, so reaching this branch means the only configured source is
// SourceRemoteAddr.
if len(e.config.sourceHeaderKeys) == 0 {
if ctx := r.Context(); ctx.Err() != nil {
return Extraction{}, ctx.Err()
}
return e.extractFromRemoteAddr(r.RemoteAddr)
}
return e.extractRequestView(requestViewFromRequest(r))
}
// ExtractInput resolves client IP and metadata from framework-agnostic request
// input.
//
// It follows the same source ordering and terminal-error rules as Extract.
func (e *extractor) ExtractInput(input Input) (Extraction, error) {
ctx := requestInputContext(input)
if err := ctx.Err(); err != nil {
return Extraction{}, err
}
// See Extract: an empty sourceHeaderKeys means SourceRemoteAddr is the
// only configured source.
if len(e.config.sourceHeaderKeys) == 0 {
return e.extractFromRemoteAddr(input.RemoteAddr)
}
return e.extractRequestView(requestViewFromInput(input))
}
func (e *extractor) extractRequestView(r requestView) (Extraction, error) {
if err := r.context().Err(); err != nil {
return Extraction{}, err
}
for i := range e.sources {
source := &e.sources[i]
if i > 0 {
if err := r.context().Err(); err != nil {
return Extraction{}, err
}
}
var (
result Extraction
err error
)
switch source.source.kind {
case sourceForwarded:
result, err = e.extractChainSource(
r,
source,
"Forwarded chain exceeds configured maximum length",
"request received from untrusted proxy while Forwarded is present",
func(err error) {
if !errors.Is(err, ErrInvalidForwardedHeader) {
return
}
e.logSecurityWarning(r, source.source, SecurityEventMalformedForwarded, "malformed Forwarded header received", "parse_error", err.Error())
},
)
case sourceXForwardedFor:
result, err = e.extractChainSource(
r,
source,
"X-Forwarded-For chain exceeds configured maximum length",
"request received from untrusted proxy while X-Forwarded-For is present",
nil,
)
case sourceRemoteAddr:
result, err = e.extractRemoteAddrSource(r, source)
default:
result, err = e.extractSingleHeaderSource(r, source)
}
if err == nil {
return result, nil
}
if !errors.Is(err, ErrSourceUnavailable) {
if !result.Source.valid() {
result.Source = source.source
}
return result, err
}
if i == len(e.sources)-1 {
if !result.Source.valid() {
result.Source = source.source
}
return result, err
}
}
return Extraction{}, ErrSourceUnavailable
}
func (e *extractor) extractFromRemoteAddr(remoteAddr string) (Extraction, error) {
source := builtinSource(sourceRemoteAddr)
result, failure := remoteAddrExtractor{clientIPPolicy: e.config.clientIP}.extract(remoteAddr, source)
if failure != nil {
err := adaptRemoteAddrFailure(failure, source)
result.Source = source
return result, err
}
return result, nil
}
func (e *extractor) buildConfiguredSources(sources []Source) []configuredSource {
configured := make([]configuredSource, len(sources))
for i, source := range sources {
source := source
headerName, _ := sourceHeaderKey(source)
if headerName != "" {
headerName = textproto.CanonicalMIMEHeaderKey(headerName)
}
configuredSource := configuredSource{
source: source,
name: source.String(),
unavailableErr: &ExtractionError{Err: ErrSourceUnavailable, Source: source},
}
switch source.kind {
case sourceForwarded:
configuredSource.chain = chainExtractor{policy: chainPolicy{
headerName: headerName,
parseValues: func(values []string) ([]string, error) {
parts, err := parseForwardedValues(values, e.config.maxChainLength)
if err != nil {
return nil, adaptForwardedParseError(err, source, e)
}
return parts, nil
},
parseClientIP: parseChainIP,
clientIP: e.config.clientIP,
trustedProxy: e.config.proxy,
selection: e.config.chainSelection,
collectDebugInfo: e.config.debugMode,
untrustedChainSep: ", ",
}}
case sourceXForwardedFor:
configuredSource.chain = chainExtractor{policy: chainPolicy{
headerName: headerName,
parseValues: func(values []string) ([]string, error) {
parts, err := parseXFFValues(values, e.config.maxChainLength)
if err != nil {
return nil, adaptXFFParseError(err, source, e)
}
return parts, nil
},
parseClientIP: parseIP,
clientIP: e.config.clientIP,
trustedProxy: e.config.proxy,
selection: e.config.chainSelection,
collectDebugInfo: e.config.debugMode,
untrustedChainSep: ", ",
}}
case sourceRemoteAddr:
configuredSource.remote = remoteAddrExtractor{clientIPPolicy: e.config.clientIP}
default:
configuredSource.single = singleHeaderExtractor{policy: singleHeaderPolicy{
headerName: headerName,
clientIP: e.config.clientIP,
trustedProxy: e.config.proxy,
}}
}
configured[i] = configuredSource
}
return configured
}
package clientip
import "context"
// HeaderValues provides access to request header values by name.
//
// Implementations should return one slice entry per received header line.
// Single-IP sources rely on per-line values to detect duplicates, and chain
// sources preserve wire order across repeated lines. Do not merge duplicate
// header lines into one comma-joined value.
//
// Header names are requested in canonical MIME format (for example
// "X-Forwarded-For").
//
// net/http's http.Header satisfies this interface directly.
type HeaderValues interface {
Values(name string) []string
}
// HeaderValuesFunc adapts a function to the HeaderValues interface.
type HeaderValuesFunc func(name string) []string
// Values implements HeaderValues.
func (f HeaderValuesFunc) Values(name string) []string {
if f == nil {
return nil
}
return f(name)
}
// Input provides framework-agnostic request data for extraction.
//
// Context defaults to context.Background() when nil.
//
// For Headers, preserve repeated header lines as separate values for each
// header name (for example two X-Forwarded-For lines should yield a slice with
// length 2, and two X-Real-IP lines should also yield length 2). A nil Headers
// provider makes header-based sources unavailable; SourceRemoteAddr can still
// run if it is included with WithSources.
type Input struct {
Context context.Context
RemoteAddr string
Headers HeaderValues
}
func requestInputContext(input Input) context.Context {
if input.Context == nil {
return context.Background()
}
return input.Context
}
package clientip
import "context"
// SecurityEvent... constants are stable public labels for extractor security
// events. Log consumers can depend on these names.
const (
SecurityEventMultipleHeaders = "multiple_headers"
SecurityEventChainTooLong = "chain_too_long"
SecurityEventUntrustedProxy = "untrusted_proxy"
SecurityEventNoTrustedProxies = "no_trusted_proxies"
SecurityEventTooFewTrustedProxies = "too_few_trusted_proxies"
SecurityEventTooManyTrustedProxies = "too_many_trusted_proxies"
SecurityEventInvalidIP = "invalid_ip"
SecurityEventReservedIP = "reserved_ip"
SecurityEventPrivateIP = "private_ip"
SecurityEventMalformedForwarded = "malformed_forwarded"
)
// Logger records security-significant events emitted by extractor.
//
// Implementations should be safe for concurrent use, as a single extractor
// instance is typically shared across many goroutines.
//
// The provided context comes from the inbound HTTP request and can carry
// tracing metadata (for example, trace or span IDs).
//
// Operational fallback does not emit separate log events. Inspect
// Result.FallbackUsed when that distinction matters.
//
// The interface intentionally mirrors slog's WarnContext signature, so
// *slog.Logger can be used directly without an adapter.
type Logger interface {
WarnContext(ctx context.Context, msg string, args ...any)
}
// noopLogger is the default Logger implementation when logging is not
// explicitly configured.
type noopLogger struct{}
func (noopLogger) WarnContext(context.Context, string, ...any) {}
// Observer receives one event per resolver call on a valid Resolver.
//
// Implementations should be safe for concurrent use. Observer is result-level,
// so it can see strict successes, strict failures, and operational fallbacks.
type Observer interface {
OnResolved(ctx context.Context, result Result)
}
type noopObserver struct{}
func (noopObserver) OnResolved(context.Context, Result) {}
package clientip
import "strings"
const typicalChainCapacity = 8
// chainPartsCapacity returns an allocation hint for parsed chain parts.
//
// It intentionally samples common short chain shapes instead of fully counting
// delimiters; validation and max-length enforcement happen in the parsers.
func chainPartsCapacity(values []string, maxLength int) int {
if maxLength <= 0 {
maxLength = 1
}
if len(values) == 1 {
v := values[0]
firstComma := strings.IndexByte(v, ',')
if firstComma == -1 {
return 1
}
secondComma := strings.IndexByte(v[firstComma+1:], ',')
if secondComma == -1 {
if maxLength < 2 {
return maxLength
}
return 2
}
if strings.IndexByte(v[firstComma+secondComma+2:], ',') == -1 {
if maxLength < 3 {
return maxLength
}
return 3
}
} else if len(values) == 2 {
if strings.IndexByte(values[0], ',') == -1 && strings.IndexByte(values[1], ',') == -1 {
if maxLength < 2 {
return maxLength
}
return 2
}
}
if maxLength < typicalChainCapacity {
return maxLength
}
return typicalChainCapacity
}
// trimHTTPWhitespace trims HTTP optional whitespace: SP and HTAB only.
func trimHTTPWhitespace(value string) string {
start := 0
for start < len(value) {
ch := value[start]
if ch != ' ' && ch != '\t' {
break
}
start++
}
end := len(value)
for end > start {
ch := value[end-1]
if ch != ' ' && ch != '\t' {
break
}
end--
}
return value[start:end]
}
package clientip
import "fmt"
type chainTooLongParseError struct {
ChainLength int
MaxLength int
}
func (e *chainTooLongParseError) Error() string {
return fmt.Sprintf("proxy chain too long (chain_length=%d, max_length=%d)", e.ChainLength, e.MaxLength)
}
package clientip
import (
"errors"
"fmt"
"strings"
)
// parseForwardedValues extracts RFC 7239 for= values from repeated Forwarded
// header lines in arrival order. Only for= parameters become chain entries;
// malformed syntax fails closed because a sabotaged Forwarded header can hide
// or reorder client attribution.
func parseForwardedValues(values []string, maxChainLength int) ([]string, error) {
if len(values) == 0 {
return nil, nil
}
parts := make([]string, 0, chainPartsCapacity(values, maxChainLength))
for _, value := range values {
err := scanForwardedSegments(value, ',', "element", func(element string) error {
forwardedFor, hasFor, parseErr := parseForwardedElement(element)
if parseErr != nil {
return parseErr
}
if !hasFor {
return nil
}
if len(parts) >= maxChainLength {
return &chainTooLongParseError{
ChainLength: len(parts) + 1,
MaxLength: maxChainLength,
}
}
parts = append(parts, forwardedFor)
return nil
})
if err != nil {
var chainErr *chainTooLongParseError
if errors.As(err, &chainErr) {
return nil, err
}
return nil, err
}
}
return parts, nil
}
// parseForwardedElement extracts at most one for= parameter from an element.
// Duplicate for= parameters are rejected as ambiguous instead of choosing one.
func parseForwardedElement(element string) (forwardedFor string, hasFor bool, err error) {
err = scanForwardedSegments(element, ';', "parameter", func(param string) error {
eq := strings.IndexByte(param, '=')
if eq <= 0 {
return fmt.Errorf("invalid forwarded parameter %q", param)
}
key := strings.TrimSpace(param[:eq])
value := strings.TrimSpace(param[eq+1:])
if key == "" {
return fmt.Errorf("empty parameter key in %q", param)
}
if value == "" {
return fmt.Errorf("empty parameter value for %q", key)
}
if !strings.EqualFold(key, "for") {
return nil
}
if hasFor {
return fmt.Errorf("duplicate for parameter in element %q", element)
}
parsedValue, parseErr := parseForwardedForValue(value)
if parseErr != nil {
return parseErr
}
forwardedFor = parsedValue
hasFor = true
return nil
})
if err != nil {
return "", false, err
}
return forwardedFor, hasFor, nil
}
// scanForwardedSegments splits on delimiter while respecting quoted strings
// and quoted-pair escapes. This prevents commas or semicolons inside quoted
// values from changing the element/parameter structure we validate.
func scanForwardedSegments(value string, delimiter byte, segmentKind string, onSegment func(string) error) error {
start := 0
inQuotes := false
escaped := false
for i := 0; i <= len(value); i++ {
if i == len(value) {
if inQuotes {
return fmt.Errorf("unterminated quoted string in %q", value)
}
if escaped {
return fmt.Errorf("unterminated escape in %q", value)
}
} else {
ch := value[i]
if escaped {
escaped = false
continue
}
if ch == '\\' && inQuotes {
escaped = true
continue
}
if ch == '"' {
inQuotes = !inQuotes
continue
}
if ch != delimiter || inQuotes {
continue
}
}
segment := strings.TrimSpace(value[start:i])
if segment == "" {
return fmt.Errorf("empty forwarded %s in %q", segmentKind, value)
}
if err := onSegment(segment); err != nil {
return err
}
start = i + 1
}
return nil
}
// parseForwardedForValue normalizes the value side of for=. Quoted values must
// be fully quoted and valid; partially quoted or empty values are malformed.
func parseForwardedForValue(value string) (string, error) {
value = strings.TrimSpace(value)
if value == "" {
return "", fmt.Errorf("empty for value")
}
if value[0] == '"' {
unquoted, err := unquoteForwardedValue(value)
if err != nil {
return "", err
}
value = strings.TrimSpace(unquoted)
}
if value == "" {
return "", fmt.Errorf("empty for value")
}
return value, nil
}
// unquoteForwardedValue decodes a Forwarded quoted-string and rejects raw
// quotes or dangling escapes so malformed header input remains terminal.
func unquoteForwardedValue(value string) (string, error) {
if len(value) < 2 || value[0] != '"' || value[len(value)-1] != '"' {
return "", fmt.Errorf("invalid quoted string %q", value)
}
inner := value[1 : len(value)-1]
if strings.IndexByte(inner, '\\') == -1 {
if strings.IndexByte(inner, '"') != -1 {
return "", fmt.Errorf("unexpected quote in %q", value)
}
return inner, nil
}
var b strings.Builder
b.Grow(len(inner))
escaped := false
for i := 1; i < len(value)-1; i++ {
ch := value[i]
if escaped {
b.WriteByte(ch)
escaped = false
continue
}
if ch == '\\' {
escaped = true
continue
}
if ch == '"' {
return "", fmt.Errorf("unexpected quote in %q", value)
}
b.WriteByte(ch)
}
if escaped {
return "", fmt.Errorf("unterminated escape in %q", value)
}
return b.String(), nil
}
package clientip
import (
"net"
"net/netip"
"strings"
)
// normalizeIP unmaps IPv4-in-IPv6 addresses to their IPv4 form.
func normalizeIP(ip netip.Addr) netip.Addr {
if ip.Is4In6() {
return ip.Unmap()
}
return ip
}
// parseChainIP parses an IP from a chain value that has already been
// extracted and trimmed by a header parser.
//
// This is intentionally stricter than parseIP: it accepts bare IPs,
// bracketed IPs, and bracketed IPs with a numeric port suffix only.
func parseChainIP(s string) netip.Addr {
ip, err := netip.ParseAddr(s)
if err == nil {
return ip
}
if len(s) < 2 || s[0] != '[' {
return netip.Addr{}
}
end := strings.IndexByte(s, ']')
if end <= 1 {
return netip.Addr{}
}
rest := s[end+1:]
if len(rest) > 0 {
if rest[0] != ':' || len(rest) == 1 {
return netip.Addr{}
}
for i := 1; i < len(rest); i++ {
if rest[i] < '0' || rest[i] > '9' {
return netip.Addr{}
}
}
}
ip, err = netip.ParseAddr(s[1:end])
if err == nil {
return ip
}
return netip.Addr{}
}
// parseIP extracts an IP address from the formats commonly found in proxy headers.
func parseIP(s string) netip.Addr {
s = strings.TrimSpace(s)
if s == "" {
return netip.Addr{}
}
s = trimMatchedChar(s, '"')
s = trimMatchedChar(s, '\'')
if s == "" {
return netip.Addr{}
}
if looksLikeHostPort(s) {
host, ok := splitHostPortHost(s)
if !ok {
return netip.Addr{}
}
ip, ok := parseHostIP(host)
if !ok {
return netip.Addr{}
}
return ip
}
if ip, ok := parseNormalizedIP(s); ok {
return ip
}
host, ok := splitHostPortHost(s)
if !ok {
return netip.Addr{}
}
ip, ok := parseHostIP(host)
if !ok {
return netip.Addr{}
}
return ip
}
// parseRemoteAddr extracts an IP address from Request.RemoteAddr-like input.
func parseRemoteAddr(s string) netip.Addr {
host, ok := splitHostPortHost(s)
if !ok {
return parseIP(s)
}
ip, ok := parseHostIP(host)
if !ok {
return netip.Addr{}
}
return ip
}
func parseHostIP(host string) (netip.Addr, bool) {
ip, err := netip.ParseAddr(host)
if err == nil {
return ip, true
}
return parseNormalizedIP(host)
}
func looksLikeHostPort(s string) bool {
if len(s) < 3 {
return false
}
if s[0] == '[' {
end := strings.LastIndexByte(s, ']')
return end > 0 && end+1 < len(s) && s[end+1] == ':'
}
colon := strings.LastIndexByte(s, ':')
if colon <= 0 || colon == len(s)-1 {
return false
}
return strings.IndexByte(s[:colon], ':') == -1
}
func splitHostPortHost(s string) (string, bool) {
host, _, err := net.SplitHostPort(s)
if err != nil {
return "", false
}
return host, true
}
func parseNormalizedIP(s string) (netip.Addr, bool) {
s = trimMatchedPair(s, '[', ']')
if s == "" {
return netip.Addr{}, false
}
ip, err := netip.ParseAddr(s)
if err != nil {
return netip.Addr{}, false
}
return ip, true
}
func trimMatchedPair(s string, start, end byte) string {
if len(s) < 2 {
return s
}
if s[0] != start || s[len(s)-1] != end {
return s
}
return s[1 : len(s)-1]
}
func trimMatchedChar(s string, ch byte) string {
return trimMatchedPair(s, ch, ch)
}
package clientip
import "net/netip"
// ParseRemoteAddr parses and normalizes Request.RemoteAddr-style input without
// applying extractor plausibility policy.
//
// It accepts host:port values, bracketed IPv6 host:port values, and bare IP
// literals. IPv4-mapped IPv6 addresses are normalized to IPv4. Empty input
// returns ErrSourceUnavailable; unparsable input returns ErrInvalidIP wrapped in
// RemoteAddrError.
func ParseRemoteAddr(remoteAddr string) (netip.Addr, error) {
if remoteAddr == "" {
return netip.Addr{}, &ExtractionError{Err: ErrSourceUnavailable, Source: SourceRemoteAddr}
}
ip := parseRemoteAddr(remoteAddr)
if !ip.IsValid() {
return netip.Addr{}, &RemoteAddrError{
ExtractionError: ExtractionError{Err: ErrInvalidIP, Source: SourceRemoteAddr},
RemoteAddr: remoteAddr,
}
}
return normalizeIP(ip), nil
}
package clientip
import "strings"
// parseXFFValues parses X-Forwarded-For header lines into a logical chain.
//
// XFF is intentionally simpler and more permissive than RFC 7239 Forwarded:
// empty comma-created elements are ignored, while non-empty elements still
// count toward maxChainLength. Repeated header lines are processed in provider
// order.
func parseXFFValues(values []string, maxChainLength int) ([]string, error) {
if len(values) == 0 {
return nil, nil
}
// Fast path: single header value with no commas and no surrounding
// whitespace. Return the input slice directly to avoid allocation.
if len(values) == 1 {
v := values[0]
if strings.IndexByte(v, ',') == -1 {
trimmed := trimHTTPWhitespace(v)
if trimmed == "" {
return nil, nil
}
if maxChainLength <= 0 {
return nil, &chainTooLongParseError{ChainLength: 1, MaxLength: maxChainLength}
}
if trimmed == v {
return values, nil
}
return []string{trimmed}, nil
}
}
parts := make([]string, 0, chainPartsCapacity(values, maxChainLength))
for _, v := range values {
start := 0
for i := 0; i <= len(v); i++ {
if i != len(v) && v[i] != ',' {
continue
}
part := trimHTTPWhitespace(v[start:i])
if part != "" {
if len(parts) >= maxChainLength {
return nil, &chainTooLongParseError{
ChainLength: len(parts) + 1,
MaxLength: maxChainLength,
}
}
parts = append(parts, part)
}
start = i + 1
}
}
return parts, nil
}
package clientip
// PresetDirectConnection configures strict extraction for direct client-to-app
// traffic.
//
// This preset extracts from RemoteAddr only.
func PresetDirectConnection() Option {
return optionFunc(func(c *options) {
c.Sources = []Source{builtinSource(sourceRemoteAddr)}
})
}
// PresetLoopbackReverseProxy configures extraction for apps behind a reverse
// proxy on the same host (for example NGINX on localhost).
//
// It trusts loopback proxy CIDRs and prioritizes X-Forwarded-For before
// RemoteAddr within the extractor's strict source order.
func PresetLoopbackReverseProxy() Option {
return optionFunc(func(c *options) {
c.TrustedProxyPrefixes = LoopbackProxyPrefixes()
c.Sources = []Source{builtinSource(sourceXForwardedFor), builtinSource(sourceRemoteAddr)}
})
}
// PresetVMReverseProxy configures extraction for apps behind a reverse proxy
// in a typical VM or private-network setup.
//
// It trusts loopback and private proxy CIDRs and prioritizes X-Forwarded-For
// before RemoteAddr within the extractor's strict source order.
func PresetVMReverseProxy() Option {
return optionFunc(func(c *options) {
c.TrustedProxyPrefixes = LocalProxyPrefixes()
c.Sources = []Source{builtinSource(sourceXForwardedFor), builtinSource(sourceRemoteAddr)}
})
}
package clientip
import (
"context"
"errors"
"net/http"
"net/netip"
)
var errNilResolverExtractor = errors.New("resolver extractor cannot be nil")
type resultContextKey struct{}
// Fallback controls per-call operational fallback behavior.
type Fallback struct {
mode fallbackMode
staticIP netip.Addr
}
type fallbackMode uint8
const (
fallbackNone fallbackMode = iota
fallbackRemoteAddr
fallbackStaticIP
)
// NoFallback disables operational fallback.
func NoFallback() Fallback { return Fallback{mode: fallbackNone} }
// RemoteAddrFallback falls back to the connecting peer address.
func RemoteAddrFallback() Fallback { return Fallback{mode: fallbackRemoteAddr} }
// StaticFallback falls back to a configured static IP.
//
// Static fallback is a caller-supplied operational value. It is normalized,
// but not checked against the package's client-IP plausibility policy; callers
// that need a routable or policy-valid fallback should validate that before
// passing it here.
func StaticFallback(ip netip.Addr) Fallback {
return Fallback{mode: fallbackStaticIP, staticIP: normalizeIP(ip)}
}
// FallbackReason describes why ResolveOperational used fallback.
type FallbackReason uint8
const (
// FallbackReasonNone indicates no fallback was used.
FallbackReasonNone FallbackReason = iota
// FallbackReasonUntrustedProxy indicates fallback was used because proxy validation failed.
FallbackReasonUntrustedProxy
// FallbackReasonMalformedHeader indicates fallback was used because a header could not be parsed.
FallbackReasonMalformedHeader
// FallbackReasonSourceUnavailable indicates fallback was used because no configured source yielded an IP.
FallbackReasonSourceUnavailable
// FallbackReasonInvalidIP indicates fallback was used because the extracted IP failed client-IP policy.
FallbackReasonInvalidIP
// FallbackReasonUnknown indicates fallback was used because strict extraction returned an unclassified error.
FallbackReasonUnknown
)
// String returns the stable label for r.
func (r FallbackReason) String() string {
switch r {
case FallbackReasonNone:
return "none"
case FallbackReasonUntrustedProxy:
return "untrusted_proxy"
case FallbackReasonMalformedHeader:
return "malformed_header"
case FallbackReasonSourceUnavailable:
return "source_unavailable"
case FallbackReasonInvalidIP:
return "invalid_ip"
case FallbackReasonUnknown:
return "unknown"
default:
return "unknown"
}
}
// Result captures a resolver result, including fallback metadata.
//
// On strict success Source identifies the extraction source that produced IP.
// On fallback success Source is SourceRemoteAddr for RemoteAddrFallback or
// SourceStaticFallback for StaticFallback, and FallbackReason carries the
// strict failure category that triggered the fallback. On error Source may
// still identify the source that failed.
type Result struct {
// Extraction contains the IP and source metadata. It may still contain a
// Source when Err is non-nil.
Extraction
// Err is the strict extraction error, or nil when strict extraction or
// operational fallback produced a usable IP.
Err error
// FallbackUsed reports whether ResolveOperational returned a configured
// fallback result instead of the strict extraction result.
FallbackUsed bool
// FallbackReason reports why operational fallback was used.
FallbackReason FallbackReason
}
// OK reports whether the resolution produced a usable IP without error.
func (r Result) OK() bool {
return r.Err == nil && r.IP.IsValid()
}
// Classify returns a coarse result kind suitable for policy and metrics
// labels.
func (r Result) Classify() ResultKind {
if r.FallbackUsed {
return ResultFallback
}
return ClassifyError(r.Err)
}
// Resolver resolves client IP information using the configured trust policy.
//
// Resolver instances are safe for concurrent reuse.
type Resolver struct {
extractor *extractor
}
// New constructs a Resolver from options. With no options, the resolver uses a
// safe direct-connection configuration that only consults RemoteAddr.
func New(opts ...Option) (*Resolver, error) {
public := options{}
for _, opt := range opts {
if opt == nil {
continue
}
opt.applyOption(&public)
}
extractor, err := newExtractor(public)
if err != nil {
return nil, err
}
return &Resolver{extractor: extractor}, nil
}
// Resolve resolves client IP information without fallback.
//
// Use Resolve for security-sensitive decisions. A nil request returns a Result
// with ErrNilRequest. Request context cancellation and deadline errors are
// terminal. Valid resolvers notify Observer once per call.
func (r *Resolver) Resolve(req *http.Request) Result {
if r == nil || r.extractor == nil {
return Result{Err: errNilResolverExtractor}
}
if req == nil {
result := Result{Err: ErrNilRequest}
r.observe(context.Background(), result)
return result
}
result := r.resolveStrictRequest(req)
r.observe(req.Context(), result)
return result
}
// ResolveOperational resolves client IP information with per-call best-effort
// fallback. When fallback succeeds, Err is nil and fallback metadata is set.
//
// Use ResolveOperational only for analytics, logging, and other paths where a
// best-effort answer is acceptable. Context cancellation and deadline errors do
// not fall back. Valid resolvers notify Observer once with the final Result.
func (r *Resolver) ResolveOperational(req *http.Request, fallback Fallback) Result {
if r == nil || r.extractor == nil {
return Result{Err: errNilResolverExtractor}
}
if req == nil {
result := Result{Err: ErrNilRequest}
r.observe(context.Background(), result)
return result
}
strict := r.resolveStrictRequest(req)
result := strict
if strict.Err != nil && !isResolverTerminalContextError(strict.Err) {
if resolved, ok := r.applyFallback(req.RemoteAddr, fallback, strict.Err); ok {
result = resolved
}
}
r.observe(req.Context(), result)
return result
}
// ResolveInput resolves client IP information from framework-agnostic input.
//
// It has the same strict semantics as Resolve. Input.Context defaults to
// context.Background when nil, and a nil header provider makes header-based
// sources unavailable.
func (r *Resolver) ResolveInput(input Input) Result {
if r == nil || r.extractor == nil {
return Result{Err: errNilResolverExtractor}
}
result := r.resolveStrictInput(input)
r.observe(requestInputContext(input), result)
return result
}
// ResolveInputOperational resolves framework-agnostic input with per-call
// best-effort fallback.
//
// It has the same operational semantics as ResolveOperational but reads from
// Input instead of *http.Request.
func (r *Resolver) ResolveInputOperational(input Input, fallback Fallback) Result {
if r == nil || r.extractor == nil {
return Result{Err: errNilResolverExtractor}
}
strict := r.resolveStrictInput(input)
result := strict
if strict.Err != nil && !isResolverTerminalContextError(strict.Err) {
if resolved, ok := r.applyFallback(input.RemoteAddr, fallback, strict.Err); ok {
result = resolved
}
}
r.observe(requestInputContext(input), result)
return result
}
// ResolveHeaders resolves from plain http.Header and RemoteAddr values.
//
// This is a convenience wrapper around ResolveInput for frameworks and adapters
// that already expose net/http-style headers.
func (r *Resolver) ResolveHeaders(ctx context.Context, remoteAddr string, headers http.Header) Result {
return r.ResolveInput(Input{Context: ctx, RemoteAddr: remoteAddr, Headers: headers})
}
// Middleware returns pass-through net/http middleware that stores Result in
// the request context. It never rejects; downstream handlers decide policy.
func (r *Resolver) Middleware() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
result := r.Resolve(req)
ctx := context.WithValue(req.Context(), resultContextKey{}, result)
next.ServeHTTP(w, req.WithContext(ctx))
})
}
}
// FromContext returns the Result attached by Middleware.
func FromContext(ctx context.Context) (Result, bool) {
if ctx == nil {
return Result{}, false
}
result, ok := ctx.Value(resultContextKey{}).(Result)
return result, ok
}
func (r *Resolver) resolveStrictRequest(req *http.Request) Result {
extraction, err := r.extractor.Extract(req)
return Result{Extraction: extraction, Err: err}
}
func (r *Resolver) resolveStrictInput(input Input) Result {
extraction, err := r.extractor.ExtractInput(input)
return Result{Extraction: extraction, Err: err}
}
func (r *Resolver) observe(ctx context.Context, result Result) {
if ctx == nil {
ctx = context.Background()
}
r.extractor.config.observer.OnResolved(ctx, result)
}
func (r *Resolver) applyFallback(remoteAddr string, fallback Fallback, strictErr error) (Result, bool) {
reason := fallbackReasonFromError(strictErr)
switch fallback.mode {
case fallbackRemoteAddr:
ip, err := ParseRemoteAddr(remoteAddr)
if err == nil {
return Result{
Extraction: Extraction{IP: ip, Source: SourceRemoteAddr},
FallbackUsed: true,
FallbackReason: reason,
}, true
}
case fallbackStaticIP:
if fallback.staticIP.IsValid() {
return Result{
Extraction: Extraction{IP: fallback.staticIP, Source: SourceStaticFallback},
FallbackUsed: true,
FallbackReason: reason,
}, true
}
}
return Result{}, false
}
func fallbackReasonFromError(err error) FallbackReason {
switch ClassifyError(err) {
case ResultSuccess, ResultCanceled, ResultFallback:
return FallbackReasonNone
case ResultUntrusted:
return FallbackReasonUntrustedProxy
case ResultMalformed:
return FallbackReasonMalformedHeader
case ResultUnavailable:
return FallbackReasonSourceUnavailable
case ResultInvalid:
return FallbackReasonInvalidIP
case ResultUnknown:
return FallbackReasonUnknown
default:
return FallbackReasonUnknown
}
}
func isResolverTerminalContextError(err error) bool {
return errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded)
}
package clientip
import (
"encoding/json"
"errors"
"net/textproto"
"strings"
)
type sourceKind uint8
const (
sourceInvalid sourceKind = iota
sourceForwarded
sourceXForwardedFor
sourceXRealIP
sourceRemoteAddr
sourceStaticFallback
sourceHeader
)
const (
builtinSourceNameForwarded = "forwarded"
builtinSourceNameXForwardedFor = "x_forwarded_for"
builtinSourceNameXRealIP = "x_real_ip"
builtinSourceNameRemoteAddr = "remote_addr"
builtinSourceNameStaticFallback = "static_fallback"
)
// Exported source identifiers for comparison and display.
//
// These are vars because Go does not support const structs. Do not reassign
// them; internal code uses builtinSource() so reassignment would only affect
// caller-side comparisons, not extraction behavior.
var (
// SourceForwarded resolves from the RFC7239 Forwarded header.
SourceForwarded = Source{kind: sourceForwarded}
// SourceXForwardedFor resolves from the X-Forwarded-For header.
SourceXForwardedFor = Source{kind: sourceXForwardedFor}
// SourceXRealIP resolves from the X-Real-IP header.
SourceXRealIP = Source{kind: sourceXRealIP}
// SourceRemoteAddr resolves from Request.RemoteAddr.
SourceRemoteAddr = Source{kind: sourceRemoteAddr}
// SourceStaticFallback is a result-only sentinel. It appears in Result.Source
// when ResolveOperational returns a StaticFallback IP. Passing it to
// WithSources is rejected by New.
SourceStaticFallback = Source{kind: sourceStaticFallback}
)
// Source identifies one extraction source in priority order.
//
// Construct Source values with the built-in variables (SourceForwarded,
// SourceXForwardedFor, ...) or HeaderSource for custom headers. Sources stored
// by the resolver are canonicalized at construction time, so == comparison
// against a built-in or HeaderSource-produced value is reliable. Use Equal when
// comparing values that may not yet be canonical (for example, raw user input).
type Source struct {
kind sourceKind
headerName string
}
func builtinSource(kind sourceKind) Source {
return Source{kind: kind}
}
// HeaderSource returns a source backed by a custom HTTP header name.
//
// The name is trimmed and canonicalized as a MIME header key. Built-in header
// names and aliases resolve to their built-in Source values. An empty or
// whitespace-only name produces an invalid Source that New rejects when used in
// WithSources.
func HeaderSource(name string) Source {
return sourceFromString(name)
}
func canonicalSource(source Source) Source {
switch source.kind {
case sourceForwarded, sourceXForwardedFor, sourceXRealIP, sourceRemoteAddr, sourceStaticFallback:
return source
case sourceHeader:
return sourceFromString(source.headerName)
default:
return Source{}
}
}
func sourceFromString(name string) Source {
// Fast path: check exact matches before trimming/normalizing.
// Internal round-trips always use already-normalized names without whitespace.
if s, ok := sourceFromExact(name); ok {
return s
}
raw := strings.TrimSpace(name)
if raw == "" {
return Source{}
}
switch normalizeSourceName(raw) {
case builtinSourceNameForwarded:
return builtinSource(sourceForwarded)
case builtinSourceNameXForwardedFor:
return builtinSource(sourceXForwardedFor)
case builtinSourceNameXRealIP:
return builtinSource(sourceXRealIP)
case builtinSourceNameRemoteAddr:
return builtinSource(sourceRemoteAddr)
case builtinSourceNameStaticFallback:
return builtinSource(sourceStaticFallback)
default:
return Source{kind: sourceHeader, headerName: textproto.CanonicalMIMEHeaderKey(raw)}
}
}
func sourceFromExact(name string) (Source, bool) {
switch name {
case builtinSourceNameForwarded, "Forwarded":
return builtinSource(sourceForwarded), true
case builtinSourceNameXForwardedFor, "X-Forwarded-For":
return builtinSource(sourceXForwardedFor), true
case builtinSourceNameXRealIP, "X-Real-Ip", "X-Real-IP":
return builtinSource(sourceXRealIP), true
case builtinSourceNameRemoteAddr:
return builtinSource(sourceRemoteAddr), true
case builtinSourceNameStaticFallback:
return builtinSource(sourceStaticFallback), true
default:
return Source{}, false
}
}
// canonicalizeSources ensures every source is in canonical form.
//
// Sources stored in config.sourcePriority are always canonical; callers must
// not rely on name()/valid()/headerKey() re-canonicalizing on each call.
func canonicalizeSources(sources []Source) []Source {
resolved := make([]Source, len(sources))
for i, source := range sources {
resolved[i] = canonicalSource(source)
}
return resolved
}
// String returns the canonical source identifier.
//
// Built-in sources use stable snake_case identifiers. Custom header sources use
// a lower-case, underscore-separated form of the canonical header name.
func (s Source) String() string {
return s.name()
}
// Equal reports whether two sources represent the same canonical source.
//
// Equal canonicalizes custom-header names and built-in aliases before
// comparison.
func (s Source) Equal(other Source) bool {
return canonicalSource(s) == canonicalSource(other)
}
func (s Source) name() string {
switch s.kind {
case sourceForwarded:
return builtinSourceNameForwarded
case sourceXForwardedFor:
return builtinSourceNameXForwardedFor
case sourceXRealIP:
return builtinSourceNameXRealIP
case sourceRemoteAddr:
return builtinSourceNameRemoteAddr
case sourceStaticFallback:
return builtinSourceNameStaticFallback
case sourceHeader:
return normalizeSourceName(s.headerName)
default:
return ""
}
}
func (s Source) valid() bool {
if s.kind == sourceHeader {
return s.headerName != ""
}
return s.kind == sourceForwarded ||
s.kind == sourceXForwardedFor ||
s.kind == sourceXRealIP ||
s.kind == sourceRemoteAddr ||
s.kind == sourceStaticFallback
}
func (s Source) headerKey() (string, bool) {
switch s.kind {
case sourceForwarded:
return "Forwarded", true
case sourceXForwardedFor:
return "X-Forwarded-For", true
case sourceXRealIP:
return "X-Real-IP", true
case sourceRemoteAddr, sourceStaticFallback, sourceInvalid:
return "", false
default:
return s.headerName, true
}
}
func (s Source) marshalValue() string {
if s.kind == sourceHeader {
return s.headerName
}
return s.String()
}
// MarshalText returns a stable text form for the source.
//
// Built-in sources serialize as canonical identifiers. Custom header sources
// serialize as canonical MIME header names so they can be losslessly parsed.
func (s Source) MarshalText() ([]byte, error) {
return []byte(s.marshalValue()), nil
}
// UnmarshalText parses a source from a built-in alias or header name.
//
// Empty input produces an invalid Source; New rejects invalid sources in
// WithSources.
func (s *Source) UnmarshalText(text []byte) error {
if s == nil {
return errors.New("clientip.Source: UnmarshalText on nil pointer")
}
*s = sourceFromString(string(text))
return nil
}
// MarshalJSON returns the source as a JSON string.
func (s Source) MarshalJSON() ([]byte, error) {
return json.Marshal(s.marshalValue())
}
// UnmarshalJSON parses a source from a JSON string.
func (s *Source) UnmarshalJSON(data []byte) error {
if s == nil {
return errors.New("clientip.Source: UnmarshalJSON on nil pointer")
}
var raw string
if err := json.Unmarshal(data, &raw); err != nil {
return err
}
*s = sourceFromString(raw)
return nil
}
func normalizeSourceName(headerName string) string {
return strings.ToLower(strings.ReplaceAll(headerName, "-", "_"))
}
func sourceHeaderKeys(sourcePriority []Source) []string {
keys := make([]string, 0, len(sourcePriority))
seen := make(map[string]struct{}, len(sourcePriority))
for _, source := range sourcePriority {
key, ok := sourceHeaderKey(source)
if !ok {
continue
}
if _, duplicate := seen[key]; duplicate {
continue
}
seen[key] = struct{}{}
keys = append(keys, key)
}
return keys
}
func sourceHeaderKey(source Source) (string, bool) {
source = canonicalSource(source)
if !source.valid() {
return "", false
}
key, ok := source.headerKey()
if !ok {
return "", false
}
return key, true
}
package clientip
import (
"net/netip"
"slices"
"strings"
)
type chainPolicy struct {
headerName string
parseValues func([]string) ([]string, error)
parseClientIP func(string) netip.Addr
clientIP clientIPPolicy
trustedProxy proxyPolicy
selection ChainSelection
collectDebugInfo bool
untrustedChainSep string
}
type chainExtractor struct {
policy chainPolicy
}
// extract resolves a chain header source. It returns parser errors from the
// configured policy parser unchanged, and returns extractionFailure for policy
// failures that need source-specific public errors.
func (e chainExtractor) extract(req requestView, source Source) (Extraction, *extractionFailure, error) {
headerValues := req.valuesCanonical(e.policy.headerName)
if len(headerValues) == 0 {
return Extraction{}, errSourceUnavailable, nil
}
if len(e.policy.trustedProxy.TrustedProxyCIDRs) > 0 {
// Do not inspect spoofable header content until the immediate peer is
// a configured trusted proxy.
remoteIP := parseRemoteAddr(req.remoteAddr())
if !isTrustedProxy(remoteIP, e.policy.trustedProxy.TrustedProxyMatch, e.policy.trustedProxy.TrustedProxyCIDRs) {
return Extraction{}, &extractionFailure{
kind: failureUntrustedProxy,
source: source,
chain: strings.Join(headerValues, e.chainSeparator()),
trustedProxyCount: 0,
minTrustedProxies: e.policy.trustedProxy.MinTrustedProxies,
maxTrustedProxies: e.policy.trustedProxy.MaxTrustedProxies,
}, nil
}
}
parts, err := e.policy.parseValues(headerValues)
if err != nil {
return Extraction{}, nil, err
}
if len(parts) == 0 {
return Extraction{}, &extractionFailure{kind: failureEmptyChain, source: source}, nil
}
analysis, clientIP, err := e.analyzeChain(parts)
if err != nil {
return Extraction{}, &extractionFailure{
kind: failureProxyValidation,
source: source,
chain: strings.Join(parts, ", "),
trustedProxyCount: analysis.TrustedCount,
minTrustedProxies: e.policy.trustedProxy.MinTrustedProxies,
maxTrustedProxies: e.policy.trustedProxy.MaxTrustedProxies,
}, nil
}
clientIPStr := parts[analysis.ClientIndex]
disposition := evaluateClientIP(clientIP, e.policy.clientIP)
if disposition != clientIPValid {
return Extraction{}, &extractionFailure{
kind: failureInvalidClientIP,
source: source,
chain: strings.Join(parts, ", "),
index: analysis.ClientIndex,
extractedIP: clientIPStr,
trustedProxyCount: analysis.TrustedCount,
clientIPDisposition: disposition,
}, nil
}
result := Extraction{
IP: normalizeIP(clientIP),
TrustedProxyCount: analysis.TrustedCount,
Source: source,
}
if e.policy.collectDebugInfo {
// DebugInfo is success-only so failed requests do not carry extra
// parsed attacker-controlled chain details through Result by default.
result.DebugInfo = &ChainDebugInfo{
FullChain: slices.Clone(parts),
ClientIndex: analysis.ClientIndex,
TrustedIndices: slices.Clone(analysis.TrustedIndices),
}
}
return result, nil, nil
}
func (e chainExtractor) analyzeChain(parts []string) (chainAnalysis, netip.Addr, error) {
parseClientIP := e.policy.parseClientIP
if parseClientIP == nil {
parseClientIP = parseIP
}
if e.policy.selection == LeftmostUntrustedIP {
return analyzeChainLeftmost(parts, e.policy.trustedProxy, e.policy.collectDebugInfo, parseClientIP)
}
return analyzeChainRightmost(parts, e.policy.trustedProxy, e.policy.collectDebugInfo, parseClientIP)
}
func (e chainExtractor) chainSeparator() string {
if e.policy.untrustedChainSep != "" {
return e.policy.untrustedChainSep
}
return ", "
}
package clientip
import (
"errors"
"fmt"
)
// extractChainSource adapts chain extractor output into orchestration-level
// errors. Source-specific parser adapters have already wrapped syntax and
// length errors; policy failures become public typed errors through
// adaptChainFailure.
func (e *extractor) extractChainSource(
r requestView,
source *configuredSource,
chainTooLongMessage string,
untrustedProxyMessage string,
handleParseError func(error),
) (Extraction, error) {
result, failure, err := source.chain.extract(r, source.source)
if err != nil {
e.handleChainError(r, source.source, err, chainTooLongMessage, handleParseError)
return Extraction{}, err
}
if failure != nil {
if failure.kind == failureSourceUnavailable {
return Extraction{}, source.unavailableErr
}
return Extraction{}, e.adaptChainFailure(r, source.source, failure, untrustedProxyMessage)
}
return result, nil
}
func (e *extractor) extractSingleHeaderSource(r requestView, source *configuredSource) (Extraction, error) {
result, failure := source.single.extract(r, source.source)
if failure != nil {
if failure.kind == failureSourceUnavailable {
return Extraction{}, source.unavailableErr
}
return Extraction{}, e.adaptSingleHeaderFailure(r, source.source, failure)
}
return result, nil
}
func (e *extractor) extractRemoteAddrSource(r requestView, source *configuredSource) (Extraction, error) {
result, failure := source.remote.extract(r.remoteAddr(), source.source)
if failure != nil {
if failure.kind == failureSourceUnavailable {
return Extraction{}, source.unavailableErr
}
return Extraction{}, adaptRemoteAddrFailure(failure, source.source)
}
return result, nil
}
// logSecurityWarning emits stable base attributes with the request context so
// caller-provided loggers can attach trace/span metadata.
func (e *extractor) logSecurityWarning(r requestView, source Source, event, msg string, attrs ...any) {
if e.config.loggerNoop {
return
}
baseAttrs := []any{
"event", event,
"source", source.String(),
"path", r.path(),
"remote_addr", r.remoteAddr(),
}
baseAttrs = append(baseAttrs, attrs...)
e.config.logger.WarnContext(r.context(), msg, baseAttrs...)
}
func proxyValidationWarningDetails(err error) (event, msg string, ok bool) {
switch {
case errors.Is(err, ErrNoTrustedProxies):
return SecurityEventNoTrustedProxies, "no trusted proxies found in request chain", true
case errors.Is(err, ErrTooFewTrustedProxies):
return SecurityEventTooFewTrustedProxies, "trusted proxy count below configured minimum", true
case errors.Is(err, ErrTooManyTrustedProxies):
return SecurityEventTooManyTrustedProxies, "trusted proxy count exceeds configured maximum", true
default:
return "", "", false
}
}
func (e *extractor) logProxyValidationWarning(r requestView, source Source, err error) {
if e.config.loggerNoop {
return
}
event, msg, ok := proxyValidationWarningDetails(err)
if !ok {
return
}
var proxyErr *ProxyValidationError
if errors.As(err, &proxyErr) {
e.logSecurityWarning(
r, source, event, msg,
"trusted_proxy_count", proxyErr.TrustedProxyCount,
"min_trusted_proxies", proxyErr.MinTrustedProxies,
"max_trusted_proxies", proxyErr.MaxTrustedProxies,
)
return
}
e.logSecurityWarning(r, source, event, msg)
}
func (e *extractor) handleChainError(
r requestView,
source Source,
err error,
chainTooLongMessage string,
handleParseError func(error),
) {
if errors.Is(err, ErrChainTooLong) && !e.config.loggerNoop {
var chainErr *ChainTooLongError
if errors.As(err, &chainErr) {
e.logSecurityWarning(
r, source, SecurityEventChainTooLong, chainTooLongMessage,
"chain_length", chainErr.ChainLength,
"max_length", chainErr.MaxLength,
)
} else {
e.logSecurityWarning(r, source, SecurityEventChainTooLong, chainTooLongMessage)
}
}
if handleParseError != nil {
handleParseError(err)
}
}
// adaptChainFailure converts chain-source policy failures into public errors.
// Keep new chain failure kinds here so logging and typed errors stay centralized.
func (e *extractor) adaptChainFailure(r requestView, source Source, failure *extractionFailure, untrustedProxyMessage string) error {
if failure == nil {
return &ExtractionError{Err: ErrInvalidIP, Source: source}
}
switch failure.kind {
case failureSourceUnavailable:
return &ExtractionError{Err: ErrSourceUnavailable, Source: source}
case failureUntrustedProxy:
e.logSecurityWarning(r, source, SecurityEventUntrustedProxy, untrustedProxyMessage)
return &ProxyValidationError{
ExtractionError: ExtractionError{Err: ErrUntrustedProxy, Source: source},
Chain: failure.chain,
TrustedProxyCount: failure.trustedProxyCount,
MinTrustedProxies: failure.minTrustedProxies,
MaxTrustedProxies: failure.maxTrustedProxies,
}
case failureProxyValidation:
err := &ProxyValidationError{
ExtractionError: ExtractionError{
Err: proxyCountError(failure.trustedProxyCount, e.config.proxy),
Source: source,
},
Chain: failure.chain,
TrustedProxyCount: failure.trustedProxyCount,
MinTrustedProxies: failure.minTrustedProxies,
MaxTrustedProxies: failure.maxTrustedProxies,
}
e.logProxyValidationWarning(r, source, err)
return err
case failureEmptyChain:
return &ExtractionError{Err: ErrInvalidIP, Source: source}
case failureInvalidClientIP:
return &InvalidIPError{
ExtractionError: ExtractionError{Err: ErrInvalidIP, Source: source},
Chain: failure.chain,
ExtractedIP: failure.extractedIP,
Index: failure.index,
TrustedProxies: failure.trustedProxyCount,
}
default:
return &ExtractionError{Err: ErrInvalidIP, Source: source}
}
}
// adaptSingleHeaderFailure converts single-header policy failures into public
// errors and emits the spoofing-related warnings for duplicate/untrusted input.
func (e *extractor) adaptSingleHeaderFailure(r requestView, sourceName Source, failure *extractionFailure) error {
if failure == nil {
return &ExtractionError{Err: ErrInvalidIP, Source: sourceName}
}
switch failure.kind {
case failureSourceUnavailable:
return &ExtractionError{Err: ErrSourceUnavailable, Source: sourceName}
case failureMultipleHeaders:
e.logSecurityWarning(
r, sourceName, SecurityEventMultipleHeaders, "multiple single-IP headers received - possible spoofing attempt",
"header", failure.headerName,
"header_count", failure.headerCount,
)
return &MultipleHeadersError{
ExtractionError: ExtractionError{Err: ErrMultipleSingleIPHeaders, Source: sourceName},
HeaderCount: failure.headerCount,
HeaderName: failure.headerName,
RemoteAddr: failure.remoteAddr,
}
case failureUntrustedProxy:
e.logSecurityWarning(
r, sourceName, SecurityEventUntrustedProxy, "request received from untrusted proxy while single-header source is present",
"header", failure.headerName,
)
return &ProxyValidationError{
ExtractionError: ExtractionError{Err: ErrUntrustedProxy, Source: sourceName},
Chain: failure.chain,
TrustedProxyCount: failure.trustedProxyCount,
MinTrustedProxies: failure.minTrustedProxies,
MaxTrustedProxies: failure.maxTrustedProxies,
}
case failureInvalidClientIP:
return &InvalidIPError{
ExtractionError: ExtractionError{Err: ErrInvalidIP, Source: sourceName},
ExtractedIP: failure.extractedIP,
}
default:
return &ExtractionError{Err: ErrInvalidIP, Source: sourceName}
}
}
// adaptRemoteAddrFailure converts RemoteAddr parsing/policy failures into the
// public RemoteAddrError shape.
func adaptRemoteAddrFailure(failure *extractionFailure, sourceName Source) error {
if failure == nil {
return &ExtractionError{Err: ErrInvalidIP, Source: sourceName}
}
switch failure.kind {
case failureSourceUnavailable:
return &ExtractionError{Err: ErrSourceUnavailable, Source: sourceName}
case failureInvalidClientIP:
return &RemoteAddrError{
ExtractionError: ExtractionError{Err: ErrInvalidIP, Source: sourceName},
RemoteAddr: failure.remoteAddr,
}
default:
return &ExtractionError{Err: ErrInvalidIP, Source: sourceName}
}
}
// adaptForwardedParseError marks malformed RFC 7239 input with the Forwarded
// sentinel while preserving chain-length errors as their own category.
func adaptForwardedParseError(err error, source Source, extractor *extractor) error {
if chainErr := adaptChainLengthError(err, source, extractor); chainErr != nil {
return chainErr
}
return &ExtractionError{
Err: fmt.Errorf("%w: %w", ErrInvalidForwardedHeader, err),
Source: source,
}
}
// adaptXFFParseError currently only maps XFF parser errors that are chain-limit
// failures; other XFF parsing is intentionally permissive.
func adaptXFFParseError(err error, source Source, extractor *extractor) error {
if chainErr := adaptChainLengthError(err, source, extractor); chainErr != nil {
return chainErr
}
return err
}
func adaptChainLengthError(err error, source Source, _ *extractor) error {
var chainErr *chainTooLongParseError
if !errors.As(err, &chainErr) {
return nil
}
return &ChainTooLongError{
ExtractionError: ExtractionError{Err: ErrChainTooLong, Source: source},
ChainLength: chainErr.ChainLength,
MaxLength: chainErr.MaxLength,
}
}
// proxyCountError re-runs count validation to map policy failures onto stable
// public sentinel errors used by errors.Is and Result.Classify.
func proxyCountError(trustedCount int, proxy proxyPolicy) error {
err := validateProxyCountPolicy(trustedCount, proxy)
if err == nil {
return nil
}
switch {
case errors.Is(err, ErrNoTrustedProxies):
return ErrNoTrustedProxies
case errors.Is(err, ErrTooFewTrustedProxies):
return ErrTooFewTrustedProxies
case errors.Is(err, ErrTooManyTrustedProxies):
return ErrTooManyTrustedProxies
default:
return err
}
}
package clientip
type remoteAddrExtractor struct {
clientIPPolicy clientIPPolicy
}
// extract resolves the immediate connecting peer. This is the only source that
// does not depend on trusted proxy configuration.
func (e remoteAddrExtractor) extract(remoteAddr string, source Source) (Extraction, *extractionFailure) {
if remoteAddr == "" {
return Extraction{}, errSourceUnavailable
}
ip := parseRemoteAddr(remoteAddr)
disposition := evaluateClientIP(ip, e.clientIPPolicy)
if disposition != clientIPValid {
return Extraction{}, &extractionFailure{
kind: failureInvalidClientIP,
source: source,
remoteAddr: remoteAddr,
clientIPDisposition: disposition,
}
}
return Extraction{
IP: normalizeIP(ip),
Source: source,
}, nil
}
package clientip
import (
"context"
"net/http"
)
type headerValuesFunc func(name string) []string
// requestView is the internal request adapter shared by *http.Request and
// framework-agnostic Input resolution. It lets source extractors read
// canonical headers, context, path, and RemoteAddr without knowing the
// caller's shape.
type requestView struct {
ctx context.Context
remoteAddrValue string
pathValue string
headerMap map[string][]string
headerFunc headerValuesFunc
}
func (r requestView) context() context.Context {
if r.ctx == nil {
return context.Background()
}
return r.ctx
}
func (r requestView) remoteAddr() string {
return r.remoteAddrValue
}
func (r requestView) path() string {
return r.pathValue
}
// valuesCanonical performs a header lookup. Callers must pass an
// already-canonical MIME header key (e.g. "X-Forwarded-For").
func (r requestView) valuesCanonical(name string) []string {
if r.headerMap != nil {
return r.headerMap[name]
}
if r.headerFunc != nil {
return r.headerFunc(name)
}
return nil
}
// requestViewFromRequest uses http.Header directly. Configured source header
// keys are canonicalized at construction time, so lookups avoid repeated
// canonicalization on hot paths.
func requestViewFromRequest(r *http.Request) requestView {
if r == nil {
return requestView{}
}
view := requestView{
ctx: r.Context(),
remoteAddrValue: r.RemoteAddr,
headerMap: map[string][]string(r.Header),
}
if r.URL != nil {
view.pathValue = r.URL.Path
}
return view
}
// requestViewFromInput keeps header access as a provider call so frameworks
// can preserve repeated header-line semantics without materializing
// http.Header.
func requestViewFromInput(input Input) requestView {
view := requestView{
ctx: requestInputContext(input),
remoteAddrValue: input.RemoteAddr,
}
if input.Headers == nil {
return view
}
if h, ok := input.Headers.(HeaderValuesFunc); ok {
if h == nil {
return view
}
view.headerFunc = headerValuesFunc(h)
return view
}
// Deliberately catch typed nils (e.g. (*myHeaders)(nil)) so they behave
// the same as an unset Headers field rather than panicking at call time.
if isNilValue(input.Headers) {
return view
}
view.headerFunc = func(name string) []string {
return input.Headers.Values(name)
}
return view
}
package clientip
type singleHeaderPolicy struct {
headerName string
clientIP clientIPPolicy
trustedProxy proxyPolicy
}
type singleHeaderExtractor struct {
policy singleHeaderPolicy
}
// extract resolves a single-IP header source. Unlike chain headers, duplicate
// header lines are terminal because there is no ordering rule that can safely
// choose between multiple asserted client IPs.
func (e singleHeaderExtractor) extract(req requestView, source Source) (Extraction, *extractionFailure) {
headerValues := req.valuesCanonical(e.policy.headerName)
if len(headerValues) == 0 {
return Extraction{}, errSourceUnavailable
}
if len(headerValues) > 1 {
return Extraction{}, &extractionFailure{
kind: failureMultipleHeaders,
source: source,
headerName: e.policy.headerName,
headerCount: len(headerValues),
remoteAddr: req.remoteAddr(),
}
}
headerValue := headerValues[0]
if headerValue == "" {
return Extraction{}, errSourceUnavailable
}
if len(e.policy.trustedProxy.TrustedProxyCIDRs) > 0 {
// Single-IP headers are only meaningful when the immediate peer is
// trusted to set or sanitize them.
remoteIP := parseRemoteAddr(req.remoteAddr())
if !isTrustedProxy(remoteIP, e.policy.trustedProxy.TrustedProxyMatch, e.policy.trustedProxy.TrustedProxyCIDRs) {
return Extraction{}, &extractionFailure{
kind: failureUntrustedProxy,
source: source,
headerName: e.policy.headerName,
chain: headerValue,
trustedProxyCount: 0,
minTrustedProxies: e.policy.trustedProxy.MinTrustedProxies,
maxTrustedProxies: e.policy.trustedProxy.MaxTrustedProxies,
}
}
}
ip := parseIP(headerValue)
disposition := evaluateClientIP(ip, e.policy.clientIP)
if disposition != clientIPValid {
return Extraction{}, &extractionFailure{
kind: failureInvalidClientIP,
source: source,
extractedIP: headerValue,
clientIPDisposition: disposition,
}
}
return Extraction{
IP: normalizeIP(ip),
Source: source,
}, nil
}
package clientip
import (
"net/netip"
)
type proxyPolicy struct {
TrustedProxyCIDRs []netip.Prefix
TrustedProxyMatch prefixMatcher
MinTrustedProxies int
MaxTrustedProxies int
}
// chainAnalysis describes the selected client candidate and trusted suffix.
// Indexes refer to parsed chain parts, not byte offsets or raw header lines.
type chainAnalysis struct {
ClientIndex int
TrustedCount int
TrustedIndices []int
}
// isTrustedProxy checks whether ip is inside the configured trusted proxy set.
// The precomputed matcher is the hot path; cidrs is retained as a linear
// fallback for zero-value or manually assembled policy values in tests.
func isTrustedProxy(ip netip.Addr, matcher prefixMatcher, cidrs []netip.Prefix) bool {
if !ip.IsValid() {
return false
}
if matcher.initialized {
return matcher.contains(ip)
}
for _, cidr := range cidrs {
if cidr.Contains(ip) {
return true
}
}
return false
}
// validateProxyCountPolicy validates counts of CIDR-trusted hops only. It does
// not implement count-only trust and cannot make a header source trustworthy.
func validateProxyCountPolicy(trustedCount int, policy proxyPolicy) error {
if len(policy.TrustedProxyCIDRs) > 0 && policy.MinTrustedProxies > 0 && trustedCount == 0 {
return ErrNoTrustedProxies
}
if policy.MinTrustedProxies > 0 && trustedCount < policy.MinTrustedProxies {
return ErrTooFewTrustedProxies
}
if policy.MaxTrustedProxies > 0 && trustedCount > policy.MaxTrustedProxies {
return ErrTooManyTrustedProxies
}
return nil
}
// analyzeChainRightmost walks from the nearest hop to the oldest hop and
// treats the trailing trusted suffix as proxy infrastructure. The first
// non-trusted hop before that suffix is the client candidate; if every hop is
// trusted, the oldest hop is selected and still validated as a client IP.
func analyzeChainRightmost(parts []string, policy proxyPolicy, collectTrustedIndices bool, parseClientIP func(string) netip.Addr) (chainAnalysis, netip.Addr, error) {
trustedCount := 0
clientIndex := 0
clientIP := netip.Addr{}
var trustedIndices []int
if collectTrustedIndices {
trustedIndices = make([]int, 0, len(parts))
}
for i := len(parts) - 1; i >= 0; i-- {
ip := parseClientIP(parts[i])
if !isTrustedProxy(ip, policy.TrustedProxyMatch, policy.TrustedProxyCIDRs) {
clientIndex = i
clientIP = ip
break
}
if collectTrustedIndices {
trustedIndices = append(trustedIndices, i)
}
trustedCount++
clientIP = ip
}
analysis := chainAnalysis{
ClientIndex: clientIndex,
TrustedCount: trustedCount,
TrustedIndices: trustedIndices,
}
if err := validateProxyCountPolicy(trustedCount, policy); err != nil {
return analysis, netip.Addr{}, err
}
return analysis, clientIP, nil
}
// analyzeChainLeftmost still validates the trailing trusted suffix, then
// selects the earliest untrusted hop, or the oldest hop if every hop is
// trusted. This mode assumes trusted proxies produced or sanitized the full
// chain; otherwise leftmost values are client-controlled.
func analyzeChainLeftmost(parts []string, policy proxyPolicy, collectTrustedIndices bool, parseClientIP func(string) netip.Addr) (chainAnalysis, netip.Addr, error) {
if len(policy.TrustedProxyCIDRs) == 0 {
analysis := chainAnalysis{ClientIndex: 0, TrustedCount: 0}
return analysis, parseClientIP(parts[0]), nil
}
trustedCount := 0
leftmostUntrustedIndex := -1
leftmostUntrustedIP := netip.Addr{}
hasLeftmostUntrusted := false
fallbackClientIndex := 0
fallbackClientIP := netip.Addr{}
hasFallbackClient := false
var trustedIndices []int
if collectTrustedIndices {
trustedIndices = make([]int, 0, len(parts))
}
stillTrailingTrusted := true
for i := len(parts) - 1; i >= 0; i-- {
ip := parseClientIP(parts[i])
trusted := isTrustedProxy(ip, policy.TrustedProxyMatch, policy.TrustedProxyCIDRs)
if stillTrailingTrusted && trusted {
if collectTrustedIndices {
trustedIndices = append(trustedIndices, i)
}
trustedCount++
continue
}
if stillTrailingTrusted {
fallbackClientIndex = i
fallbackClientIP = ip
hasFallbackClient = true
}
stillTrailingTrusted = false
if !trusted {
leftmostUntrustedIndex = i
leftmostUntrustedIP = ip
hasLeftmostUntrusted = true
}
}
analysis := chainAnalysis{TrustedCount: trustedCount}
if collectTrustedIndices {
analysis.TrustedIndices = trustedIndices
}
if err := validateProxyCountPolicy(trustedCount, policy); err != nil {
return analysis, netip.Addr{}, err
}
if hasLeftmostUntrusted {
analysis.ClientIndex = leftmostUntrustedIndex
return analysis, leftmostUntrustedIP, nil
}
if hasFallbackClient {
analysis.ClientIndex = fallbackClientIndex
return analysis, fallbackClientIP, nil
}
analysis.ClientIndex = 0
return analysis, parseClientIP(parts[analysis.ClientIndex]), nil
}
package clientip
import "net/netip"
type clientIPPolicy struct {
AllowPrivateIPs bool
AllowReservedClientPrefixes []netip.Prefix
}
type clientIPDisposition int
const (
clientIPInvalid clientIPDisposition = iota
clientIPValid
clientIPReserved
clientIPPrivate
)
var (
reservedClientIPv4Prefixes = []netip.Prefix{
mustParsePrefix("0.0.0.0/8"),
mustParsePrefix("100.64.0.0/10"),
mustParsePrefix("192.0.0.0/24"),
mustParsePrefix("192.0.2.0/24"),
mustParsePrefix("198.18.0.0/15"),
mustParsePrefix("198.51.100.0/24"),
mustParsePrefix("203.0.113.0/24"),
mustParsePrefix("240.0.0.0/4"),
}
reservedClientIPv6Prefixes = []netip.Prefix{
mustParsePrefix("64:ff9b::/96"),
mustParsePrefix("64:ff9b:1::/48"),
mustParsePrefix("100::/64"),
mustParsePrefix("2001:2::/48"),
mustParsePrefix("2001:db8::/32"),
mustParsePrefix("2001:20::/28"),
}
)
// ipv4SpecialFirstOctet marks first octets that appear in any special IPv4 range
// (private, reserved, loopback, link-local, multicast). If the first octet is not
// marked, the address is guaranteed to be a valid public IPv4 — allowing us to
// skip all individual checks in evaluateClientIP.
var ipv4SpecialFirstOctet [256]bool
func init() {
// Every IPv4 prefix that evaluateClientIP may treat as non-public.
// This must cover the same ranges as the checks in evaluateClientIP:
// IsLoopback, IsLinkLocalUnicast, IsMulticast, IsUnspecified, IsPrivate,
// plus all entries in reservedClientIPv4Prefixes.
specialRanges := append([]netip.Prefix{
mustParsePrefix("0.0.0.0/8"), // IsUnspecified
mustParsePrefix("10.0.0.0/8"), // IsPrivate
mustParsePrefix("127.0.0.0/8"), // IsLoopback
mustParsePrefix("169.254.0.0/16"), // IsLinkLocalUnicast
mustParsePrefix("172.16.0.0/12"), // IsPrivate
mustParsePrefix("192.168.0.0/16"), // IsPrivate
mustParsePrefix("224.0.0.0/3"), // IsMulticast + future reserved (224.0.0.0–255.255.255.255)
}, reservedClientIPv4Prefixes...)
for _, prefix := range specialRanges {
markIPv4SpecialOctets(prefix)
}
}
// markIPv4SpecialOctets marks all first octets covered by prefix in the lookup table.
func markIPv4SpecialOctets(prefix netip.Prefix) {
first := prefix.Addr().As4()[0]
bits := prefix.Bits()
if bits >= 8 {
ipv4SpecialFirstOctet[first] = true
return
}
// Prefix wider than /8 — covers multiple first octets.
count := 1 << (8 - bits)
for i := 0; i < count; i++ {
ipv4SpecialFirstOctet[int(first)+i] = true
}
}
func evaluateClientIP(ip netip.Addr, policy clientIPPolicy) clientIPDisposition {
if !ip.IsValid() {
return clientIPInvalid
}
// Fast path: IPv4 with first octet not in any special range is always
// a valid public address. This avoids 6+ sequential method calls for the
// common case.
if ip.Is4() && !ipv4SpecialFirstOctet[ip.As4()[0]] {
return clientIPValid
}
if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsMulticast() || ip.IsUnspecified() {
return clientIPInvalid
}
if isReservedIP(ip) && !isAllowlistedReservedClientIP(ip, policy.AllowReservedClientPrefixes) {
return clientIPReserved
}
if !policy.AllowPrivateIPs && ip.IsPrivate() {
return clientIPPrivate
}
return clientIPValid
}
func isReservedIP(ip netip.Addr) bool {
if !ip.IsValid() {
return false
}
ip = normalizeIP(ip)
prefixes := reservedClientIPv6Prefixes
if ip.Is4() {
prefixes = reservedClientIPv4Prefixes
}
for _, prefix := range prefixes {
if prefix.Contains(ip) {
return true
}
}
return false
}
func isAllowlistedReservedClientIP(ip netip.Addr, allowlist []netip.Prefix) bool {
if len(allowlist) == 0 || !ip.IsValid() {
return false
}
ip = normalizeIP(ip)
for _, prefix := range allowlist {
if prefix.Contains(ip) {
return true
}
}
return false
}
package clientip
import "net/netip"
type prefixMatcher struct {
initialized bool
ipv4Root *prefixTrieNode
ipv6Root *prefixTrieNode
}
// prefixTrieNode is a binary prefix trie node. A terminal node means every
// address below that node matches a configured prefix; terminal on the root
// represents /0.
type prefixTrieNode struct {
children [2]*prefixTrieNode
terminal bool
}
// newPrefixMatcher builds separate IPv4 and IPv6 tries so hot-path trust checks
// do not scan every configured CIDR.
func newPrefixMatcher(prefixes []netip.Prefix) prefixMatcher {
matcher := prefixMatcher{}
if len(prefixes) == 0 {
return matcher
}
matcher.initialized = true
for _, prefix := range prefixes {
addr := prefix.Addr()
if !addr.IsValid() {
continue
}
bits := prefix.Bits()
if bits < 0 {
continue
}
if bits > addr.BitLen() {
bits = addr.BitLen()
}
if addr.Is4() {
if matcher.ipv4Root == nil {
matcher.ipv4Root = &prefixTrieNode{}
}
bytes := addr.As4()
insertPrefix(matcher.ipv4Root, bytes[:], bits)
continue
}
if matcher.ipv6Root == nil {
matcher.ipv6Root = &prefixTrieNode{}
}
bytes := addr.As16()
insertPrefix(matcher.ipv6Root, bytes[:], bits)
}
return matcher
}
func (m prefixMatcher) contains(ip netip.Addr) bool {
if !m.initialized || !ip.IsValid() {
return false
}
if ip.Is4() {
if m.ipv4Root == nil {
return false
}
bytes := ip.As4()
return trieContains(m.ipv4Root, bytes[:])
}
if m.ipv6Root == nil {
return false
}
bytes := ip.As16()
return trieContains(m.ipv6Root, bytes[:])
}
// insertPrefix records the first bits of addr as a trusted prefix.
func insertPrefix(root *prefixTrieNode, addr []byte, bits int) {
node := root
if bits == 0 {
node.terminal = true
return
}
for bitIndex := 0; bitIndex < bits; bitIndex++ {
bit := addrBit(addr, bitIndex)
child := node.children[bit]
if child == nil {
child = &prefixTrieNode{}
node.children[bit] = child
}
node = child
}
node.terminal = true
}
// trieContains reports whether addr falls under any terminal prefix node.
func trieContains(root *prefixTrieNode, addr []byte) bool {
node := root
if node == nil {
return false
}
if node.terminal {
return true
}
for _, octet := range addr {
for bit := 7; bit >= 0; bit-- {
node = node.children[(octet>>bit)&1]
if node == nil {
return false
}
if node.terminal {
return true
}
}
}
return false
}
// addrBit reads address bits in network byte order, most significant bit first.
func addrBit(addr []byte, bitIndex int) int {
byteIndex := bitIndex / 8
shift := 7 - (bitIndex % 8)
if ((addr[byteIndex] >> shift) & 1) == 1 {
return 1
}
return 0
}
package clientip
import (
"errors"
"fmt"
"net/netip"
)
var (
// ErrNoTrustedProxies indicates no trusted proxies were found in a parsed
// chain when at least one is required.
ErrNoTrustedProxies = errors.New("no trusted proxies found in proxy chain")
// ErrSourceUnavailable indicates the selected source is not present on the
// request.
ErrSourceUnavailable = errors.New("source unavailable")
// ErrNilRequest indicates a nil *http.Request was passed to Resolve.
ErrNilRequest = errors.New("request cannot be nil")
// ErrMultipleSingleIPHeaders indicates multiple values were provided for a
// single-IP header source.
ErrMultipleSingleIPHeaders = errors.New("multiple single-IP headers received")
// ErrUntrustedProxy indicates a header source was provided by an untrusted
// immediate proxy.
ErrUntrustedProxy = errors.New("request from untrusted proxy")
// ErrTooFewTrustedProxies indicates trusted proxies in the chain are below
// the configured minimum.
ErrTooFewTrustedProxies = errors.New("too few trusted proxies in proxy chain")
// ErrTooManyTrustedProxies indicates trusted proxies in the chain exceed the
// configured maximum.
ErrTooManyTrustedProxies = errors.New("too many trusted proxies in proxy chain")
// ErrInvalidIP indicates the extracted client IP is invalid or implausible.
ErrInvalidIP = errors.New("invalid or implausible IP address")
// ErrChainTooLong indicates a Forwarded/X-Forwarded-For chain exceeded the
// configured maximum length.
ErrChainTooLong = errors.New("proxy chain too long")
// ErrInvalidForwardedHeader indicates a malformed RFC7239 Forwarded header.
ErrInvalidForwardedHeader = errors.New("invalid Forwarded header")
)
// ExtractionError wraps a source-specific extraction failure.
type ExtractionError struct {
// Err is the underlying sentinel or wrapped error.
Err error
// Source is the extraction source associated with Err.
Source Source
}
// Error implements error.
func (e *ExtractionError) Error() string {
return fmt.Sprintf("%s: %v", e.Source.String(), e.Err)
}
// Unwrap returns the underlying sentinel or wrapped error.
func (e *ExtractionError) Unwrap() error {
return e.Err
}
// SourceName returns the source identifier associated with this error.
func (e *ExtractionError) SourceName() string {
return e.Source.String()
}
// SourceValue returns the source associated with this error.
func (e *ExtractionError) SourceValue() Source {
return e.Source
}
// MultipleHeadersError reports duplicate header-line values for a source that
// expects a single header line.
type MultipleHeadersError struct {
ExtractionError
// HeaderCount is the number of header-line values received for HeaderName.
HeaderCount int
// HeaderName is the canonical header name, when known.
HeaderName string
// RemoteAddr is the request RemoteAddr observed while validating the header.
RemoteAddr string
}
// Error implements error.
func (e *MultipleHeadersError) Error() string {
if e.HeaderName != "" {
return fmt.Sprintf("%s: %v (header=%q, header_count=%d, remote_addr=%s)",
e.Source.String(), e.Err, e.HeaderName, e.HeaderCount, e.RemoteAddr)
}
return fmt.Sprintf("%s: %v (header_count=%d, remote_addr=%s)",
e.Source.String(), e.Err, e.HeaderCount, e.RemoteAddr)
}
// ProxyValidationError reports failures from trusted-proxy chain validation.
type ProxyValidationError struct {
ExtractionError
// Chain is the parsed proxy chain rendered as a comma-separated string.
Chain string
// TrustedProxyCount is the number of trusted proxies found in the chain.
TrustedProxyCount int
// MinTrustedProxies is the configured minimum trusted-proxy count.
MinTrustedProxies int
// MaxTrustedProxies is the configured maximum trusted-proxy count.
MaxTrustedProxies int
}
// Error implements error.
func (e *ProxyValidationError) Error() string {
return fmt.Sprintf("%s: %v (chain=%q, trusted_count=%d, min=%d, max=%d)",
e.Source.String(), e.Err, e.Chain, e.TrustedProxyCount, e.MinTrustedProxies, e.MaxTrustedProxies)
}
// InvalidIPError reports an invalid or implausible extracted client IP.
type InvalidIPError struct {
ExtractionError
// Chain is the parsed proxy chain when the invalid IP came from a chain
// source.
Chain string
// ExtractedIP is the invalid or implausible IP string.
ExtractedIP string
// Index is the selected client index in Chain, or 0 for non-chain sources.
Index int
// TrustedProxies is the number of trusted proxies found in Chain.
TrustedProxies int
}
// Error implements error.
func (e *InvalidIPError) Error() string {
if e.Chain != "" {
return fmt.Sprintf("%s: %v (chain=%q, extracted_ip=%q, index=%d, trusted_proxies=%d)",
e.Source.String(), e.Err, e.Chain, e.ExtractedIP, e.Index, e.TrustedProxies)
}
if e.ExtractedIP != "" {
return fmt.Sprintf("%s: %v (ip=%q)", e.Source.String(), e.Err, e.ExtractedIP)
}
return e.ExtractionError.Error()
}
// RemoteAddrError reports an invalid or implausible Request.RemoteAddr value.
type RemoteAddrError struct {
ExtractionError
// RemoteAddr is the original remote address string.
RemoteAddr string
}
// Error implements error.
func (e *RemoteAddrError) Error() string {
return fmt.Sprintf("%s: %v (remote_addr=%q)", e.Source.String(), e.Err, e.RemoteAddr)
}
// ChainTooLongError reports an overlong Forwarded/X-Forwarded-For chain.
type ChainTooLongError struct {
ExtractionError
// ChainLength is the number of parsed chain entries.
ChainLength int
// MaxLength is the configured maximum chain length.
MaxLength int
}
// Error implements error.
func (e *ChainTooLongError) Error() string {
return fmt.Sprintf("%s: %v (chain_length=%d, max_length=%d)",
e.Source.String(), e.Err, e.ChainLength, e.MaxLength)
}
// ChainDebugInfo describes parsed chain-analysis details for diagnostics.
type ChainDebugInfo struct {
// FullChain contains the parsed Forwarded or X-Forwarded-For chain.
FullChain []string
// ClientIndex is the index selected as the client candidate.
ClientIndex int
// TrustedIndices are the indexes identified as trusted proxies.
TrustedIndices []int
}
// Extraction contains extraction metadata.
//
// On error, Source may still be set when available.
//
// For additional diagnostics (such as chain details or trusted-proxy counts),
// inspect typed errors like ProxyValidationError and InvalidIPError.
type Extraction struct {
// IP is the normalized client IP when extraction succeeds.
IP netip.Addr
// Source identifies where IP came from. On error it may identify the source
// that failed.
Source Source
// TrustedProxyCount is the number of trusted proxies observed in a chain
// source.
TrustedProxyCount int
// DebugInfo contains optional parsed chain details when WithDebugInfo is
// enabled and a chain source succeeds.
DebugInfo *ChainDebugInfo
}
// ParseCIDRs parses one or more CIDR strings.
//
// The returned prefixes are suitable for WithTrustedProxies or
// WithAllowedReservedClientPrefixes.
func ParseCIDRs(cidrs ...string) ([]netip.Prefix, error) {
prefixes := make([]netip.Prefix, 0, len(cidrs))
for _, cidr := range cidrs {
prefix, err := netip.ParsePrefix(cidr)
if err != nil {
return nil, fmt.Errorf("invalid CIDR %q: %w", cidr, err)
}
prefixes = append(prefixes, prefix)
}
return prefixes, nil
}