You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
324 lines
8.3 KiB
324 lines
8.3 KiB
package http |
|
|
|
import ( |
|
"bytes" |
|
"context" |
|
"crypto/tls" |
|
"fmt" |
|
"git.diulo.com/mogfee/kit/encoding" |
|
"git.diulo.com/mogfee/kit/errors" |
|
"git.diulo.com/mogfee/kit/internal/host" |
|
"git.diulo.com/mogfee/kit/internal/httputil" |
|
"git.diulo.com/mogfee/kit/middleware" |
|
"git.diulo.com/mogfee/kit/registry" |
|
"git.diulo.com/mogfee/kit/selector" |
|
"git.diulo.com/mogfee/kit/transport" |
|
"io" |
|
"net/http" |
|
"time" |
|
) |
|
|
|
type DecodeErrorFunc func(ctx context.Context, res *http.Response) error |
|
type EncodeRequestFunc func(ctx context.Context, contentType string, in any) (body []byte, err error) |
|
type DecodeResponseFunc func(ctx context.Context, res *http.Response, out any) error |
|
type ClientOption func(options *clientOptions) |
|
|
|
type clientOptions struct { |
|
ctx context.Context |
|
tlsConf *tls.Config |
|
timeout time.Duration |
|
endpoint string |
|
userAgent string |
|
encoder EncodeRequestFunc |
|
decoder DecodeResponseFunc |
|
errorDecoder DecodeErrorFunc |
|
transport http.RoundTripper |
|
nodeFilters []selector.NodeFilter |
|
discovery registry.Discovery |
|
middleware []middleware.Middleware |
|
block bool |
|
} |
|
|
|
// WithTransport with client transport. |
|
func WithTransport(trans http.RoundTripper) ClientOption { |
|
return func(o *clientOptions) { |
|
o.transport = trans |
|
} |
|
} |
|
|
|
// WithTimeout with client request timeout. |
|
func WithTimeout(d time.Duration) ClientOption { |
|
return func(o *clientOptions) { |
|
o.timeout = d |
|
} |
|
} |
|
|
|
// WithUserAgent with client user agent. |
|
func WithUserAgent(ua string) ClientOption { |
|
return func(o *clientOptions) { |
|
o.userAgent = ua |
|
} |
|
} |
|
|
|
// WithMiddleware with client middleware. |
|
func WithMiddleware(m ...middleware.Middleware) ClientOption { |
|
return func(o *clientOptions) { |
|
o.middleware = m |
|
} |
|
} |
|
|
|
// WithEndpoint with client addr. |
|
func WithEndpoint(endpoint string) ClientOption { |
|
return func(o *clientOptions) { |
|
o.endpoint = endpoint |
|
} |
|
} |
|
|
|
// WithRequestEncoder with client request encoder. |
|
func WithRequestEncoder(encoder EncodeRequestFunc) ClientOption { |
|
return func(o *clientOptions) { |
|
o.encoder = encoder |
|
} |
|
} |
|
|
|
// WithResponseDecoder with client response decoder. |
|
func WithResponseDecoder(decoder DecodeResponseFunc) ClientOption { |
|
return func(o *clientOptions) { |
|
o.decoder = decoder |
|
} |
|
} |
|
|
|
// WithErrorDecoder with client error decoder. |
|
func WithErrorDecoder(errorDecoder DecodeErrorFunc) ClientOption { |
|
return func(o *clientOptions) { |
|
o.errorDecoder = errorDecoder |
|
} |
|
} |
|
|
|
// WithDiscovery with client discovery. |
|
func WithDiscovery(d registry.Discovery) ClientOption { |
|
return func(o *clientOptions) { |
|
o.discovery = d |
|
} |
|
} |
|
|
|
// WithNodeFilter with select filters |
|
func WithNodeFilter(filters ...selector.NodeFilter) ClientOption { |
|
return func(o *clientOptions) { |
|
o.nodeFilters = filters |
|
} |
|
} |
|
|
|
// WithBlock with client block. |
|
func WithBlock() ClientOption { |
|
return func(o *clientOptions) { |
|
o.block = true |
|
} |
|
} |
|
|
|
// WithTLSConfig with tls config. |
|
func WithTLSConfig(c *tls.Config) ClientOption { |
|
return func(o *clientOptions) { |
|
o.tlsConf = c |
|
} |
|
} |
|
|
|
type Client struct { |
|
opts clientOptions |
|
targe *Target |
|
r *resolver |
|
cc *http.Client |
|
insecure bool |
|
selector selector.Selector |
|
} |
|
|
|
func NewClient(ctx context.Context, opts ...ClientOption) (*Client, error) { |
|
options := clientOptions{ |
|
ctx: ctx, |
|
timeout: 2000 * time.Millisecond, |
|
encoder: DefaultrequestEncoder, |
|
decoder: DefaultResponseDecoder, |
|
errorDecoder: DefaultErrorDecoder, |
|
transport: http.DefaultTransport, |
|
} |
|
for _, o := range opts { |
|
o(&options) |
|
} |
|
if options.tlsConf != nil { |
|
if tr, ok := options.transport.(*http.Transport); ok { |
|
tr.TLSClientConfig = options.tlsConf |
|
} |
|
} |
|
insecure := options.tlsConf == nil |
|
target, err := parseTarget(options.endpoint, insecure) |
|
if err != nil { |
|
return nil, err |
|
} |
|
selector := selector.GlobalSelector().Build() |
|
var r *resolver |
|
if options.discovery != nil { |
|
if target.Scheme == "discovery" { |
|
if r, err = newResolver(ctx, options.discovery, target, selector, options.block, insecure); err != nil { |
|
return nil, fmt.Errorf("[http client] new resolver failed!err: %v", options.endpoint) |
|
} |
|
} else if _, _, err = host.ExtractHostPort(options.endpoint); err != nil { |
|
return nil, fmt.Errorf("[http client] invalid endpoint format: %v", options.endpoint) |
|
} |
|
} |
|
|
|
return &Client{ |
|
opts: options, |
|
targe: target, |
|
r: r, |
|
cc: &http.Client{ |
|
Timeout: options.timeout, |
|
Transport: options.transport, |
|
}, |
|
insecure: insecure, |
|
selector: selector, |
|
}, nil |
|
} |
|
func (client *Client) Invoke(ctx context.Context, method, path string, args any, reply any, opts ...CallOption) error { |
|
var ( |
|
contentType string |
|
body io.Reader |
|
) |
|
c := defaultCallInfo(path) |
|
for _, o := range opts { |
|
if err := o.before(&c); err != nil { |
|
return err |
|
} |
|
} |
|
if args != nil { |
|
data, err := client.opts.encoder(ctx, c.contentType, args) |
|
if err != nil { |
|
return err |
|
} |
|
contentType = c.contentType |
|
body = bytes.NewReader(data) |
|
} |
|
url := fmt.Sprintf("%s://%s%s", client.targe.Scheme, client.targe.Authority, path) |
|
req, err := http.NewRequest(method, url, body) |
|
if err != nil { |
|
return err |
|
} |
|
if contentType != "" { |
|
req.Header.Set("Content-Type", c.contentType) |
|
} |
|
if client.opts.userAgent != "" { |
|
req.Header.Set("User-Agent", client.opts.userAgent) |
|
} |
|
ctx = transport.NewClientContext(ctx, &Transport{ |
|
endpoint: client.opts.endpoint, |
|
operation: c.operation, |
|
reqHeader: headerCarrier(req.Header), |
|
request: req, |
|
pathTemplate: c.pathTemplate, |
|
}) |
|
return client.invoke(ctx, req, args, reply, c, opts...) |
|
} |
|
func (client *Client) invoke(ctx context.Context, req *http.Request, args any, reply any, c callInfo, opts ...CallOption) error { |
|
h := func(ctx context.Context, in any) (any, error) { |
|
res, err := client.do(req.WithContext(ctx)) |
|
if res != nil { |
|
cs := csAttempt{res: res} |
|
for _, o := range opts { |
|
o.after(&c, &cs) |
|
} |
|
} |
|
if err != nil { |
|
return nil, err |
|
} |
|
defer res.Body.Close() |
|
if err = client.opts.decoder(ctx, res, reply); err != nil { |
|
return nil, err |
|
} |
|
return reply, nil |
|
} |
|
var p selector.Peer |
|
ctx = selector.NewPeerContext(ctx, &p) |
|
if len(client.opts.middleware) > 0 { |
|
h = middleware.Chain(client.opts.middleware...)(h) |
|
} |
|
_, err := h(ctx, args) |
|
return err |
|
} |
|
func (client *Client) Do(req *http.Request, opts ...CallOption) (*http.Response, error) { |
|
c := defaultCallInfo(req.URL.Path) |
|
for _, o := range opts { |
|
if err := o.before(&c); err != nil { |
|
return nil, err |
|
} |
|
} |
|
return client.do(req) |
|
} |
|
func (client *Client) do(req *http.Request) (*http.Response, error) { |
|
var done func(context.Context, selector.DoneInfo) |
|
if client.r != nil { |
|
var ( |
|
err error |
|
node selector.Node |
|
) |
|
if node, done, err = client.selector.Select(req.Context(), selector.WithNodeFilter(client.opts.nodeFilters...)); err != nil { |
|
return nil, errors.ServiceUnavailable("NODE_NOT_FOUND", err.Error()) |
|
} |
|
if client.insecure { |
|
req.URL.Scheme = "http" |
|
} else { |
|
req.URL.Scheme = "https" |
|
} |
|
req.URL.Host = node.Address() |
|
req.Host = node.Address() |
|
} |
|
resp, err := client.cc.Do(req) |
|
if err == nil { |
|
err = client.opts.errorDecoder(req.Context(), resp) |
|
} |
|
if done != nil { |
|
done(req.Context(), selector.DoneInfo{Err: err}) |
|
} |
|
if err != nil { |
|
return nil, err |
|
} |
|
return resp, nil |
|
} |
|
func (client *Client) Close() error { |
|
if client.r != nil { |
|
return client.r.Close() |
|
} |
|
return nil |
|
} |
|
func DefaultrequestEncoder(ctx context.Context, contentType string, v any) ([]byte, error) { |
|
name := httputil.ContentSubtype(contentType) |
|
return encoding.GetCodec(name).Marshal(v) |
|
} |
|
func DefaultResponseDecoder(ctx context.Context, res *http.Response, v any) error { |
|
defer res.Body.Close() |
|
data, err := io.ReadAll(res.Body) |
|
if err != nil { |
|
return err |
|
} |
|
return CodecForResponse(res).Unmarshal(data, v) |
|
} |
|
func DefaultErrorDecoder(ctx context.Context, res *http.Response) error { |
|
if res.StatusCode >= 200 && res.StatusCode <= 299 { |
|
return nil |
|
} |
|
data, err := io.ReadAll(res.Body) |
|
defer res.Body.Close() |
|
if err == nil { |
|
e := new(errors.Error) |
|
if err = CodecForResponse(res).Unmarshal(data, e); err == nil { |
|
e.Code = int32(res.StatusCode) |
|
return e |
|
} |
|
} |
|
return errors.Newf(res.StatusCode, errors.UnknownReason, "").WithCause(err) |
|
} |
|
func CodecForResponse(r *http.Response) encoding.Codec { |
|
codec := encoding.GetCodec(httputil.ContentSubtype(r.Request.Header.Get("Content-Type"))) |
|
if codec != nil { |
|
return codec |
|
} |
|
return encoding.GetCodec("json") |
|
}
|
|
|