04GORM源碼解讀

簡介

GORM 源碼解讀, 基於 v1.9.11 版本.html

查詢

上一節中, 咱們已經探究過了模型是如何定義的, 以及數據表是如何建立的. 此次, 看一下查詢是如何實現的.mysql

查詢涉及到很大的一塊內容, 由於要支持各類類型的方法. 先看一下官方文檔中提供的最簡單的幾個查詢方法.git

// 根據主鍵查詢第一條記錄
db.First(&user)
//// SELECT * FROM users ORDER BY id LIMIT 1;

// 隨機獲取一條記錄
db.Take(&user)
//// SELECT * FROM users LIMIT 1;

// 根據主鍵查詢最後一條記錄
db.Last(&user)
//// SELECT * FROM users ORDER BY id DESC LIMIT 1;

// 查詢全部的記錄
db.Find(&users)
//// SELECT * FROM users;

// 查詢指定的某條記錄(僅當主鍵爲整型時可用)
db.First(&user, 10)
//// SELECT * FROM users WHERE id = 10;
複製代碼

First 方法爲例, 看一下它的實現:github

// First find first record that match given conditions, order by primary key
func (s *DB) First(out interface{}, where ...interface{}) *DB {
	newScope := s.NewScope(out)
	newScope.Search.Limit(1)

	return newScope.Set("gorm:order_by_primary_key", "ASC").
		inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
}
複製代碼

First 方法從數據庫中獲取第一條數據, 以 primary key 升序排序.sql

前面介紹過, 具體的數據庫操做實現是依靠 callbacks 的. 這裏用到了 callbacks.queries.數據庫

在默認的 callbacks 中, 註冊了三個不一樣的 query 回調函數.express

// Define callbacks for querying
func init() {
	DefaultCallback.Query().Register("gorm:query", queryCallback)
	DefaultCallback.Query().Register("gorm:preload", preloadCallback)
	DefaultCallback.Query().Register("gorm:after_query", afterQueryCallback)
}
複製代碼

查詢流程

先來看一下最主要的 queryCallback 函數.app

// queryCallback used to query data from database
func queryCallback(scope *Scope) {
	if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip {
		return
	}

	//we are only preloading relations, dont touch base model
	if _, skip := scope.InstanceGet("gorm:only_preload"); skip {
		return
	}

	defer scope.trace(scope.db.nowFunc())

	var (
		isSlice, isPtr bool
		resultType     reflect.Type
		results        = scope.IndirectValue()
	)

	if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok {
		if primaryField := scope.PrimaryField(); primaryField != nil {
			scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), scope.Quote(primaryField.DBName), orderBy))
		}
	}

	if value, ok := scope.Get("gorm:query_destination"); ok {
		results = indirect(reflect.ValueOf(value))
	}

	if kind := results.Kind(); kind == reflect.Slice {
		isSlice = true
		resultType = results.Type().Elem()
		results.Set(reflect.MakeSlice(results.Type(), 0, 0))

		if resultType.Kind() == reflect.Ptr {
			isPtr = true
			resultType = resultType.Elem()
		}
	} else if kind != reflect.Struct {
		scope.Err(errors.New("unsupported destination, should be slice or struct"))
		return
	}

	scope.prepareQuerySQL()

	if !scope.HasError() {
		scope.db.RowsAffected = 0
		if str, ok := scope.Get("gorm:query_option"); ok {
			scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str))
		}

		if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
			defer rows.Close()

			columns, _ := rows.Columns()
			for rows.Next() {
				scope.db.RowsAffected++

				elem := results
				if isSlice {
					elem = reflect.New(resultType).Elem()
				}

				scope.scan(rows, columns, scope.New(elem.Addr().Interface()).Fields())

				if isSlice {
					if isPtr {
						results.Set(reflect.Append(results, elem.Addr()))
					} else {
						results.Set(reflect.Append(results, elem))
					}
				}
			}

			if err := rows.Err(); err != nil {
				scope.Err(err)
			} else if scope.db.RowsAffected == 0 && !isSlice {
				scope.Err(ErrRecordNotFound)
			}
		}
	}
}
複製代碼

核心的步驟在於 scope.prepareQuerySQL() 構建 SQL 語句. 而後經過 rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...), 執行了數據庫查詢.函數

