You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
322 lines
9.4 KiB
322 lines
9.4 KiB
package main |
|
|
|
import ( |
|
"fmt" |
|
"os" |
|
"regexp" |
|
"strings" |
|
|
|
"google.golang.org/protobuf/reflect/protoreflect" |
|
|
|
"google.golang.org/genproto/googleapis/api/annotations" |
|
"google.golang.org/protobuf/compiler/protogen" |
|
"google.golang.org/protobuf/proto" |
|
"google.golang.org/protobuf/types/descriptorpb" |
|
) |
|
|
|
const ( |
|
contextPackage = protogen.GoImportPath("context") |
|
transportHTTPPackage = protogen.GoImportPath("github.com/go-kratos/kratos/v2/transport/http") |
|
bindingPackage = protogen.GoImportPath("github.com/go-kratos/kratos/v2/transport/http/binding") |
|
) |
|
|
|
var methodSets = make(map[string]int) |
|
|
|
// generateFile generates a _http.pb.go file containing kratos errors definitions. |
|
func generateFile(gen *protogen.Plugin, file *protogen.File, omitempty bool) *protogen.GeneratedFile { |
|
if len(file.Services) == 0 || (omitempty && !hasHTTPRule(file.Services)) { |
|
return nil |
|
} |
|
filename := file.GeneratedFilenamePrefix + "_http.pb.go" |
|
g := gen.NewGeneratedFile(filename, file.GoImportPath) |
|
g.P("// Code generated by protoc-gen-go-http. DO NOT EDIT.") |
|
g.P("// versions:") |
|
g.P(fmt.Sprintf("// - protoc-gen-go-http %s", release)) |
|
g.P("// - protoc ", protocVersion(gen)) |
|
if file.Proto.GetOptions().GetDeprecated() { |
|
g.P("// ", file.Desc.Path(), " is a deprecated file.") |
|
} else { |
|
g.P("// source: ", file.Desc.Path()) |
|
} |
|
g.P() |
|
g.P("package ", file.GoPackageName) |
|
g.P() |
|
generateFileContent(gen, file, g, omitempty) |
|
return g |
|
} |
|
|
|
// generateFileContent generates the kratos errors definitions, excluding the package statement. |
|
func generateFileContent(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, omitempty bool) { |
|
if len(file.Services) == 0 { |
|
return |
|
} |
|
g.P("// This is a compile-time assertion to ensure that this generated file") |
|
g.P("// is compatible with the kratos package it is being compiled against.") |
|
g.P("var _ = new(", contextPackage.Ident("Context"), ")") |
|
g.P("var _ = ", bindingPackage.Ident("EncodeURL")) |
|
g.P("const _ = ", transportHTTPPackage.Ident("SupportPackageIsVersion1")) |
|
g.P() |
|
|
|
for _, service := range file.Services { |
|
genService(gen, file, g, service, omitempty) |
|
} |
|
} |
|
|
|
func genService(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, service *protogen.Service, omitempty bool) { |
|
if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() { |
|
g.P("//") |
|
g.P(deprecationComment) |
|
} |
|
// HTTP Server. |
|
sd := &serviceDesc{ |
|
ServiceType: service.GoName, |
|
ServiceName: string(service.Desc.FullName()), |
|
Metadata: file.Desc.Path(), |
|
} |
|
for _, method := range service.Methods { |
|
if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() { |
|
continue |
|
} |
|
rule, ok := proto.GetExtension(method.Desc.Options(), annotations.E_Http).(*annotations.HttpRule) |
|
if rule != nil && ok { |
|
for _, bind := range rule.AdditionalBindings { |
|
sd.Methods = append(sd.Methods, buildHTTPRule(g, method, bind)) |
|
} |
|
sd.Methods = append(sd.Methods, buildHTTPRule(g, method, rule)) |
|
} else if !omitempty { |
|
path := fmt.Sprintf("/%s/%s", service.Desc.FullName(), method.Desc.Name()) |
|
sd.Methods = append(sd.Methods, buildMethodDesc(g, method, "POST", path)) |
|
} |
|
} |
|
if len(sd.Methods) != 0 { |
|
g.P(sd.execute()) |
|
} |
|
} |
|
|
|
func hasHTTPRule(services []*protogen.Service) bool { |
|
for _, service := range services { |
|
for _, method := range service.Methods { |
|
if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() { |
|
continue |
|
} |
|
rule, ok := proto.GetExtension(method.Desc.Options(), annotations.E_Http).(*annotations.HttpRule) |
|
if rule != nil && ok { |
|
return true |
|
} |
|
} |
|
} |
|
return false |
|
} |
|
|
|
func buildHTTPRule(g *protogen.GeneratedFile, m *protogen.Method, rule *annotations.HttpRule) *methodDesc { |
|
var ( |
|
path string |
|
method string |
|
body string |
|
responseBody string |
|
) |
|
|
|
switch pattern := rule.Pattern.(type) { |
|
case *annotations.HttpRule_Get: |
|
path = pattern.Get |
|
method = "GET" |
|
case *annotations.HttpRule_Put: |
|
path = pattern.Put |
|
method = "PUT" |
|
case *annotations.HttpRule_Post: |
|
path = pattern.Post |
|
method = "POST" |
|
case *annotations.HttpRule_Delete: |
|
path = pattern.Delete |
|
method = "DELETE" |
|
case *annotations.HttpRule_Patch: |
|
path = pattern.Patch |
|
method = "PATCH" |
|
case *annotations.HttpRule_Custom: |
|
path = pattern.Custom.Path |
|
method = pattern.Custom.Kind |
|
} |
|
body = rule.Body |
|
responseBody = rule.ResponseBody |
|
md := buildMethodDesc(g, m, method, path) |
|
if method == "GET" || method == "DELETE" { |
|
if body != "" { |
|
_, _ = fmt.Fprintf(os.Stderr, "\u001B[31mWARN\u001B[m: %s %s body should not be declared.\n", method, path) |
|
} |
|
} else { |
|
if body == "" { |
|
_, _ = fmt.Fprintf(os.Stderr, "\u001B[31mWARN\u001B[m: %s %s does not declare a body.\n", method, path) |
|
} |
|
} |
|
if body == "*" { |
|
md.HasBody = true |
|
md.Body = "" |
|
} else if body != "" { |
|
md.HasBody = true |
|
md.Body = "." + camelCaseVars(body) |
|
} else { |
|
md.HasBody = false |
|
} |
|
if responseBody == "*" { |
|
md.ResponseBody = "" |
|
} else if responseBody != "" { |
|
md.ResponseBody = "." + camelCaseVars(responseBody) |
|
} |
|
return md |
|
} |
|
|
|
func buildMethodDesc(g *protogen.GeneratedFile, m *protogen.Method, method, path string) *methodDesc { |
|
defer func() { methodSets[m.GoName]++ }() |
|
|
|
vars := buildPathVars(path) |
|
|
|
for v, s := range vars { |
|
fields := m.Input.Desc.Fields() |
|
|
|
if s != nil { |
|
path = replacePath(v, *s, path) |
|
} |
|
for _, field := range strings.Split(v, ".") { |
|
if strings.TrimSpace(field) == "" { |
|
continue |
|
} |
|
if strings.Contains(field, ":") { |
|
field = strings.Split(field, ":")[0] |
|
} |
|
fd := fields.ByName(protoreflect.Name(field)) |
|
if fd == nil { |
|
fmt.Fprintf(os.Stderr, "\u001B[31mERROR\u001B[m: The corresponding field '%s' declaration in message could not be found in '%s'\n", v, path) |
|
os.Exit(2) |
|
} |
|
if fd.IsMap() { |
|
fmt.Fprintf(os.Stderr, "\u001B[31mWARN\u001B[m: The field in path:'%s' shouldn't be a map.\n", v) |
|
} else if fd.IsList() { |
|
fmt.Fprintf(os.Stderr, "\u001B[31mWARN\u001B[m: The field in path:'%s' shouldn't be a list.\n", v) |
|
} else if fd.Kind() == protoreflect.MessageKind || fd.Kind() == protoreflect.GroupKind { |
|
fields = fd.Message().Fields() |
|
} |
|
} |
|
} |
|
return &methodDesc{ |
|
Name: m.GoName, |
|
OriginalName: string(m.Desc.Name()), |
|
Num: methodSets[m.GoName], |
|
Request: g.QualifiedGoIdent(m.Input.GoIdent), |
|
Reply: g.QualifiedGoIdent(m.Output.GoIdent), |
|
Path: path, |
|
Method: method, |
|
HasVars: len(vars) > 0, |
|
} |
|
} |
|
|
|
func buildPathVars(path string) (res map[string]*string) { |
|
if strings.HasSuffix(path, "/") { |
|
fmt.Fprintf(os.Stderr, "\u001B[31mWARN\u001B[m: Path %s should not end with \"/\" \n", path) |
|
} |
|
pattern := regexp.MustCompile(`(?i){([a-z\.0-9_\s]*)=?([^{}]*)}`) |
|
matches := pattern.FindAllStringSubmatch(path, -1) |
|
res = make(map[string]*string, len(matches)) |
|
for _, m := range matches { |
|
name := strings.TrimSpace(m[1]) |
|
if len(name) > 1 && len(m[2]) > 0 { |
|
res[name] = &m[2] |
|
} else { |
|
res[name] = nil |
|
} |
|
} |
|
return |
|
} |
|
|
|
func replacePath(name string, value string, path string) string { |
|
pattern := regexp.MustCompile(fmt.Sprintf(`(?i){([\s]*%s[\s]*)=?([^{}]*)}`, name)) |
|
idx := pattern.FindStringIndex(path) |
|
if len(idx) > 0 { |
|
path = fmt.Sprintf("%s{%s:%s}%s", |
|
path[:idx[0]], // The start of the match |
|
name, |
|
strings.ReplaceAll(value, "*", ".*"), |
|
path[idx[1]:], |
|
) |
|
} |
|
return path |
|
} |
|
|
|
func camelCaseVars(s string) string { |
|
subs := strings.Split(s, ".") |
|
vars := make([]string, 0, len(subs)) |
|
for _, sub := range subs { |
|
vars = append(vars, camelCase(sub)) |
|
} |
|
return strings.Join(vars, ".") |
|
} |
|
|
|
// camelCase returns the CamelCased name. |
|
// If there is an interior underscore followed by a lower case letter, |
|
// drop the underscore and convert the letter to upper case. |
|
// There is a remote possibility of this rewrite causing a name collision, |
|
// but it's so remote we're prepared to pretend it's nonexistent - since the |
|
// C++ generator lowercases names, it's extremely unlikely to have two fields |
|
// with different capitalizations. |
|
// In short, _my_field_name_2 becomes XMyFieldName_2. |
|
func camelCase(s string) string { |
|
if s == "" { |
|
return "" |
|
} |
|
t := make([]byte, 0, 32) |
|
i := 0 |
|
if s[0] == '_' { |
|
// Need a capital letter; drop the '_'. |
|
t = append(t, 'X') |
|
i++ |
|
} |
|
// Invariant: if the next letter is lower case, it must be converted |
|
// to upper case. |
|
// That is, we process a word at a time, where words are marked by _ or |
|
// upper case letter. Digits are treated as words. |
|
for ; i < len(s); i++ { |
|
c := s[i] |
|
if c == '_' && i+1 < len(s) && isASCIILower(s[i+1]) { |
|
continue // Skip the underscore in s. |
|
} |
|
if isASCIIDigit(c) { |
|
t = append(t, c) |
|
continue |
|
} |
|
// Assume we have a letter now - if not, it's a bogus identifier. |
|
// The next word is a sequence of characters that must start upper case. |
|
if isASCIILower(c) { |
|
c ^= ' ' // Make it a capital letter. |
|
} |
|
t = append(t, c) // Guaranteed not lower case. |
|
// Accept lower case sequence that follows. |
|
for i+1 < len(s) && isASCIILower(s[i+1]) { |
|
i++ |
|
t = append(t, s[i]) |
|
} |
|
} |
|
return string(t) |
|
} |
|
|
|
// Is c an ASCII lower-case letter? |
|
func isASCIILower(c byte) bool { |
|
return 'a' <= c && c <= 'z' |
|
} |
|
|
|
// Is c an ASCII digit? |
|
func isASCIIDigit(c byte) bool { |
|
return '0' <= c && c <= '9' |
|
} |
|
|
|
func protocVersion(gen *protogen.Plugin) string { |
|
v := gen.Request.GetCompilerVersion() |
|
if v == nil { |
|
return "(unknown)" |
|
} |
|
var suffix string |
|
if s := v.GetSuffix(); s != "" { |
|
suffix = "-" + s |
|
} |
|
return fmt.Sprintf("v%d.%d.%d%s", v.GetMajor(), v.GetMinor(), v.GetPatch(), suffix) |
|
} |
|
|
|
const deprecationComment = "// Deprecated: Do not use."
|
|
|