gpt4 book ai didi

Gorm源码学习-创建行记录

转载 作者:我是一只小鸟 更新时间:2022-12-17 22:33:32 25 4
gpt4 key购买 nike

1. 前言

Gorm源码学习系列 。

  • Gorm源码学习-数据库连接

此文是Gorm源码学习系列的第二篇,主要梳理下通过Gorm创建表的流程.

  。

2. 创建行记录代码示例

gorm提供了以下几个接口来创建行记录 。

  • 一次创建一行  func (db *DB) Create(value interface{}) (tx *DB)
  • 批量创建  func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB)
  • 数据库不存在主键时创建,存在时更新  func (db *DB) Save(value interface{}) (tx *DB)

详细请看 教程 及源码 finisher_api.go ,这里使用 func (db *DB) Create(value interface{}) (tx *DB) 来说明创建行记录等大致流程.

  。

2.1 声明模型

                      
                        type Stu struct {
	ID     int64 `gorm:"column:id; primary_key" json:"id"`
	Age    int64 `gorm:"column:age;"`
	Height int64 `gorm:"column:height;"`
	Weight int64 `gorm:"column:weight;"`
}

// 设置表名
func (Stu) TableName() string {
	return "t_student"
}
                      
                    

模型代码的主要用途如下, 。

  • 申明的表中有哪些列及每列的名称、特性等,如 gorm 标签指定每个字断对于的表的列名
  • 通过实现 Tabler 接口指定了固定的表名,接口定义如下
                      
                        type Tabler interface {
	TableName() string
}
                      
                    

关于模型定义中更多的约定和约束等,请看 教程 .

出于分表等业务场景,我们并不希望固定模型等表名,gorm提供了 func (db *DB) Table(name string, args ...interface{}) (tx *DB) 等方法 。

来动态指定表名,详情请看 教程 .

  。

2.2 创建行

                      
                        func main() {
	// 数据库连接, 具体查看https://www.cnblogs.com/amos01/p/16890747.html 连接数据库代码示例
	db, _ := dbOpen()
	// 打开调试模式、会打印DML
	db = db.Debug()
	stu := &Stu{
		Age:    18,
		Height: 185,
		Weight: 70,
	}
	db = db.Create(stu)
	fmt.Printf("Error:%v ID:%v RowsAffected:%v\n", db.Error, stu.ID, db.RowsAffected)
}
                      
                    

代码输出如下 。

                      
                        $ go run main.go
2022/12/11 14:59:59 /Users/zbw/workspace/test/main.go:33
[1.910ms] [rows:1] INSERT INTO `t_student` (`age`,`height`,`weight`) VALUES (18,185,70)
Error:<nil> ID:1027 RowsAffected:1
                      
                    

从代码输出可以看,行记录的ID为1027,连接数据库查询,结果如下.

                      
                        mysql> select * from t_student where id = 1027\G
*************************** 1. row ***************************
    id: 1027
   age: 18
height: 185
weight: 70
1 row in set (0.01 sec)
                      
                    

因此,我们带着以下问题来梳理下Gorm创建行记录的流程 。

  • 如何从model到DML语句的
  • 如何将ID写入到model的 

  。

3. 从Model到DML

func (db *DB) Create(value interface{}) (tx *DB) 的实现如下 。

                      
                        // Create inserts value, returning the inserted data's primary key in value's id
func (db *DB) Create(value interface{}) (tx *DB) {
	if db.CreateBatchSize > 0 {
		return db.CreateInBatches(value, db.CreateBatchSize)
	}

	tx = db.getInstance()
	tx.Statement.Dest = value
	return tx.callbacks.Create().Execute(tx)
}
                      
                    

func (p *processor) Execute(db *DB) *DB  的实现比较长,具体代码见 github 。

总结下来,做了两件主要的事情, 。

  • 解析model获取表名、每列的定义等
  • 执行钩子函数以及创建行函数

X 。

3.1 数据结构理解

  • gorm.Statement
查看gorm.Statement代码
                        
                          // Statement statement
type Statement struct {
	*DB
	TableExpr            *clause.Expr
	Table                string      // 表名
	Model                interface{} // model定义
	Unscoped             bool
	Dest                 interface{} // model的另外一种表达,如map
	ReflectValue         reflect.Value
	Clauses              map[string]clause.Clause
	BuildClauses         []string
	Distinct             bool
	Selects              []string // selected columns
	Omits                []string // omit columns
	Joins                []join
	Preloads             map[string][]interface{}
	Settings             sync.Map
	ConnPool             ConnPool       // 数据库连接
	Schema               *schema.Schema // 表结构化信息
	Context              context.Context
	RaiseErrorOnNotFound bool
	SkipHooks            bool
	SQL                  strings.Builder // 最终的DML语句
	Vars                 []interface{}   // DML语句的参数值
	CurDestIndex         int             // 批量创建/更新时,gorm当前操作的数组/slice的下标
	attrs                []interface{}
	assigns              []interface{}
	scopes               []func(*DB) *DB
}
                        
                      
  • schema.Schem