那麼查詢到的結果是如何傳遞的, 傳遞給誰呢?ui

函數的開頭定義了 results = scope.IndirectValue(), 這就是最終查詢結果的歸屬地.

results 只能是結構體或者是結構體的切片.

if kind := results.Kind(); kind == reflect.Slice {
  isSlice = true
  resultType = results.Type().Elem()
  results.Set(reflect.MakeSlice(results.Type(), 0, 0))

  if resultType.Kind() == reflect.Ptr {
    isPtr = true
    resultType = resultType.Elem()
  }
} else if kind != reflect.Struct {
  scope.Err(errors.New("unsupported destination, should be slice or struct"))
  return
}
複製代碼

具體如何處理查詢到的結果是在下面這部分代碼中:

columns, _ := rows.Columns()
for rows.Next() {
  scope.db.RowsAffected++

  elem := results
  if isSlice {
    elem = reflect.New(resultType).Elem()
  }

  scope.scan(rows, columns, scope.New(elem.Addr().Interface()).Fields())

  if isSlice {
    if isPtr {
      results.Set(reflect.Append(results, elem.Addr()))
    } else {
      results.Set(reflect.Append(results, elem))
    }
  }
}
複製代碼

這部分代碼的核心語句在於 scope.scan, 看一下這個方法的定義:

func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) {
	var (
		ignored            interface{}
		values             = make([]interface{}, len(columns))
		selectFields       []*Field
		selectedColumnsMap = map[string]int{}
		resetFields        = map[int]*Field{}
	)

	for index, column := range columns {
		values[index] = &ignored

		selectFields = fields
		offset := 0
		if idx, ok := selectedColumnsMap[column]; ok {
			offset = idx + 1
			selectFields = selectFields[offset:]
		}

		for fieldIndex, field := range selectFields {
			if field.DBName == column {
				if field.Field.Kind() == reflect.Ptr {
					values[index] = field.Field.Addr().Interface()
				} else {
					reflectValue := reflect.New(reflect.PtrTo(field.Struct.Type))
					reflectValue.Elem().Set(field.Field.Addr())
					values[index] = reflectValue.Interface()
					resetFields[index] = field
				}

				selectedColumnsMap[column] = offset + fieldIndex

				if field.IsNormal {
					break
				}
			}
		}
	}

	scope.Err(rows.Scan(values...))

	for index, field := range resetFields {
		if v := reflect.ValueOf(values[index]).Elem().Elem(); v.IsValid() {
			field.Field.Set(v)
		}
	}
}
複製代碼

就和它的名字暗示的那樣, 實際上就是調用了 rows.Scan(values...), 將查詢到的數據複製到對應的字段中.

由此, 咱們就瞭解了查詢時的主要流程了.

前面專一於流程, 略過了構建 SQL 語句的細節, 來仔細看看 prepareQuerySQL 方法.

構建查詢 SQL 語句

func (scope *Scope) prepareQuerySQL() {
	if scope.Search.raw {
		scope.Raw(scope.CombinedConditionSql())
	} else {
		scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSQL(), scope.QuotedTableName(), scope.CombinedConditionSql()))
	}
	return
}
複製代碼

內部分支中都使用到了 scope.Raw, 看一下它的實現:

// Raw set raw sql
func (scope *Scope) Raw(sql string) *Scope {
	scope.SQL = strings.Replace(sql, "$$$", "?", -1)
	return scope
}
複製代碼

它的做用是將獲取到的 sql 語句賦值到 scope.SQL 字段上, 其中替換了全部的 $$$?.

回到 prepareQuerySQL 上來, 重要的部分是實際上是 Raw 的參數. if 的後半部分更好理解點, 就是構建了 SELECT 表達式.

SELECT 表達式須要三個變量, 字段名, 表名, 條件.

將每一個都看一下吧.

func (scope *Scope) selectSQL() string {
	if len(scope.Search.selects) == 0 {
		if len(scope.Search.joinConditions) > 0 {
			return fmt.Sprintf("%v.*", scope.QuotedTableName())
		}
		return "*"
	}
	return scope.buildSelectQuery(scope.Search.selects)
}

