blob: 8fd09ac18f46a24c1baa74d6ff4337a1a5c3a16c [file] [log] [blame]
package desc
import (
"errors"
"fmt"
"reflect"
"sync"
"github.com/golang/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/types/descriptorpb"
"github.com/jhump/protoreflect/desc/sourceinfo"
"github.com/jhump/protoreflect/internal"
)
// The global cache is used to store descriptors that wrap items in
// protoregistry.GlobalTypes and protoregistry.GlobalFiles. This prevents
// repeating work to re-wrap underlying global descriptors.
var (
// We put all wrapped file and message descriptors in this cache.
loadedDescriptors = lockingCache{cache: mapCache{}}
// Unfortunately, we need a different mechanism for enums for
// compatibility with old APIs, which required that they were
// registered in a different way :(
loadedEnumsMu sync.RWMutex
loadedEnums = map[reflect.Type]*EnumDescriptor{}
)
// LoadFileDescriptor creates a file descriptor using the bytes returned by
// proto.FileDescriptor. Descriptors are cached so that they do not need to be
// re-processed if the same file is fetched again later.
func LoadFileDescriptor(file string) (*FileDescriptor, error) {
d, err := sourceinfo.GlobalFiles.FindFileByPath(file)
if errors.Is(err, protoregistry.NotFound) {
// for backwards compatibility, see if this matches a known old
// alias for the file (older versions of libraries that registered
// the files using incorrect/non-canonical paths)
if alt := internal.StdFileAliases[file]; alt != "" {
d, err = sourceinfo.GlobalFiles.FindFileByPath(alt)
}
}
if err != nil {
if !errors.Is(err, protoregistry.NotFound) {
return nil, internal.ErrNoSuchFile(file)
}
return nil, err
}
if fd := loadedDescriptors.get(d); fd != nil {
return fd.(*FileDescriptor), nil
}
var fd *FileDescriptor
loadedDescriptors.withLock(func(cache descriptorCache) {
fd, err = wrapFile(d, cache)
})
return fd, err
}
// LoadMessageDescriptor loads descriptor using the encoded descriptor proto returned by
// Message.Descriptor() for the given message type. If the given type is not recognized,
// then a nil descriptor is returned.
func LoadMessageDescriptor(message string) (*MessageDescriptor, error) {
mt, err := sourceinfo.GlobalTypes.FindMessageByName(protoreflect.FullName(message))
if err != nil {
if errors.Is(err, protoregistry.NotFound) {
return nil, nil
}
return nil, err
}
return loadMessageDescriptor(mt.Descriptor())
}
func loadMessageDescriptor(md protoreflect.MessageDescriptor) (*MessageDescriptor, error) {
d := loadedDescriptors.get(md)
if d != nil {
return d.(*MessageDescriptor), nil
}
var err error
loadedDescriptors.withLock(func(cache descriptorCache) {
d, err = wrapMessage(md, cache)
})
if err != nil {
return nil, err
}
return d.(*MessageDescriptor), err
}
// LoadMessageDescriptorForType loads descriptor using the encoded descriptor proto returned
// by message.Descriptor() for the given message type. If the given type is not recognized,
// then a nil descriptor is returned.
func LoadMessageDescriptorForType(messageType reflect.Type) (*MessageDescriptor, error) {
m, err := messageFromType(messageType)
if err != nil {
return nil, err
}
return LoadMessageDescriptorForMessage(m)
}
// LoadMessageDescriptorForMessage loads descriptor using the encoded descriptor proto
// returned by message.Descriptor(). If the given type is not recognized, then a nil
// descriptor is returned.
func LoadMessageDescriptorForMessage(message proto.Message) (*MessageDescriptor, error) {
// efficiently handle dynamic messages
type descriptorable interface {
GetMessageDescriptor() *MessageDescriptor
}
if d, ok := message.(descriptorable); ok {
return d.GetMessageDescriptor(), nil
}
var md protoreflect.MessageDescriptor
if m, ok := message.(protoreflect.ProtoMessage); ok {
md = m.ProtoReflect().Descriptor()
} else {
md = proto.MessageReflect(message).Descriptor()
}
return loadMessageDescriptor(sourceinfo.WrapMessage(md))
}
func messageFromType(mt reflect.Type) (proto.Message, error) {
if mt.Kind() != reflect.Ptr {
mt = reflect.PtrTo(mt)
}
m, ok := reflect.Zero(mt).Interface().(proto.Message)
if !ok {
return nil, fmt.Errorf("failed to create message from type: %v", mt)
}
return m, nil
}
// interface implemented by all generated enums
type protoEnum interface {
EnumDescriptor() ([]byte, []int)
}
// NB: There is no LoadEnumDescriptor that takes a fully-qualified enum name because
// it is not useful since protoc-gen-go does not expose the name anywhere in generated
// code or register it in a way that is it accessible for reflection code. This also
// means we have to cache enum descriptors differently -- we can only cache them as
// they are requested, as opposed to caching all enum types whenever a file descriptor
// is cached. This is because we need to know the generated type of the enums, and we
// don't know that at the time of caching file descriptors.
// LoadEnumDescriptorForType loads descriptor using the encoded descriptor proto returned
// by enum.EnumDescriptor() for the given enum type.
func LoadEnumDescriptorForType(enumType reflect.Type) (*EnumDescriptor, error) {
// we cache descriptors using non-pointer type
if enumType.Kind() == reflect.Ptr {
enumType = enumType.Elem()
}
e := getEnumFromCache(enumType)
if e != nil {
return e, nil
}
enum, err := enumFromType(enumType)
if err != nil {
return nil, err
}
return loadEnumDescriptor(enumType, enum)
}
func getEnumFromCache(t reflect.Type) *EnumDescriptor {
loadedEnumsMu.RLock()
defer loadedEnumsMu.RUnlock()
return loadedEnums[t]
}
func putEnumInCache(t reflect.Type, d *EnumDescriptor) {
loadedEnumsMu.Lock()
defer loadedEnumsMu.Unlock()
loadedEnums[t] = d
}
// LoadEnumDescriptorForEnum loads descriptor using the encoded descriptor proto
// returned by enum.EnumDescriptor().
func LoadEnumDescriptorForEnum(enum protoEnum) (*EnumDescriptor, error) {
et := reflect.TypeOf(enum)
// we cache descriptors using non-pointer type
if et.Kind() == reflect.Ptr {
et = et.Elem()
enum = reflect.Zero(et).Interface().(protoEnum)
}
e := getEnumFromCache(et)
if e != nil {
return e, nil
}
return loadEnumDescriptor(et, enum)
}
func enumFromType(et reflect.Type) (protoEnum, error) {
e, ok := reflect.Zero(et).Interface().(protoEnum)
if !ok {
if et.Kind() != reflect.Ptr {
et = et.Elem()
}
e, ok = reflect.Zero(et).Interface().(protoEnum)
}
if !ok {
return nil, fmt.Errorf("failed to create enum from type: %v", et)
}
return e, nil
}
func getDescriptorForEnum(enum protoEnum) (*descriptorpb.FileDescriptorProto, []int, error) {
fdb, path := enum.EnumDescriptor()
name := fmt.Sprintf("%T", enum)
fd, err := internal.DecodeFileDescriptor(name, fdb)
return fd, path, err
}
func loadEnumDescriptor(et reflect.Type, enum protoEnum) (*EnumDescriptor, error) {
fdp, path, err := getDescriptorForEnum(enum)
if err != nil {
return nil, err
}
fd, err := LoadFileDescriptor(fdp.GetName())
if err != nil {
return nil, err
}
ed := findEnum(fd, path)
putEnumInCache(et, ed)
return ed, nil
}
func findEnum(fd *FileDescriptor, path []int) *EnumDescriptor {
if len(path) == 1 {
return fd.GetEnumTypes()[path[0]]
}
md := fd.GetMessageTypes()[path[0]]
for _, i := range path[1 : len(path)-1] {
md = md.GetNestedMessageTypes()[i]
}
return md.GetNestedEnumTypes()[path[len(path)-1]]
}
// LoadFieldDescriptorForExtension loads the field descriptor that corresponds to the given
// extension description.
func LoadFieldDescriptorForExtension(ext *proto.ExtensionDesc) (*FieldDescriptor, error) {
file, err := LoadFileDescriptor(ext.Filename)
if err != nil {
return nil, err
}
field, ok := file.FindSymbol(ext.Name).(*FieldDescriptor)
// make sure descriptor agrees with attributes of the ExtensionDesc
if !ok || !field.IsExtension() || field.GetOwner().GetFullyQualifiedName() != proto.MessageName(ext.ExtendedType) ||
field.GetNumber() != ext.Field {
return nil, fmt.Errorf("file descriptor contained unexpected object with name %s", ext.Name)
}
return field, nil
}