查看schema.Schema代码
                        
                          type Schema struct {
	Name                      string
	ModelType                 reflect.Type
	Table                     string // 表名
	PrioritizedPrimaryField   *Field
	DBNames                   []string // 表每列的名字
	PrimaryFields             []*Field
	PrimaryFieldDBNames       []string // 表的主键列明
	Fields                    []*Field // gorm自定义的model每个字短
	FieldsByName              map[string]*Field
	FieldsByDBName            map[string]*Field
	FieldsWithDefaultDBValue  []*Field // fields with default value assigned by database
	Relationships             Relationships
	CreateClauses             []clause.Interface // 创建行的子句
	QueryClauses              []clause.Interface
	UpdateClauses             []clause.Interface
	DeleteClauses             []clause.Interface
	BeforeCreate, AfterCreate bool
	BeforeUpdate, AfterUpdate bool
	BeforeDelete, AfterDelete bool
	BeforeSave, AfterSave     bool
	AfterFind                 bool
	err                       error
	initialized               chan struct{}
	namer                     Namer
	cacheStore                *sync.Map
}
                        
                      
  • schema.Field
查看schema.Field代码
                        
                          // Field is the representation of model schema's field
type Field struct {
	Name                   string // model的字段名
	DBName                 string // 对应表的列名
	BindNames              []string
	DataType               DataType
	GORMDataType           DataType
	PrimaryKey             bool
	AutoIncrement          bool
	AutoIncrementIncrement int64
	Creatable              bool
	Updatable              bool
	Readable               bool
	AutoCreateTime         TimeType
	AutoUpdateTime         TimeType
	HasDefaultValue        bool
	DefaultValue           string
	DefaultValueInterface  interface{}
	NotNull                bool
	Unique                 bool
	Comment                string
	Size                   int
	Precision              int
	Scale                  int
	IgnoreMigration        bool
	FieldType              reflect.Type // 反射类型
	IndirectFieldType      reflect.Type // 反射类型
	StructField            reflect.StructField // model字段信息
	Tag                    reflect.StructTag // tag
	TagSettings            map[string]string
	Schema                 *Schema
	EmbeddedSchema         *Schema
	OwnerSchema            *Schema
	ReflectValueOf         func(context.Context, reflect.Value) reflect.Value                  // 通过反射获取该字段的反射对象
	ValueOf                func(context.Context, reflect.Value) (value interface{}, zero bool) // 通过反射获取该字段的值 get方法
	Set                    func(context.Context, reflect.Value, interface{}) error             // 通过反射设置该字段的值 set方法
	Serializer             SerializerInterface
	NewValuePool           FieldNewValuePool
}
                        
                      
  • clause.Interface clause.Clause

gorm定义了多种clause,包括 。

查看clause.Interface代码
                        
                          // Interface clause interface
type Interface interface {
	Name() string
	Build(Builder)
	MergeClause(*Clause)
}
                        
                      
查看clause.Clause代码
                        
                          // Clause
type Clause struct {
	Name                string // WHERE
	BeforeExpression    Expression
	AfterNameExpression Expression
	AfterExpression     Expression
	Expression          Expression
	Builder             ClauseBuilder
}
                        
                      

  。

3.2 解析Model

通过调用 stmt.Parse(stmt.Model) 进行model解析 。

stmt.Parse(stmt.Model) 会调用到函数 func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Namer, specialTableName string) (*Schema, error) 进行解析.

详细代码见 schema.go ,下面列举重要的几个点.

  • 通过反射判断 dest interface{} 是否为 reflect.Struct
  • 通过接口获取表名,其中stu实现了 Tabler 接口
                      
                        	// 获取表名
	modelValue := reflect.New(modelType)
	tableName := namer.TableName(modelType.Name())
	if tabler, ok := modelValue.Interface().(Tabler); ok {
		tableName = tabler.TableName()
	}
	if tabler, ok := modelValue.Interface().(TablerWithNamer); ok {
		tableName = tabler.TableName(namer)
	}
	if en, ok := namer.(embeddedNamer); ok {
		tableName = en.Table
	}
	if specialTableName != "" && specialTableName != tableName {
		tableName = specialTableName
	}
                      
                    
  • 解析model每个字段
                      
                        // 通过反射获取每个字段