func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) {
	switch value := clause["query"].(type) {
	case string:
		str = value
	case []string:
		str = strings.Join(value, ", ")
	}

	args := clause["args"].([]interface{})
	replacements := []string{}
	for _, arg := range args {
		switch reflect.ValueOf(arg).Kind() {
		case reflect.Slice:
			values := reflect.ValueOf(arg)
			var tempMarks []string
			for i := 0; i < values.Len(); i++ {
				tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
			}
			replacements = append(replacements, strings.Join(tempMarks, ","))
		default:
			if valuer, ok := interface{}(arg).(driver.Valuer); ok {
				arg, _ = valuer.Value()
			}
			replacements = append(replacements, scope.AddToVars(arg))
		}
	}

	buff := bytes.NewBuffer([]byte{})
	i := 0
	for pos, char := range str {
		if str[pos] == '?' {
			buff.WriteString(replacements[i])
			i++
		} else {
			buff.WriteRune(char)
		}
	}

	str = buff.String()

	return
}
複製代碼

scope.Search.selects 爲空的時候, 比較簡單. 只要根據是否有連表查詢, 返回 table.**.

buildSelectQuery 就是根據 scope.Search.selects 構建查詢字段名.

前面半部分一看就明白.

switch value := clause["query"].(type) {
case string:
  str = value
case []string:
  str = strings.Join(value, ", ")
}
複製代碼

重點是遇到參數時如何處理, 也就是後半段代碼.

args := clause["args"].([]interface{})
replacements := []string{}
for _, arg := range args {
  switch reflect.ValueOf(arg).Kind() {
  case reflect.Slice:
    values := reflect.ValueOf(arg)
    var tempMarks []string
    for i := 0; i < values.Len(); i++ {
      tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
    }
    replacements = append(replacements, strings.Join(tempMarks, ","))
  default:
    if valuer, ok := interface{}(arg).(driver.Valuer); ok {
      arg, _ = valuer.Value()
    }
    replacements = append(replacements, scope.AddToVars(arg))
  }
}

buff := bytes.NewBuffer([]byte{})
i := 0
for pos, char := range str {
  if str[pos] == '?' {
    buff.WriteString(replacements[i])
    i++
  } else {
    buff.WriteRune(char)
  }
}
複製代碼

主要的過程是遍歷 args := clause["args"].([]interface{}), 建立了一個 replacements 切片. 而後將 str 中全部的 ?, 替換爲了對應的字段.

到此, 構建 SELECT 字段的過程就結束了.

獲取表名的過程相對簡單, 直接展現代碼吧:

// QuotedTableName return quoted table name
func (scope *Scope) QuotedTableName() (name string) {
	if scope.search != nil && len(scope.Search.tableName) > 0 {
		if strings.Contains(scope.Search.tableName, " ") {
			return scope.Search.tableName
		}
		return scope.Quote(scope.Search.tableName)
	}

	return scope.Quote(scope.TableName())
}
複製代碼

條件語句

更多的關注點在於如何構建篩選條件, 即 CombinedConditionSql 方法.

// CombinedConditionSql return combined condition sql
func (scope *Scope) CombinedConditionSql() string {
	joinSQL := scope.joinsSQL()
	whereSQL := scope.whereSQL()
	if scope.Search.raw {
		whereSQL = strings.TrimSuffix(strings.TrimPrefix(whereSQL, "WHERE ("), ")")
	}
	return joinSQL + whereSQL + scope.groupSQL() +
		scope.havingSQL() + scope.orderSQL() + scope.limitAndOffsetSQL()
}
複製代碼

短小的代碼中是精簡的邏輯, 條件語句有不少模塊, 這裏總共有 6 個子句. 都看一遍吧, 看完以後應該對如何構建條件語句不會陌生了.

func (scope *Scope) joinsSQL() string {
	var joinConditions []string
	for _, clause := range scope.Search.joinConditions {
		if sql := scope.buildCondition(clause, true); sql != "" {
			joinConditions = append(joinConditions, strings.TrimSuffix(strings.TrimPrefix(sql, "("), ")"))
		}
	}

	return strings.Join(joinConditions, " ") + " "
}
複製代碼

建立 joinSQL 的過程當中主要用到了 buildCondition, 繼續深刻:

