Skip to content

Commit

Permalink
fix:parse all field names declared in a row (#1872)
Browse files Browse the repository at this point in the history
* fix:parse all fields names declared in a row
  • Loading branch information
sdghchj authored Oct 17, 2024
1 parent 4fd8a36 commit a74d34c
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 51 deletions.
51 changes: 27 additions & 24 deletions field_parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,35 +65,38 @@ func (ps *tagBaseFieldParser) ShouldSkip() bool {
return false
}

func (ps *tagBaseFieldParser) FieldName() (string, error) {
var name string
func (ps *tagBaseFieldParser) FieldNames() ([]string, error) {
if len(ps.field.Names) <= 1 {
// if embedded but with a json/form name ??
if ps.field.Tag != nil {
// json:"tag,hoge"
name := strings.TrimSpace(strings.Split(ps.tag.Get(jsonTag), ",")[0])
if name != "" {
return []string{name}, nil
}

if ps.field.Tag != nil {
// json:"tag,hoge"
name = strings.TrimSpace(strings.Split(ps.tag.Get(jsonTag), ",")[0])
if name != "" {
return name, nil
// use "form" tag over json tag
name = ps.FormName()
if name != "" {
return []string{name}, nil
}
}

// use "form" tag over json tag
name = ps.FormName()
if name != "" {
return name, nil
if len(ps.field.Names) == 0 {
return nil, nil
}
}

if ps.field.Names == nil {
return "", nil
}

switch ps.p.PropNamingStrategy {
case SnakeCase:
return toSnakeCase(ps.field.Names[0].Name), nil
case PascalCase:
return ps.field.Names[0].Name, nil
default:
return toLowerCamelCase(ps.field.Names[0].Name), nil
var names = make([]string, 0, len(ps.field.Names))
for _, name := range ps.field.Names {
switch ps.p.PropNamingStrategy {
case SnakeCase:
names = append(names, toSnakeCase(name.Name))
case PascalCase:
names = append(names, name.Name)
default:
names = append(names, toLowerCamelCase(name.Name))
}
}
return names, nil
}

func (ps *tagBaseFieldParser) firstTagValue(tag string) string {
Expand Down
51 changes: 36 additions & 15 deletions field_parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -676,9 +676,11 @@ func TestValidTags(t *testing.T) {
schema.Type = []string{"integer"}
err = newTagBaseFieldParser(
&Parser{},
&ast.Field{Tag: &ast.BasicLit{
Value: `json:"test" validate:"required,oneof=one two"`,
}},
&ast.Field{
Names: []*ast.Ident{{Name: "Test"}},
Tag: &ast.BasicLit{
Value: `json:"test" validate:"required,oneof=one two"`,
}},
).ComplementSchema(&schema)
assert.NoError(t, err)
assert.Empty(t, schema.Enum)
Expand All @@ -687,22 +689,41 @@ func TestValidTags(t *testing.T) {
t.Run("Form Filed Name", func(t *testing.T) {
t.Parallel()

filedname, err := newTagBaseFieldParser(
filednames, err := newTagBaseFieldParser(
&Parser{},
&ast.Field{Tag: &ast.BasicLit{
Value: `form:"test[]"`,
}},
).FieldName()
&ast.Field{
Names: []*ast.Ident{{Name: "Test"}},
Tag: &ast.BasicLit{
Value: `form:"test[]"`,
}},
).FieldNames()
assert.NoError(t, err)
assert.Equal(t, "test", filedname)
assert.Equal(t, "test", filednames[0])

filedname, err = newTagBaseFieldParser(
filednames, err = newTagBaseFieldParser(
&Parser{},
&ast.Field{Tag: &ast.BasicLit{
Value: `form:"test"`,
}},
).FieldName()
&ast.Field{
Names: []*ast.Ident{{Name: "Test"}},
Tag: &ast.BasicLit{
Value: `form:"test"`,
}},
).FieldNames()
assert.NoError(t, err)
assert.Equal(t, "test", filednames[0])
})

t.Run("Two Names", func(t *testing.T) {
t.Parallel()

fieldnames, err := newTagBaseFieldParser(
&Parser{},
&ast.Field{
Names: []*ast.Ident{{Name: "X"}, {Name: "Y"}},
},
).FieldNames()
assert.NoError(t, err)
assert.Equal(t, "test", filedname)
assert.Equal(t, 2, len(fieldnames))
assert.Equal(t, "x", fieldnames[0])
assert.Equal(t, "y", fieldnames[1])
})
}
27 changes: 15 additions & 12 deletions parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ type FieldParserFactory func(ps *Parser, field *ast.Field) FieldParser
// FieldParser parse struct field.
type FieldParser interface {
ShouldSkip() bool
FieldName() (string, error)
FieldNames() ([]string, error)
FormName() string
HeaderName() string
PathName() string
Expand Down Expand Up @@ -1527,20 +1527,20 @@ func (parser *Parser) parseStructField(file *ast.File, field *ast.Field) (map[st
return nil, nil, nil
}

fieldName, err := ps.FieldName()
fieldNames, err := ps.FieldNames()
if err != nil {
return nil, nil, err
}

if fieldName == "" {
if len(fieldNames) == 0 {
typeName, err := getFieldType(file, field.Type, nil)
if err != nil {
return nil, nil, fmt.Errorf("%s: %w", fieldName, err)
return nil, nil, err
}

schema, err := parser.getTypeSchema(typeName, file, false)
if err != nil {
return nil, nil, fmt.Errorf("%s: %w", fieldName, err)
return nil, nil, err
}

if len(schema.Type) > 0 && schema.Type[0] == OBJECT {
Expand All @@ -1562,7 +1562,7 @@ func (parser *Parser) parseStructField(file *ast.File, field *ast.Field) (map[st

schema, err := ps.CustomSchema()
if err != nil {
return nil, nil, fmt.Errorf("%s: %w", fieldName, err)
return nil, nil, fmt.Errorf("%v: %w", fieldNames, err)
}

if schema == nil {
Expand All @@ -1576,24 +1576,24 @@ func (parser *Parser) parseStructField(file *ast.File, field *ast.Field) (map[st
}

if err != nil {
return nil, nil, fmt.Errorf("%s: %w", fieldName, err)
return nil, nil, fmt.Errorf("%v: %w", fieldNames, err)
}
}

err = ps.ComplementSchema(schema)
if err != nil {
return nil, nil, fmt.Errorf("%s: %w", fieldName, err)
return nil, nil, fmt.Errorf("%v: %w", fieldNames, err)
}

var tagRequired []string

required, err := ps.IsRequired()
if err != nil {
return nil, nil, fmt.Errorf("%s: %w", fieldName, err)
return nil, nil, fmt.Errorf("%v: %w", fieldNames, err)
}

if required {
tagRequired = append(tagRequired, fieldName)
tagRequired = append(tagRequired, fieldNames...)
}

if schema.Extensions == nil {
Expand All @@ -1608,8 +1608,11 @@ func (parser *Parser) parseStructField(file *ast.File, field *ast.Field) (map[st
if pathName := ps.PathName(); len(pathName) > 0 {
schema.Extensions["path"] = pathName
}

return map[string]spec.Schema{fieldName: *schema}, tagRequired, nil
fields := make(map[string]spec.Schema)
for _, name := range fieldNames {
fields[name] = *schema
}
return fields, tagRequired, nil
}

func getFieldType(file *ast.File, field ast.Expr, genericParamTypeDefs map[string]*genericTypeSpec) (string, error) {
Expand Down

0 comments on commit a74d34c

Please sign in to comment.