for i := 0; i < modelType.NumField(); i++ {
    if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) {
        // 解析每个字段
        if field := schema.ParseField(fieldStruct); field.EmbeddedSchema != nil {
            schema.Fields = append(schema.Fields, field.EmbeddedSchema.Fields...)
        } else {
            schema.Fields = append(schema.Fields, field)
        }
    }
}
                      
                    
  • 放到map方便查找,并且通过 func (field *Field) setupValuerAndSetter() 初始化每个Field的 ReflectValueOf ValueOf Set 方法。
                              
                                for _, field := range schema.Fields {
        if field.DBName == "" && field.DataType != "" {
            field.DBName = namer.ColumnName(schema.Table, field.Name)
        }
        if field.DBName != "" {
            // nonexistence or shortest path or first appear prioritized if has permission
            if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) {
                if _, ok := schema.FieldsByDBName[field.DBName]; !ok {
                    schema.DBNames = append(schema.DBNames, field.DBName)
                }
                // gorm tag字段到field的映射
                schema.FieldsByDBName[field.DBName] = field
                // model 字段到field的映射
                schema.FieldsByName[field.Name] = field
                if v != nil && v.PrimaryKey {
                    for idx, f := range schema.PrimaryFields {
                        if f == v {
                            schema.PrimaryFields = append(schema.PrimaryFields[0:idx], schema.PrimaryFields[idx+1:]...)
                        }
                    }
                }
                // 主键
                if field.PrimaryKey {
                    schema.PrimaryFields = append(schema.PrimaryFields, field)
                }
            }
        }
        if of, ok := schema.FieldsByName[field.Name]; !ok || of.TagSettings["-"] == "-" {
            schema.FieldsByName[field.Name] = field
        }
        // 挂载字段的set方法和get方法
        field.setupValuerAndSetter()
    }
                              
                            

  。

值得一提的是,每个model解析后的结果是一致,可以将结果解析的结构缓存下来,并且通过 chan 来解决并发的问题.

解析model之后,通过 process 获取到钩子函数及创建行的函数,具体代码见 Github 。

                      
                        	for _, f := range p.fns {
		f(db)
	}
                      
                    

  。

3.3 执行钩子函数及创建行的函数

创建行的函数及对应的钩子函数位于 create.go 。

  • 创建行记录
                      
                        if db.Statement.SQL.Len() == 0 {
    db.Statement.SQL.Grow(180)
    db.Statement.AddClauseIfNotExists(clause.Insert{})
    db.Statement.AddClause(ConvertToCreateValues(db.Statement))

    db.Statement.Build(db.Statement.BuildClauses...)
}
                      
                    

 这里插入两个 clause.Clause ,分别为 clause.Insert 以及 clause.Values ,然后调用这两种 clause.Clause 的 build 方法生成 SQL 语句.

首先,看下 ConvertToCreateValues 的实现,这里只截取部分代码 。

                      
                        values = clause.Values{Columns: make([]clause.Column, 0, len(stmt.Schema.DBNames))}
// 获取每一列的名字
for _, db := range stmt.Schema.DBNames {
    if field := stmt.Schema.FieldsByDBName[db]; !field.HasDefaultValue || field.DefaultValueInterface != nil {
	    if v, ok := selectColumns[db]; (ok && v) || (!ok && (!restricted || field.AutoCreateTime > 0 || field.AutoUpdateTime > 0)) {
		    values.Columns = append(values.Columns, clause.Column{Name: db})
		}
	}
}

// 获取每一列对应的值
switch stmt.ReflectValue.Kind() {
    case reflect.Slice, reflect.Array:
    case reflect.Struct:
        values.Values = [][]interface{}{make([]interface{}, len(values.Columns))}
        for idx, column := range values.Columns {
            field := stmt.Schema.FieldsByDBName[column.Name]
            // func (field *Field) setupValuerAndSetter() 挂载的方法
            if values.Values[0][idx], isZero = field.ValueOf(stmt.Context, stmt.ReflectValue); isZero {
                if field.DefaultValueInterface != nil {
                    values.Values[0][idx] = field.DefaultValueInterface
                    stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, field.DefaultValueInterface))
                } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
                    stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime))
                    values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue)
                }
            } else if field.AutoUpdateTime > 0 && updateTrackTime {
                stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime))
                values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue)
            }
        }
                      
                    

通过 ConvertToCreateValues 获取了每一列的名称及对应的值.

接下来,看 clause.Clause 到 SQL 语句的过程.

遍历加入 clause ,此时分别为 clause.Insert 以及 clause.Values 。

                      
                        // Build build sql with clauses names
func (stmt *Statement) Build(clauses ...string) {
    var firstClauseWritten bool
    for _, name := range clauses {
        if c, ok := stmt.Clauses[name]; ok {
            // 代码有删减
            c.Build(stmt)
        }
    }
}
                      
                    