func (scope *Scope) buildCondition(clause map[string]interface{}, include bool) (str string) {
	var (
		quotedTableName  = scope.QuotedTableName()
		quotedPrimaryKey = scope.Quote(scope.PrimaryKey())
		equalSQL         = "="
		inSQL            = "IN"
	)

	// If building not conditions
	if !include {
		equalSQL = "<>"
		inSQL = "NOT IN"
	}

	switch value := clause["query"].(type) {
	case sql.NullInt64:
		return fmt.Sprintf("(%v.%v %s %v)", quotedTableName, quotedPrimaryKey, equalSQL, value.Int64)
	case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
		return fmt.Sprintf("(%v.%v %s %v)", quotedTableName, quotedPrimaryKey, equalSQL, value)
	case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}:
		if !include && reflect.ValueOf(value).Len() == 0 {
			return
		}
		str = fmt.Sprintf("(%v.%v %s (?))", quotedTableName, quotedPrimaryKey, inSQL)
		clause["args"] = []interface{}{value}
	case string:
		if isNumberRegexp.MatchString(value) {
			return fmt.Sprintf("(%v.%v %s %v)", quotedTableName, quotedPrimaryKey, equalSQL, scope.AddToVars(value))
		}

		if value != "" {
			if !include {
				if comparisonRegexp.MatchString(value) {
					str = fmt.Sprintf("NOT (%v)", value)
				} else {
					str = fmt.Sprintf("(%v.%v NOT IN (?))", quotedTableName, scope.Quote(value))
				}
			} else {
				str = fmt.Sprintf("(%v)", value)
			}
		}
	case map[string]interface{}:
		var sqls []string
		for key, value := range value {
			if value != nil {
				sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", quotedTableName, scope.Quote(key), equalSQL, scope.AddToVars(value)))
			} else {
				if !include {
					sqls = append(sqls, fmt.Sprintf("(%v.%v IS NOT NULL)", quotedTableName, scope.Quote(key)))
				} else {
					sqls = append(sqls, fmt.Sprintf("(%v.%v IS NULL)", quotedTableName, scope.Quote(key)))
				}
			}
		}
		return strings.Join(sqls, " AND ")
	case interface{}:
		var sqls []string
		newScope := scope.New(value)

		if len(newScope.Fields()) == 0 {
			scope.Err(fmt.Errorf("invalid query condition: %v", value))
			return
		}
		scopeQuotedTableName := newScope.QuotedTableName()
		for _, field := range newScope.Fields() {
			if !field.IsIgnored && !field.IsBlank {
				sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", scopeQuotedTableName, scope.Quote(field.DBName), equalSQL, scope.AddToVars(field.Field.Interface())))
			}
		}
		return strings.Join(sqls, " AND ")
	default:
		scope.Err(fmt.Errorf("invalid query condition: %v", value))
		return
	}

	replacements := []string{}
	args := clause["args"].([]interface{})
	for _, arg := range args {
		var err error
		switch reflect.ValueOf(arg).Kind() {
		case reflect.Slice: // For where("id in (?)", []int64{1,2})
			if scanner, ok := interface{}(arg).(driver.Valuer); ok {
				arg, err = scanner.Value()
				replacements = append(replacements, scope.AddToVars(arg))
			} else if b, ok := arg.([]byte); ok {
				replacements = append(replacements, scope.AddToVars(b))
			} else if as, ok := arg.([][]interface{}); ok {
				var tempMarks []string
				for _, a := range as {
					var arrayMarks []string
					for _, v := range a {
						arrayMarks = append(arrayMarks, scope.AddToVars(v))
					}

					if len(arrayMarks) > 0 {
						tempMarks = append(tempMarks, fmt.Sprintf("(%v)", strings.Join(arrayMarks, ",")))
					}
				}

				if len(tempMarks) > 0 {
					replacements = append(replacements, strings.Join(tempMarks, ","))
				}
			} else if values := reflect.ValueOf(arg); values.Len() > 0 {
				var tempMarks []string
				for i := 0; i < values.Len(); i++ {
					tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
				}
				replacements = append(replacements, strings.Join(tempMarks, ","))
			} else {
				replacements = append(replacements, scope.AddToVars(Expr("NULL")))
			}
		default:
			if valuer, ok := interface{}(arg).(driver.Valuer); ok {
				arg, err = valuer.Value()
			}

			replacements = append(replacements, scope.AddToVars(arg))
		}

		if err != nil {
			scope.Err(err)
		}
	}

	buff := bytes.NewBuffer([]byte{})
	i := 0
	for _, s := range str {
		if s == '?' && len(replacements) > i {
			buff.WriteString(replacements[i])
			i++
		} else {
			buff.WriteRune(s)
		}
	}

	str = buff.String()

	return
}
複製代碼

