blob: 1a35540ca51e7d065cc08a3cac9cc68652cae7b7 [file] [log] [blame]
package grpcreflect
import (
"bytes"
"context"
"fmt"
"io"
"reflect"
"runtime"
"sort"
"sync"
"sync/atomic"
"time"
"github.com/golang/protobuf/proto"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
refv1 "google.golang.org/grpc/reflection/grpc_reflection_v1"
refv1alpha "google.golang.org/grpc/reflection/grpc_reflection_v1alpha"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/reflect/protodesc"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/types/descriptorpb"
"github.com/jhump/protoreflect/desc"
"github.com/jhump/protoreflect/internal"
)
// If we try the v1 reflection API and get back "not implemented", we'll wait
// this long before trying v1 again. This allows a long-lived client to
// dynamically switch from v1alpha to v1 if the underlying server is updated
// to support it. But it also prevents every stream request from always trying
// v1 first: if we try it and see it fail, we shouldn't continually retry it
// if we expect it will fail again.
const durationBetweenV1Attempts = time.Hour
// elementNotFoundError is the error returned by reflective operations where the
// server does not recognize a given file name, symbol name, or extension.
type elementNotFoundError struct {
name string
kind elementKind
symType symbolType // only used when kind == elementKindSymbol
tag int32 // only used when kind == elementKindExtension
// only errors with a kind of elementKindFile will have a cause, which means
// the named file count not be resolved because of a dependency that could
// not be found where cause describes the missing dependency
cause *elementNotFoundError
}
type elementKind int
const (
elementKindSymbol elementKind = iota
elementKindFile
elementKindExtension
)
type symbolType string
const (
symbolTypeService = "Service"
symbolTypeMessage = "Message"
symbolTypeEnum = "Enum"
symbolTypeUnknown = "Symbol"
)
func symbolNotFound(symbol string, symType symbolType, cause *elementNotFoundError) error {
if cause != nil && cause.kind == elementKindSymbol && cause.name == symbol {
// no need to wrap
if symType != symbolTypeUnknown && cause.symType == symbolTypeUnknown {
// We previously didn't know symbol type but now do?
// Create a new error that has the right symbol type.
return &elementNotFoundError{name: symbol, symType: symType, kind: elementKindSymbol}
}
return cause
}
return &elementNotFoundError{name: symbol, symType: symType, kind: elementKindSymbol, cause: cause}
}
func extensionNotFound(extendee string, tag int32, cause *elementNotFoundError) error {
if cause != nil && cause.kind == elementKindExtension && cause.name == extendee && cause.tag == tag {
// no need to wrap
return cause
}
return &elementNotFoundError{name: extendee, tag: tag, kind: elementKindExtension, cause: cause}
}
func fileNotFound(file string, cause *elementNotFoundError) error {
if cause != nil && cause.kind == elementKindFile && cause.name == file {
// no need to wrap
return cause
}
return &elementNotFoundError{name: file, kind: elementKindFile, cause: cause}
}
func (e *elementNotFoundError) Error() string {
first := true
var b bytes.Buffer
for ; e != nil; e = e.cause {
if first {
first = false
} else {
_, _ = fmt.Fprint(&b, "\ncaused by: ")
}
switch e.kind {
case elementKindSymbol:
_, _ = fmt.Fprintf(&b, "%s not found: %s", e.symType, e.name)
case elementKindExtension:
_, _ = fmt.Fprintf(&b, "Extension not found: tag %d for %s", e.tag, e.name)
default:
_, _ = fmt.Fprintf(&b, "File not found: %s", e.name)
}
}
return b.String()
}
// IsElementNotFoundError determines if the given error indicates that a file
// name, symbol name, or extension field was could not be found by the server.
func IsElementNotFoundError(err error) bool {
_, ok := err.(*elementNotFoundError)
return ok
}
// ProtocolError is an error returned when the server sends a response of the
// wrong type.
type ProtocolError struct {
missingType reflect.Type
}
func (p ProtocolError) Error() string {
return fmt.Sprintf("Protocol error: response was missing %v", p.missingType)
}
type extDesc struct {
extendedMessageName string
extensionNumber int32
}
type resolvers struct {
descriptorResolver protodesc.Resolver
extensionResolver protoregistry.ExtensionTypeResolver
}
type fileEntry struct {
fd *desc.FileDescriptor
fallback bool
}
// Client is a client connection to a server for performing reflection calls
// and resolving remote symbols.
type Client struct {
ctx context.Context
now func() time.Time
stubV1 refv1.ServerReflectionClient
stubV1Alpha refv1alpha.ServerReflectionClient
allowMissing atomic.Bool
fallbackResolver atomic.Pointer[resolvers]
connMu sync.Mutex
cancel context.CancelFunc
stream refv1.ServerReflection_ServerReflectionInfoClient
useV1Alpha bool
lastTriedV1 time.Time
cacheMu sync.RWMutex
protosByName map[string]*descriptorpb.FileDescriptorProto
filesByName map[string]fileEntry
filesBySymbol map[string]fileEntry
filesByExtension map[extDesc]fileEntry
}
// NewClient creates a new Client with the given root context and using the
// given RPC stub for talking to the server.
//
// Deprecated: Use NewClientV1Alpha if you are intentionally pinning the
// v1alpha version of the reflection service. Otherwise, use NewClientAuto
// instead.
func NewClient(ctx context.Context, stub refv1alpha.ServerReflectionClient) *Client {
return NewClientV1Alpha(ctx, stub)
}
// NewClientV1Alpha creates a new Client using the v1alpha version of reflection
// with the given root context and using the given RPC stub for talking to the
// server.
func NewClientV1Alpha(ctx context.Context, stub refv1alpha.ServerReflectionClient) *Client {
return newClient(ctx, nil, stub)
}
// NewClientV1 creates a new Client using the v1 version of reflection with the
// given root context and using the given RPC stub for talking to the server.
func NewClientV1(ctx context.Context, stub refv1.ServerReflectionClient) *Client {
return newClient(ctx, stub, nil)
}
func newClient(ctx context.Context, stubv1 refv1.ServerReflectionClient, stubv1alpha refv1alpha.ServerReflectionClient) *Client {
cr := &Client{
ctx: ctx,
now: time.Now,
stubV1: stubv1,
stubV1Alpha: stubv1alpha,
protosByName: map[string]*descriptorpb.FileDescriptorProto{},
filesByName: map[string]fileEntry{},
filesBySymbol: map[string]fileEntry{},
filesByExtension: map[extDesc]fileEntry{},
}
// don't leak a grpc stream
runtime.SetFinalizer(cr, (*Client).Reset)
return cr
}
// NewClientAuto creates a new Client that will use either v1 or v1alpha version
// of reflection (based on what the server supports) with the given root context
// and using the given client connection.
//
// It will first the v1 version of the reflection service. If it gets back an
// "Unimplemented" error, it will fall back to using the v1alpha version. It
// will remember which version the server supports for any subsequent operations
// that need to re-invoke the streaming RPC. But, if it's a very long-lived
// client, it will periodically retry the v1 version (in case the server is
// updated to support it also). The period for these retries is every hour.
func NewClientAuto(ctx context.Context, cc grpc.ClientConnInterface) *Client {
stubv1 := refv1.NewServerReflectionClient(cc)
stubv1alpha := refv1alpha.NewServerReflectionClient(cc)
return newClient(ctx, stubv1, stubv1alpha)
}
// AllowMissingFileDescriptors configures the client to allow missing files
// when building descriptors when possible. Missing files are often fatal
// errors, but with this option they can sometimes be worked around. Building
// a schema can only succeed with some files missing if the files in question
// only provide custom options and/or other unused types.
func (cr *Client) AllowMissingFileDescriptors() {
cr.allowMissing.Store(true)
}
// AllowFallbackResolver configures the client to allow falling back to the
// given resolvers if the server is unable to supply descriptors for a particular
// query. This allows working around issues where servers' reflection service
// provides an incomplete set of descriptors, but the client has knowledge of
// the missing descriptors from another source. It is usually most appropriate
// to pass [protoregistry.GlobalFiles] and [protoregistry.GlobalTypes] as the
// resolver values.
//
// The first value is used as a fallback for FileByFilename and FileContainingSymbol
// queries. The second value is used as a fallback for FileContainingExtension. It
// can also be used as a fallback for AllExtensionNumbersForType if it provides
// a method with the following signature (which *[protoregistry.Types] provides):
//
// RangeExtensionsByMessage(message protoreflect.FullName, f func(protoreflect.ExtensionType) bool)
func (cr *Client) AllowFallbackResolver(descriptors protodesc.Resolver, exts protoregistry.ExtensionTypeResolver) {
if descriptors == nil && exts == nil {
cr.fallbackResolver.Store(nil)
} else {
cr.fallbackResolver.Store(&resolvers{
descriptorResolver: descriptors,
extensionResolver: exts,
})
}
}
// FileByFilename asks the server for a file descriptor for the proto file with
// the given name.
func (cr *Client) FileByFilename(filename string) (*desc.FileDescriptor, error) {
// hit the cache first
cr.cacheMu.RLock()
if entry, ok := cr.filesByName[filename]; ok {
cr.cacheMu.RUnlock()
return entry.fd, nil
}
// not there? see if we've downloaded the proto
fdp, ok := cr.protosByName[filename]
cr.cacheMu.RUnlock()
if ok {
return cr.descriptorFromProto(fdp)
}
req := &refv1.ServerReflectionRequest{
MessageRequest: &refv1.ServerReflectionRequest_FileByFilename{
FileByFilename: filename,
},
}
accept := func(fd *desc.FileDescriptor) bool {
return fd.GetName() == filename
}
fd, err := cr.getAndCacheFileDescriptors(req, filename, "", accept)
if isNotFound(err) {
// File not found? see if we can look up via alternate name
if alternate, ok := internal.StdFileAliases[filename]; ok {
req := &refv1.ServerReflectionRequest{
MessageRequest: &refv1.ServerReflectionRequest_FileByFilename{
FileByFilename: alternate,
},
}
fd, err = cr.getAndCacheFileDescriptors(req, alternate, filename, accept)
}
}
if isNotFound(err) {
// Still no? See if we can use a fallback resolver
resolver := cr.fallbackResolver.Load()
if resolver != nil && resolver.descriptorResolver != nil {
fileDesc, fallbackErr := resolver.descriptorResolver.FindFileByPath(filename)
if fallbackErr == nil {
var wrapErr error
fd, wrapErr = desc.WrapFile(fileDesc)
if wrapErr == nil {
fd = cr.cacheFile(fd, true)
err = nil // clear error since we've succeeded via the fallback
}
}
}
}
if isNotFound(err) {
err = fileNotFound(filename, nil)
} else if e, ok := err.(*elementNotFoundError); ok {
err = fileNotFound(filename, e)
}
return fd, err
}
// FileContainingSymbol asks the server for a file descriptor for the proto file
// that declares the given fully-qualified symbol.
func (cr *Client) FileContainingSymbol(symbol string) (*desc.FileDescriptor, error) {
// hit the cache first
cr.cacheMu.RLock()
entry, ok := cr.filesBySymbol[symbol]
cr.cacheMu.RUnlock()
if ok {
return entry.fd, nil
}
req := &refv1.ServerReflectionRequest{
MessageRequest: &refv1.ServerReflectionRequest_FileContainingSymbol{
FileContainingSymbol: symbol,
},
}
accept := func(fd *desc.FileDescriptor) bool {
return fd.FindSymbol(symbol) != nil
}
fd, err := cr.getAndCacheFileDescriptors(req, "", "", accept)
if isNotFound(err) {
// Symbol not found? See if we can use a fallback resolver
resolver := cr.fallbackResolver.Load()
if resolver != nil && resolver.descriptorResolver != nil {
d, fallbackErr := resolver.descriptorResolver.FindDescriptorByName(protoreflect.FullName(symbol))
if fallbackErr == nil {
var wrapErr error
fd, wrapErr = desc.WrapFile(d.ParentFile())
if wrapErr == nil {
fd = cr.cacheFile(fd, true)
err = nil // clear error since we've succeeded via the fallback
}
}
}
}
if isNotFound(err) {
err = symbolNotFound(symbol, symbolTypeUnknown, nil)
} else if e, ok := err.(*elementNotFoundError); ok {
err = symbolNotFound(symbol, symbolTypeUnknown, e)
}
return fd, err
}
// FileContainingExtension asks the server for a file descriptor for the proto
// file that declares an extension with the given number for the given
// fully-qualified message name.
func (cr *Client) FileContainingExtension(extendedMessageName string, extensionNumber int32) (*desc.FileDescriptor, error) {
// hit the cache first
cr.cacheMu.RLock()
entry, ok := cr.filesByExtension[extDesc{extendedMessageName, extensionNumber}]
cr.cacheMu.RUnlock()
if ok {
return entry.fd, nil
}
req := &refv1.ServerReflectionRequest{
MessageRequest: &refv1.ServerReflectionRequest_FileContainingExtension{
FileContainingExtension: &refv1.ExtensionRequest{
ContainingType: extendedMessageName,
ExtensionNumber: extensionNumber,
},
},
}
accept := func(fd *desc.FileDescriptor) bool {
return fd.FindExtension(extendedMessageName, extensionNumber) != nil
}
fd, err := cr.getAndCacheFileDescriptors(req, "", "", accept)
if isNotFound(err) {
// Extension not found? See if we can use a fallback resolver
resolver := cr.fallbackResolver.Load()
if resolver != nil && resolver.extensionResolver != nil {
extType, fallbackErr := resolver.extensionResolver.FindExtensionByNumber(protoreflect.FullName(extendedMessageName), protoreflect.FieldNumber(extensionNumber))
if fallbackErr == nil {
var wrapErr error
fd, wrapErr = desc.WrapFile(extType.TypeDescriptor().ParentFile())
if wrapErr == nil {
fd = cr.cacheFile(fd, true)
err = nil // clear error since we've succeeded via the fallback
}
}
}
}
if isNotFound(err) {
err = extensionNotFound(extendedMessageName, extensionNumber, nil)
} else if e, ok := err.(*elementNotFoundError); ok {
err = extensionNotFound(extendedMessageName, extensionNumber, e)
}
return fd, err
}
func (cr *Client) getAndCacheFileDescriptors(req *refv1.ServerReflectionRequest, expectedName, alias string, accept func(*desc.FileDescriptor) bool) (*desc.FileDescriptor, error) {
resp, err := cr.send(req)
if err != nil {
return nil, err
}
fdResp := resp.GetFileDescriptorResponse()
if fdResp == nil {
return nil, &ProtocolError{reflect.TypeOf(fdResp).Elem()}
}
// Response can contain the result file descriptor, but also its transitive
// deps. Furthermore, protocol states that subsequent requests do not need
// to send transitive deps that have been sent in prior responses. So we
// need to cache all file descriptors and then return the first one (which
// should be the answer). If we're looking for a file by name, we can be
// smarter and make sure to grab one by name instead of just grabbing the
// first one.
var fds []*descriptorpb.FileDescriptorProto
for _, fdBytes := range fdResp.FileDescriptorProto {
fd := &descriptorpb.FileDescriptorProto{}
if err = proto.Unmarshal(fdBytes, fd); err != nil {
return nil, err
}
if expectedName != "" && alias != "" && expectedName != alias && fd.GetName() == expectedName {
// we found a file was aliased, so we need to update the proto to reflect that
fd.Name = proto.String(alias)
}
cr.cacheMu.Lock()
// store in cache of raw descriptor protos, but don't overwrite existing protos
if existingFd, ok := cr.protosByName[fd.GetName()]; ok {
fd = existingFd
} else {
cr.protosByName[fd.GetName()] = fd
}
cr.cacheMu.Unlock()
fds = append(fds, fd)
}
// find the right result from the files returned
for _, fd := range fds {
result, err := cr.descriptorFromProto(fd)
if err != nil {
return nil, err
}
if accept(result) {
return result, nil
}
}
return nil, status.Errorf(codes.NotFound, "response does not include expected file")
}
func (cr *Client) descriptorFromProto(fd *descriptorpb.FileDescriptorProto) (*desc.FileDescriptor, error) {
allowMissing := cr.allowMissing.Load()
deps := make([]*desc.FileDescriptor, 0, len(fd.GetDependency()))
var deferredErr error
var missingDeps []int
for i, depName := range fd.GetDependency() {
if dep, err := cr.FileByFilename(depName); err != nil {
if _, ok := err.(*elementNotFoundError); !ok || !allowMissing {
return nil, err
}
// We'll ignore for now to see if the file is really necessary.
// (If it only supplies custom options, we can get by without it.)
if deferredErr == nil {
deferredErr = err
}
missingDeps = append(missingDeps, i)
} else {
deps = append(deps, dep)
}
}
if len(missingDeps) > 0 {
fd = fileWithoutDeps(fd, missingDeps)
}
d, err := desc.CreateFileDescriptor(fd, deps...)
if err != nil {
if deferredErr != nil {
// assume the issue is the missing dep
return nil, deferredErr
}
return nil, err
}
d = cr.cacheFile(d, false)
return d, nil
}
func (cr *Client) cacheFile(fd *desc.FileDescriptor, fallback bool) *desc.FileDescriptor {
cr.cacheMu.Lock()
defer cr.cacheMu.Unlock()
// Cache file descriptor by name. If we can't overwrite an existing
// entry, return it. (Existing entry could come from concurrent caller.)
if existing, ok := cr.filesByName[fd.GetName()]; ok && !canOverwrite(existing, fallback) {
return existing.fd
}
entry := fileEntry{fd: fd, fallback: fallback}
cr.filesByName[fd.GetName()] = entry
// also cache by symbols and extensions
for _, m := range fd.GetMessageTypes() {
cr.cacheMessageLocked(m, entry)
}
for _, e := range fd.GetEnumTypes() {
if !cr.maybeCacheFileBySymbol(e.GetFullyQualifiedName(), entry) {
continue
}
for _, v := range e.GetValues() {
cr.maybeCacheFileBySymbol(v.GetFullyQualifiedName(), entry)
}
}
for _, e := range fd.GetExtensions() {
if !cr.maybeCacheFileBySymbol(e.GetFullyQualifiedName(), entry) {
continue
}
cr.maybeCacheFileByExtension(extDesc{e.GetOwner().GetFullyQualifiedName(), e.GetNumber()}, entry)
}
for _, s := range fd.GetServices() {
if !cr.maybeCacheFileBySymbol(s.GetFullyQualifiedName(), entry) {
continue
}
for _, m := range s.GetMethods() {
cr.maybeCacheFileBySymbol(m.GetFullyQualifiedName(), entry)
}
}
return fd
}
func (cr *Client) cacheMessageLocked(md *desc.MessageDescriptor, entry fileEntry) {
if !cr.maybeCacheFileBySymbol(md.GetFullyQualifiedName(), entry) {
return
}
for _, f := range md.GetFields() {
cr.maybeCacheFileBySymbol(f.GetFullyQualifiedName(), entry)
}
for _, o := range md.GetOneOfs() {
cr.maybeCacheFileBySymbol(o.GetFullyQualifiedName(), entry)
}
for _, e := range md.GetNestedEnumTypes() {
if !cr.maybeCacheFileBySymbol(e.GetFullyQualifiedName(), entry) {
continue
}
for _, v := range e.GetValues() {
cr.maybeCacheFileBySymbol(v.GetFullyQualifiedName(), entry)
}
}
for _, e := range md.GetNestedExtensions() {
if !cr.maybeCacheFileBySymbol(e.GetFullyQualifiedName(), entry) {
continue
}
cr.maybeCacheFileByExtension(extDesc{e.GetOwner().GetFullyQualifiedName(), e.GetNumber()}, entry)
}
for _, m := range md.GetNestedMessageTypes() {
cr.cacheMessageLocked(m, entry) // recurse
}
}
func canOverwrite(existing fileEntry, fallback bool) bool {
return !fallback && existing.fallback
}
func (cr *Client) maybeCacheFileBySymbol(symbol string, entry fileEntry) bool {
existing, ok := cr.filesBySymbol[symbol]
if ok && !canOverwrite(existing, entry.fallback) {
return false
}
cr.filesBySymbol[symbol] = entry
return true
}
func (cr *Client) maybeCacheFileByExtension(ext extDesc, entry fileEntry) {
existing, ok := cr.filesByExtension[ext]
if ok && !canOverwrite(existing, entry.fallback) {
return
}
cr.filesByExtension[ext] = entry
}
// AllExtensionNumbersForType asks the server for all known extension numbers
// for the given fully-qualified message name.
func (cr *Client) AllExtensionNumbersForType(extendedMessageName string) ([]int32, error) {
req := &refv1.ServerReflectionRequest{
MessageRequest: &refv1.ServerReflectionRequest_AllExtensionNumbersOfType{
AllExtensionNumbersOfType: extendedMessageName,
},
}
resp, err := cr.send(req)
var exts []int32
if err != nil && !isNotFound(err) {
// If the server doesn't know about the message type and returns "not found",
// we'll treat that as "no known extensions" instead of returning an error.
return nil, err
}
if err == nil {
extResp := resp.GetAllExtensionNumbersResponse()
if extResp == nil {
return nil, &ProtocolError{reflect.TypeOf(extResp).Elem()}
}
exts = extResp.ExtensionNumber
}
resolver := cr.fallbackResolver.Load()
if resolver != nil && resolver.extensionResolver != nil {
type extRanger interface {
RangeExtensionsByMessage(message protoreflect.FullName, f func(protoreflect.ExtensionType) bool)
}
if ranger, ok := resolver.extensionResolver.(extRanger); ok {
// Merge results with fallback resolver
extSet := map[int32]struct{}{}
ranger.RangeExtensionsByMessage(protoreflect.FullName(extendedMessageName), func(extType protoreflect.ExtensionType) bool {
extSet[int32(extType.TypeDescriptor().Number())] = struct{}{}
return true
})
if len(extSet) > 0 {
// De-dupe with the set of extension numbers we got
// from the server and merge the results back into exts.
for _, ext := range exts {
extSet[ext] = struct{}{}
}
exts = make([]int32, 0, len(extSet))
for ext := range extSet {
exts = append(exts, ext)
}
sort.Slice(exts, func(i, j int) bool { return exts[i] < exts[j] })
}
}
}
return exts, nil
}
// ListServices asks the server for the fully-qualified names of all exposed
// services.
func (cr *Client) ListServices() ([]string, error) {
req := &refv1.ServerReflectionRequest{
MessageRequest: &refv1.ServerReflectionRequest_ListServices{
// proto doesn't indicate any purpose for this value and server impl
// doesn't actually use it...
ListServices: "*",
},
}
resp, err := cr.send(req)
if err != nil {
return nil, err
}
listResp := resp.GetListServicesResponse()
if listResp == nil {
return nil, &ProtocolError{reflect.TypeOf(listResp).Elem()}
}
serviceNames := make([]string, len(listResp.Service))
for i, s := range listResp.Service {
serviceNames[i] = s.Name
}
return serviceNames, nil
}
func (cr *Client) send(req *refv1.ServerReflectionRequest) (*refv1.ServerReflectionResponse, error) {
// we allow one immediate retry, in case we have a stale stream
// (e.g. closed by server)
resp, err := cr.doSend(req)
if err != nil {
return nil, err
}
// convert error response messages into errors
errResp := resp.GetErrorResponse()
if errResp != nil {
return nil, status.Errorf(codes.Code(errResp.ErrorCode), "%s", errResp.ErrorMessage)
}
return resp, nil
}
func isNotFound(err error) bool {
if err == nil {
return false
}
s, ok := status.FromError(err)
return ok && s.Code() == codes.NotFound
}
func (cr *Client) doSend(req *refv1.ServerReflectionRequest) (*refv1.ServerReflectionResponse, error) {
// TODO: Streams are thread-safe, so we shouldn't need to lock. But without locking, we'll need more machinery
// (goroutines and channels) to ensure that responses are correctly correlated with their requests and thus
// delivered in correct oder.
cr.connMu.Lock()
defer cr.connMu.Unlock()
return cr.doSendLocked(0, nil, req)
}
func (cr *Client) doSendLocked(attemptCount int, prevErr error, req *refv1.ServerReflectionRequest) (*refv1.ServerReflectionResponse, error) {
if attemptCount >= 3 && prevErr != nil {
return nil, prevErr
}
if (status.Code(prevErr) == codes.Unimplemented ||
status.Code(prevErr) == codes.Unavailable) &&
cr.useV1() {
// If v1 is unimplemented, fallback to v1alpha.
// We also fallback on unavailable because some servers have been
// observed to close the connection/cancel the stream, w/out sending
// back status or headers, when the service name is not known. When
// this happens, the RPC status code is unavailable.
// See https://github.com/fullstorydev/grpcurl/issues/434
cr.useV1Alpha = true
cr.lastTriedV1 = cr.now()
}
attemptCount++
if err := cr.initStreamLocked(); err != nil {
return nil, err
}
if err := cr.stream.Send(req); err != nil {
if err == io.EOF {
// if send returns EOF, must call Recv to get real underlying error
_, err = cr.stream.Recv()
}
cr.resetLocked()
return cr.doSendLocked(attemptCount, err, req)
}
resp, err := cr.stream.Recv()
if err != nil {
cr.resetLocked()
return cr.doSendLocked(attemptCount, err, req)
}
return resp, nil
}
func (cr *Client) initStreamLocked() error {
if cr.stream != nil {
return nil
}
var newCtx context.Context
newCtx, cr.cancel = context.WithCancel(cr.ctx)
if cr.useV1Alpha == true && cr.now().Sub(cr.lastTriedV1) > durationBetweenV1Attempts {
// we're due for periodic retry of v1
cr.useV1Alpha = false
}
if cr.useV1() {
// try the v1 API
streamv1, err := cr.stubV1.ServerReflectionInfo(newCtx)
if err == nil {
cr.stream = streamv1
return nil
}
if status.Code(err) != codes.Unimplemented {
return err
}
// oh well, fall through below to try v1alpha and update state
// so we skip straight to v1alpha next time
cr.useV1Alpha = true
cr.lastTriedV1 = cr.now()
}
var err error
streamv1alpha, err := cr.stubV1Alpha.ServerReflectionInfo(newCtx)
if err == nil {
cr.stream = adaptStreamFromV1Alpha{streamv1alpha}
return nil
}
return err
}
func (cr *Client) useV1() bool {
return !cr.useV1Alpha && cr.stubV1 != nil
}
// Reset ensures that any active stream with the server is closed, releasing any
// resources.
func (cr *Client) Reset() {
cr.connMu.Lock()
defer cr.connMu.Unlock()
cr.resetLocked()
}
func (cr *Client) resetLocked() {
if cr.stream != nil {
_ = cr.stream.CloseSend()
for {
// drain the stream, this covers io.EOF too
if _, err := cr.stream.Recv(); err != nil {
break
}
}
cr.stream = nil
}
if cr.cancel != nil {
cr.cancel()
cr.cancel = nil
}
}
// ResolveService asks the server to resolve the given fully-qualified service
// name into a service descriptor.
func (cr *Client) ResolveService(serviceName string) (*desc.ServiceDescriptor, error) {
file, err := cr.FileContainingSymbol(serviceName)
if err != nil {
return nil, setSymbolType(err, serviceName, symbolTypeService)
}
d := file.FindSymbol(serviceName)
if d == nil {
return nil, symbolNotFound(serviceName, symbolTypeService, nil)
}
if s, ok := d.(*desc.ServiceDescriptor); ok {
return s, nil
} else {
return nil, symbolNotFound(serviceName, symbolTypeService, nil)
}
}
// ResolveMessage asks the server to resolve the given fully-qualified message
// name into a message descriptor.
func (cr *Client) ResolveMessage(messageName string) (*desc.MessageDescriptor, error) {
file, err := cr.FileContainingSymbol(messageName)
if err != nil {
return nil, setSymbolType(err, messageName, symbolTypeMessage)
}
d := file.FindSymbol(messageName)
if d == nil {
return nil, symbolNotFound(messageName, symbolTypeMessage, nil)
}
if s, ok := d.(*desc.MessageDescriptor); ok {
return s, nil
} else {
return nil, symbolNotFound(messageName, symbolTypeMessage, nil)
}
}
// ResolveEnum asks the server to resolve the given fully-qualified enum name
// into an enum descriptor.
func (cr *Client) ResolveEnum(enumName string) (*desc.EnumDescriptor, error) {
file, err := cr.FileContainingSymbol(enumName)
if err != nil {
return nil, setSymbolType(err, enumName, symbolTypeEnum)
}
d := file.FindSymbol(enumName)
if d == nil {
return nil, symbolNotFound(enumName, symbolTypeEnum, nil)
}
if s, ok := d.(*desc.EnumDescriptor); ok {
return s, nil
} else {
return nil, symbolNotFound(enumName, symbolTypeEnum, nil)
}
}
func setSymbolType(err error, name string, symType symbolType) error {
if e, ok := err.(*elementNotFoundError); ok {
if e.kind == elementKindSymbol && e.name == name && e.symType == symbolTypeUnknown {
e.symType = symType
}
}
return err
}
// ResolveEnumValues asks the server to resolve the given fully-qualified enum
// name into a map of names to numbers that represents the enum's values.
func (cr *Client) ResolveEnumValues(enumName string) (map[string]int32, error) {
enumDesc, err := cr.ResolveEnum(enumName)
if err != nil {
return nil, err
}
vals := map[string]int32{}
for _, valDesc := range enumDesc.GetValues() {
vals[valDesc.GetName()] = valDesc.GetNumber()
}
return vals, nil
}
// ResolveExtension asks the server to resolve the given extension number and
// fully-qualified message name into a field descriptor.
func (cr *Client) ResolveExtension(extendedType string, extensionNumber int32) (*desc.FieldDescriptor, error) {
file, err := cr.FileContainingExtension(extendedType, extensionNumber)
if err != nil {
return nil, err
}
d := findExtension(extendedType, extensionNumber, fileDescriptorExtensions{file})
if d == nil {
return nil, extensionNotFound(extendedType, extensionNumber, nil)
} else {
return d, nil
}
}
func fileWithoutDeps(fd *descriptorpb.FileDescriptorProto, missingDeps []int) *descriptorpb.FileDescriptorProto {
// We need to rebuild the file without the missing deps.
fd = proto.Clone(fd).(*descriptorpb.FileDescriptorProto)
newNumDeps := len(fd.GetDependency()) - len(missingDeps)
newDeps := make([]string, 0, newNumDeps)
remapped := make(map[int]int, newNumDeps)
missingIdx := 0
for i, dep := range fd.GetDependency() {
if missingIdx < len(missingDeps) {
if i == missingDeps[missingIdx] {
// This dep was missing. Skip it.
missingIdx++
continue
}
}
remapped[i] = len(newDeps)
newDeps = append(newDeps, dep)
}
// Also rebuild public and weak import slices.
newPublic := make([]int32, 0, len(fd.GetPublicDependency()))
for _, idx := range fd.GetPublicDependency() {
newIdx, ok := remapped[int(idx)]
if ok {
newPublic = append(newPublic, int32(newIdx))
}
}
newWeak := make([]int32, 0, len(fd.GetWeakDependency()))
for _, idx := range fd.GetWeakDependency() {
newIdx, ok := remapped[int(idx)]
if ok {
newWeak = append(newWeak, int32(newIdx))
}
}
fd.Dependency = newDeps
fd.PublicDependency = newPublic
fd.WeakDependency = newWeak
return fd
}
func findExtension(extendedType string, extensionNumber int32, scope extensionScope) *desc.FieldDescriptor {
// search extensions in this scope
for _, ext := range scope.extensions() {
if ext.GetNumber() == extensionNumber && ext.GetOwner().GetFullyQualifiedName() == extendedType {
return ext
}
}
// if not found, search nested scopes
for _, nested := range scope.nestedScopes() {
ext := findExtension(extendedType, extensionNumber, nested)
if ext != nil {
return ext
}
}
return nil
}
type extensionScope interface {
extensions() []*desc.FieldDescriptor
nestedScopes() []extensionScope
}
// fileDescriptorExtensions implements extensionHolder interface on top of
// FileDescriptorProto
type fileDescriptorExtensions struct {
proto *desc.FileDescriptor
}
func (fde fileDescriptorExtensions) extensions() []*desc.FieldDescriptor {
return fde.proto.GetExtensions()
}
func (fde fileDescriptorExtensions) nestedScopes() []extensionScope {
scopes := make([]extensionScope, len(fde.proto.GetMessageTypes()))
for i, m := range fde.proto.GetMessageTypes() {
scopes[i] = msgDescriptorExtensions{m}
}
return scopes
}
// msgDescriptorExtensions implements extensionHolder interface on top of
// DescriptorProto
type msgDescriptorExtensions struct {
proto *desc.MessageDescriptor
}
func (mde msgDescriptorExtensions) extensions() []*desc.FieldDescriptor {
return mde.proto.GetNestedExtensions()
}
func (mde msgDescriptorExtensions) nestedScopes() []extensionScope {
scopes := make([]extensionScope, len(mde.proto.GetNestedMessageTypes()))
for i, m := range mde.proto.GetNestedMessageTypes() {
scopes[i] = msgDescriptorExtensions{m}
}
return scopes
}
type adaptStreamFromV1Alpha struct {
refv1alpha.ServerReflection_ServerReflectionInfoClient
}
func (a adaptStreamFromV1Alpha) Send(request *refv1.ServerReflectionRequest) error {
v1req := toV1AlphaRequest(request)
return a.ServerReflection_ServerReflectionInfoClient.Send(v1req)
}
func (a adaptStreamFromV1Alpha) Recv() (*refv1.ServerReflectionResponse, error) {
v1resp, err := a.ServerReflection_ServerReflectionInfoClient.Recv()
if err != nil {
return nil, err
}
return toV1Response(v1resp), nil
}