接着调用 func (c Clause) Build(builder Builder) 。

                      
                        // Build build clause 
func (c Clause) Build(builder Builder) {
    // 有删减
    // c为clause.Insert以及clause.Values
    if c.Name != "" {
        // builder写入 INSERT 或者 VALUES
        builder.WriteString(c.Name)
        builder.WriteByte(' ')
    }
    // 通过clause.Insert以及clause.Values的MergeClause函数,c.Expression为clause.Insert以及clause.Values
    // 因此,这里调用clause.Insert或者clause.Values的Build的方法
    c.Expression.Build(builder)
}
                      
                    

接下来分别看 clause.Insert 以及 clause.Values 。

                      
                        // Build build insert clause
func (insert Insert) Build(builder Builder) {
    // builder写入INTO,此时builder为INSERT INTO
    builder.WriteString("INTO ")
    // builder写入表名
    builder.WriteQuoted(currentTable)
}
                      
                    

从调用的链路可以得出,这里 builder 为 stmt *Statement ,并且 currentTable 类型为 clause.Table ,因此 。

                      
                        // WriteQuoted write quoted value
func (stmt *Statement) WriteQuoted(value interface{}) {
    stmt.QuoteTo(&stmt.SQL, value)
}

// QuoteTo write quoted value to writer 有删减
func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) {
    write := func(raw bool, str string) {
        // mysql驱动Dialector
        stmt.DB.Dialector.QuoteTo(writer, str)
    }
    switch v := field.(type) {
    case clause.Table:
        write(v.Raw, stmt.Table)
    }
}
                      
                    

至此, builder 已经拼装出 INSERT INTO `t_student`  ,解析来再看 clause.Values 的 build 方法 。

                      
                        // Build build from clause
func (values Values) Build(builder Builder) {
    if len(values.Columns) > 0 {
        builder.WriteByte('(')
        for idx, column := range values.Columns {
            if idx > 0 {
                builder.WriteByte(',')
            }
            builder.WriteQuoted(column)
        }
        builder.WriteByte(')')
        builder.WriteString(" VALUES ")
        for idx, value := range values.Values {
            if idx > 0 {
                builder.WriteByte(',')
            }
            builder.WriteByte('(')
            builder.AddVar(builder, value...)
            builder.WriteByte(')')
        }
    } else {
        builder.WriteString("DEFAULT VALUES")
    }
}
                      
                    

func (values Values) Build(builder Builder) 取出所有列名和列对应的值 。

最终 builder 拼装成例子的完整SQL语句 INSERT INTO `t_student` (`age`,`height`,`weight`) VALUES (18,185,70) 。

有了SQL语句,就可以执行了 。

                      
                        result, err := db.Statement.ConnPool.ExecContext(
    db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...,
)
                      
                    

通过前一面学习, db.Statement.ConnPool 的值为 sql.DB ,实际执行的函数为 func (db *DB) ExecContext(ctx context.Context, query string, args ...any) (Result, error) 。

至此,从Model到DML到流程已经完成.

  。

4. 将ID写入到model的 

看返回参数 sql.Result ,因此通过 LastInsertId() (int64, error) 可以获取到插入行的ID值.

                      
                        // A Result summarizes an executed SQL command.
type Result interface {
	// LastInsertId returns the integer generated by the database
	// in response to a command. Typically this will be from an
	// "auto increment" column when inserting a new row. Not all
	// databases support this feature, and the syntax of such
	// statements varies.
	LastInsertId() (int64, error)

	// RowsAffected returns the number of rows affected by an
	// update, insert, or delete. Not every database or database
	// driver may support this.
	RowsAffected() (int64, error)
}
                      
                    

获取到刚插入的行ID值,再通过反射写入model的ID字段即可.

                      
                        db.RowsAffected, _ = result.RowsAffected()
if db.RowsAffected != 0 && db.Statement.Schema != nil &&
    db.Statement.Schema.PrioritizedPrimaryField != nil &&
    db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
    insertID, err := result.LastInsertId()
    switch db.Statement.ReflectValue.Kind() {
    case reflect.Struct:
        _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
        if isZero {
            // 通过反射更新ID
            db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID))
        }
    }
}
                      
                    

  。

5. 总结

使用反射解析Model,获得每个成员对应的表的列名、值等信息.

定义SQL各个关键词如 INSERT 、 VALUES 、 FROM 、 DELETE 的结构体,并实现 clause.Interface 接口 。

进而对SQL语句的构造进行抽象封装.

  。

最后此篇关于Gorm源码学习-创建行记录的文章就讲到这里了,如果你想了解更多关于Gorm源码学习-创建行记录的内容请搜索CFSDN的文章或继续浏览相关文章,希望大家以后支持我的博客! 。

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com