開頭是一個精妙的選擇, 基於 include, 實現了 not 條件.

var (
  quotedTableName  = scope.QuotedTableName()
  quotedPrimaryKey = scope.Quote(scope.PrimaryKey())
  equalSQL         = "="
  inSQL            = "IN"
)

// If building not conditions
if !include {
  equalSQL = "<>"
  inSQL = "NOT IN"
}
複製代碼

中間是一個 switch value := clause["query"].(type) 選擇. 在這個 switch 選擇中, 大部分的條件都會直接返回. 剩餘的部分, 則會構建 str 字符串變量.

而這會繼續進入到結尾部分, 這部分的代碼和咱們上面看過的很是相似, 就是根據 clause["args"] 構建 replacements 切片, 用來替換 str 變量中的 ?.

接着看下一個 whereSQL 方法.

func (scope *Scope) whereSQL() (sql string) {
	var (
		quotedTableName                                = scope.QuotedTableName()
		deletedAtField, hasDeletedAtField              = scope.FieldByName("DeletedAt")
		primaryConditions, andConditions, orConditions []string
	)

	if !scope.Search.Unscoped && hasDeletedAtField {
		sql := fmt.Sprintf("%v.%v IS NULL", quotedTableName, scope.Quote(deletedAtField.DBName))
		primaryConditions = append(primaryConditions, sql)
	}

	if !scope.PrimaryKeyZero() {
		for _, field := range scope.PrimaryFields() {
			sql := fmt.Sprintf("%v.%v = %v", quotedTableName, scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))
			primaryConditions = append(primaryConditions, sql)
		}
	}

	for _, clause := range scope.Search.whereConditions {
		if sql := scope.buildCondition(clause, true); sql != "" {
			andConditions = append(andConditions, sql)
		}
	}

	for _, clause := range scope.Search.orConditions {
		if sql := scope.buildCondition(clause, true); sql != "" {
			orConditions = append(orConditions, sql)
		}
	}

	for _, clause := range scope.Search.notConditions {
		if sql := scope.buildCondition(clause, false); sql != "" {
			andConditions = append(andConditions, sql)
		}
	}

	orSQL := strings.Join(orConditions, " OR ")
	combinedSQL := strings.Join(andConditions, " AND ")
	if len(combinedSQL) > 0 {
		if len(orSQL) > 0 {
			combinedSQL = combinedSQL + " OR " + orSQL
		}
	} else {
		combinedSQL = orSQL
	}

	if len(primaryConditions) > 0 {
		sql = "WHERE " + strings.Join(primaryConditions, " AND ")
		if len(combinedSQL) > 0 {
			sql = sql + " AND (" + combinedSQL + ")"
		}
	} else if len(combinedSQL) > 0 {
		sql = "WHERE " + combinedSQL
	}
	return
}
複製代碼

主要構建了三個部分, primaryConditions, andConditions, orConditions.

if !scope.Search.Unscoped && hasDeletedAtField {
  sql := fmt.Sprintf("%v.%v IS NULL", quotedTableName, scope.Quote(deletedAtField.DBName))
  primaryConditions = append(primaryConditions, sql)
}

if !scope.PrimaryKeyZero() {
  for _, field := range scope.PrimaryFields() {
    sql := fmt.Sprintf("%v.%v = %v", quotedTableName, scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))
    primaryConditions = append(primaryConditions, sql)
  }
}
複製代碼

前面兩個 if 構建了 primaryConditions 條件.

for _, clause := range scope.Search.whereConditions {
  if sql := scope.buildCondition(clause, true); sql != "" {
    andConditions = append(andConditions, sql)
  }
}

for _, clause := range scope.Search.orConditions {
  if sql := scope.buildCondition(clause, true); sql != "" {
    orConditions = append(orConditions, sql)
  }
}

for _, clause := range scope.Search.notConditions {
  if sql := scope.buildCondition(clause, false); sql != "" {
    andConditions = append(andConditions, sql)
  }
}
複製代碼

而後三個 for 循環都使用了 buildCondition 方法. 注意到 scope.Search.notConditions 是算在 andConditions 中的.

orSQL := strings.Join(orConditions, " OR ")
combinedSQL := strings.Join(andConditions, " AND ")
if len(combinedSQL) > 0 {
  if len(orSQL) > 0 {
    combinedSQL = combinedSQL + " OR " + orSQL
  }
} else {
  combinedSQL = orSQL
}
複製代碼

結合 orConditionsandConditions 生成了條件語句.

if len(primaryConditions) > 0 {
  sql = "WHERE " + strings.Join(primaryConditions, " AND ")
  if len(combinedSQL) > 0 {
    sql = sql + " AND (" + combinedSQL + ")"
  }
} else if len(combinedSQL) > 0 {
  sql = "WHERE " + combinedSQL
}
return
複製代碼

最後, 結合 primaryConditions 生成最終的 WHERE 子句.

接着看另外一個:

func (scope *Scope) groupSQL() string {
	if len(scope.Search.group) == 0 {
		return ""
	}
	return " GROUP BY " + scope.Search.group
}
複製代碼

GROUP BY 子句比較簡單, 直接就能構建.

繼續:

func (scope *Scope) havingSQL() string {
	if len(scope.Search.havingConditions) == 0 {
		return ""
	}

	var andConditions []string
	for _, clause := range scope.Search.havingConditions {
		if sql := scope.buildCondition(clause, true); sql != "" {
			andConditions = append(andConditions, sql)
		}
	}

	combinedSQL := strings.Join(andConditions, " AND ")
	if len(combinedSQL) == 0 {
		return ""
	}

	return " HAVING " + combinedSQL
}
複製代碼

HAVING 子句也不算難, 構建完條件以後用 AND 鏈接, 而後在最前面加上 HAVING 就好了.

繼續:

func (scope *Scope) orderSQL() string {
	if len(scope.Search.orders) == 0 || scope.Search.ignoreOrderQuery {
		return ""
	}

	var orders []string
	for _, order := range scope.Search.orders {
		if str, ok := order.(string); ok {
			orders = append(orders, scope.quoteIfPossible(str))
		} else if expr, ok := order.(*expr); ok {
			exp := expr.expr
			for _, arg := range expr.args {
				exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1)
			}
			orders = append(orders, exp)
		}
	}
	return " ORDER BY " + strings.Join(orders, ",")
}
複製代碼

結構也是相似, 遍歷 scope.Search.orders 切片, order 有兩種不一樣的類型, 字符串或者 expr 結構體. 後者用於處理帶參數的狀況.

最後還有一個 limitAndOffsetSQL 方法:

func (scope *Scope) limitAndOffsetSQL() string {
	return scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset)
}
複製代碼

這直接調用了具體數據庫驅動中的 LimitAndOffsetSQL 方法.

看兩個具體的實現, 一個是通用中的實現, 另外一個是 mysql 中的實現.

func (commonDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
	if limit != nil {
		if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
			sql += fmt.Sprintf(" LIMIT %d", parsedLimit)
		}
	}
	if offset != nil {
		if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 {
			sql += fmt.Sprintf(" OFFSET %d", parsedOffset)
		}
	}
	return
}
複製代碼

直接將 limit 和 offset 解析爲 int 類型, 而後鏈接對應的關鍵字便可.

接着看一下 mysql 中的實現:

func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
	if limit != nil {
		if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
			sql += fmt.Sprintf(" LIMIT %d", parsedLimit)

			if offset != nil {
				if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 {
					sql += fmt.Sprintf(" OFFSET %d", parsedOffset)
				}
			}
		}
	}
	return
}
複製代碼

二者的區別在於 offset 的嵌套, mysql 中 offset 必須和 limit 一塊兒使用.

就這樣, CombinedConditionSql 中的全部子句都看完了. 說到底其實也沒什麼魔法, 不過是根據不一樣的條件, 構建不一樣的 SQL 語句.

小結

一路從 First 深刻到查詢的內部細節. 在瞭解了底層細節以後, 其餘相似的方法也就不難理解了.

// Take return a record that match given conditions, the order will depend on the database implementation
func (s *DB) Take(out interface{}, where ...interface{}) *DB {
	newScope := s.NewScope(out)
	newScope.Search.Limit(1)
	return newScope.inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
}

// Last find last record that match given conditions, order by primary key
func (s *DB) Last(out interface{}, where ...interface{}) *DB {
	newScope := s.NewScope(out)
	newScope.Search.Limit(1)
	return newScope.Set("gorm:order_by_primary_key", "DESC").
		inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
}

// Find find records that match given conditions
func (s *DB) Find(out interface{}, where ...interface{}) *DB {
	return s.NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
}
複製代碼

search 結構體

前面的過程當中, 咱們只看到了最簡單的查詢是如何產生的. 在這個過程當中, 沒有仔細研究查詢條件是如何存儲的.

看一下如何使用 Where 方法添加查詢條件.

// Get first matched record
db.Where("name = ?", "jinzhu").First(&user)
//// SELECT * FROM users WHERE name = 'jinzhu' limit 1;

// Get all matched records
db.Where("name = ?", "jinzhu").Find(&users)
//// SELECT * FROM users WHERE name = 'jinzhu';
複製代碼

上面的例子來自於官方文檔. GORM 使用鏈式調用的風格, 能夠串聯多個 Where 方法, 或是其餘的查詢條件.

// Where return a new relation, filter records with given conditions, accepts `map`, `struct` or `string` as conditions, refer http://jinzhu.github.io/gorm/crud.html#query
func (s *DB) Where(query interface{}, args ...interface{}) *DB {
	return s.clone().search.Where(query, args...).db
}
複製代碼

上面是 Where 方法的代碼, 在它的源碼附近有不少相似的的方法.

// Or filter records that match before conditions or this one, similar to `Where`
func (s *DB) Or(query interface{}, args ...interface{}) *DB {
	return s.clone().search.Or(query, args...).db
}

// Not filter records that don't match current conditions, similar to `Where`
func (s *DB) Not(query interface{}, args ...interface{}) *DB {
	return s.clone().search.Not(query, args...).db
}
複製代碼

能夠很容易的發現, 這一切的源頭都是 search 對象.

結構體 DB 定義的時候, 有個字段就是 search:

search            *search
複製代碼

search 的定義

這就是用於存儲查詢條件的地方. 它的定義以下:

type search struct {
	db               *DB
	whereConditions  []map[string]interface{}
	orConditions     []map[string]interface{}
	notConditions    []map[string]interface{}
	havingConditions []map[string]interface{}
	joinConditions   []map[string]interface{}
	initAttrs        []interface{}
	assignAttrs      []interface{}
	selects          map[string]interface{}
	omits            []string
	orders           []interface{}
	preload          []searchPreload
	offset           interface{}
	limit            interface{}
	group            string
	tableName        string
	raw              bool
	Unscoped         bool
	ignoreOrderQuery bool
}

type searchPreload struct {
	schema     string
	conditions []interface{}
}
複製代碼

這裏有不少類型爲 []map[string]interface{} 的字段, 結合前面關於條件查詢的代碼, 就能回憶起這就是存儲各類條件的地方.

另外一些字段好比 offsetlimit 也很容易明白它的做用.

search 的方法

search 下有不少方法, 雖然方法數量比較多, 但基本都很短, 總共也就一百行出頭.

func (s *search) clone() *search {
	clone := *s
	return &clone
}
複製代碼

這個克隆方法有點獨特, 彷佛什麼也沒作, 也多是我見識少.

func (s *search) Where(query interface{}, values ...interface{}) *search {
	s.whereConditions = append(s.whereConditions, map[string]interface{}{"query": query, "args": values})
	return s
}

func (s *search) Not(query interface{}, values ...interface{}) *search {
	s.notConditions = append(s.notConditions, map[string]interface{}{"query": query, "args": values})
	return s
}

func (s *search) Or(query interface{}, values ...interface{}) *search {
	s.orConditions = append(s.orConditions, map[string]interface{}{"query": query, "args": values})
	return s
}
複製代碼

上面這些方法都是用參數構建成一個 map 而後推入對應的切片中, 考慮到鏈式調用, 返回了自己.

func (s *search) Attrs(attrs ...interface{}) *search {
	s.initAttrs = append(s.initAttrs, toSearchableMap(attrs...))
	return s
}

func (s *search) Assign(attrs ...interface{}) *search {
	s.assignAttrs = append(s.assignAttrs, toSearchableMap(attrs...))
	return s
}

func toSearchableMap(attrs ...interface{}) (result interface{}) {
	if len(attrs) > 1 {
		if str, ok := attrs[0].(string); ok {
			result = map[string]interface{}{str: attrs[1]}
		}
	} else if len(attrs) == 1 {
		if attr, ok := attrs[0].(map[string]interface{}); ok {
			result = attr
		}

		if attr, ok := attrs[0].(interface{}); ok {
			result = attr
		}
	}
	return
}
複製代碼

這兩個方法也是相似, 並使用了 toSearchableMap 轉換參數.

func (s *search) Order(value interface{}, reorder ...bool) *search {
	if len(reorder) > 0 && reorder[0] {
		s.orders = []interface{}{}
	}

	if value != nil && value != "" {
		s.orders = append(s.orders, value)
	}
	return s
}
複製代碼

看到這個可能有點疑惑, 能夠從文檔和註釋中獲取解釋.

// Order specify order when retrieve records from database, set reorder to `true` to overwrite defined conditions
// db.Order("name DESC")
// db.Order("name DESC", true) // reorder
// db.Order(gorm.Expr("name = ? DESC", "first")) // sql expression
func (s *DB) Order(value interface{}, reorder ...bool) *DB {
	return s.clone().search.Order(value, reorder...).db
}
複製代碼

第二個參數用於判斷是否覆蓋前面的排序條件.

可能有點奇怪的是爲何 reorder 是可變參數, 不知爲了兼容或者是歷史遺留.

另外一點是不能理解 []interface{}{}, 這其實能夠分爲兩部分, []interface{} 是類型, {} 構造了一個空的該類型實例.

func (s *search) Select(query interface{}, args ...interface{}) *search {
	s.selects = map[string]interface{}{"query": query, "args": args}
	return s
}

func (s *search) Omit(columns ...string) *search {
	s.omits = columns
	return s
}

func (s *search) Limit(limit interface{}) *search {
	s.limit = limit
	return s
}

func (s *search) Offset(offset interface{}) *search {
	s.offset = offset
	return s
}
複製代碼

這幾個就是替換型的了, 每次調用都只會保存最新值.

func (s *search) Group(query string) *search {
	s.group = s.getInterfaceAsSQL(query)
	return s
}

func (s *search) getInterfaceAsSQL(value interface{}) (str string) {
	switch value.(type) {
	case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
		str = fmt.Sprintf("%v", value)
	default:
		s.db.AddError(ErrInvalidSQL)
	}

	if str == "-1" {
		return ""
	}
	return
}
複製代碼

getInterfaceAsSQL 的一個特性是使用 -1 會重置.

func (s *search) Having(query interface{}, values ...interface{}) *search {
	if val, ok := query.(*expr); ok {
		s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": val.expr, "args": val.args})
	} else {
		s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": query, "args": values})
	}
	return s
}

func (s *search) Joins(query string, values ...interface{}) *search {
	s.joinConditions = append(s.joinConditions, map[string]interface{}{"query": query, "args": values})
	return s
}
複製代碼

這其實也比較相似前面看過的, 就很少解釋了.

func (s *search) Preload(schema string, values ...interface{}) *search {
	var preloads []searchPreload
	for _, preload := range s.preload {
		if preload.schema != schema {
			preloads = append(preloads, preload)
		}
	}
	preloads = append(preloads, searchPreload{schema, values})
	s.preload = preloads
	return s
}
複製代碼

Preload 須要防止重複, 因此開頭會從新遍歷一遍已經存在的 schema.

func (s *search) Raw(b bool) *search {
	s.raw = b
	return s
}

func (s *search) unscoped() *search {
	s.Unscoped = true
	return s
}

func (s *search) Table(name string) *search {
	s.tableName = name
	return s
}
複製代碼

最後幾個方法也沒什麼特殊的.

小結

search 結構體仍是挺簡單的, 定義加方法總共也就一百多行. 但用處卻不小, 查詢相關的條件都是存儲在這裏的.

總結

這部分主要查看了 SQL 查詢是如何發生的, 並在這個過程當中探索了各類查詢子句是如何實現的. 同時, 也研究了一下 search 結構體和它的做用.

相關文章
相關標籤/搜索