From ef13bbaf7de048ecf71f1ca5f15818e417a744b3 Mon Sep 17 00:00:00 2001
From: Lunny Xiao <xiaolunwen@gmail.com>
Date: Fri, 3 Mar 2017 00:36:47 +0800
Subject: [PATCH] Don't rewrite non-gitea public keys (#906)

* don't rewrite non-gitea public keys

* add comment for public key
---
 models/migrations/migrations.go |  2 ++
 models/migrations/v21.go        | 53 +++++++++++++++++++++++++++++++++++++++++
 models/ssh_key.go               | 34 ++++++++++++++++++++++----
 3 files changed, 85 insertions(+), 4 deletions(-)
 create mode 100644 models/migrations/v21.go

diff --git a/models/migrations/migrations.go b/models/migrations/migrations.go
index b3e7fcc8c6..bcf6285923 100644
--- a/models/migrations/migrations.go
+++ b/models/migrations/migrations.go
@@ -90,6 +90,8 @@ var migrations = []Migration{
 	NewMigration("generate and migrate Git hooks", generateAndMigrateGitHooks),
 	// v20 -> v21
 	NewMigration("use new avatar path name for security reason", useNewNameAvatars),
+	// v21 -> v22
+	NewMigration("rewrite authorized_keys file via new format", useNewPublickeyFormat),
 }
 
 // Migrate database to current version
diff --git a/models/migrations/v21.go b/models/migrations/v21.go
new file mode 100644
index 0000000000..f7f01f062b
--- /dev/null
+++ b/models/migrations/v21.go
@@ -0,0 +1,53 @@
+// Copyright 2017 Gitea. All rights reserved.
+// Use of this source code is governed by a MIT-style
+// license that can be found in the LICENSE file.
+
+package migrations
+
+import (
+	"fmt"
+	"os"
+	"path/filepath"
+
+	"code.gitea.io/gitea/modules/setting"
+
+	"github.com/go-xorm/xorm"
+)
+
+const (
+	tplCommentPrefix = `# gitea public key`
+	tplPublicKey     = tplCommentPrefix + "\n" + `command="%s serv key-%d --config='%s'",no-port-forwarding,no-X11-forwarding,no-agent-forwarding,no-pty %s` + "\n"
+)
+
+func useNewPublickeyFormat(x *xorm.Engine) error {
+	fpath := filepath.Join(setting.SSH.RootPath, "authorized_keys")
+	tmpPath := fpath + ".tmp"
+	f, err := os.OpenFile(tmpPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
+	if err != nil {
+		return err
+	}
+	defer func() {
+		f.Close()
+		os.Remove(tmpPath)
+	}()
+
+	type PublicKey struct {
+		ID      int64
+		Content string
+	}
+
+	err = x.Iterate(new(PublicKey), func(idx int, bean interface{}) (err error) {
+		key := bean.(*PublicKey)
+		_, err = f.WriteString(fmt.Sprintf(tplPublicKey, setting.AppPath, key.ID, setting.CustomConf, key.Content))
+		return err
+	})
+	if err != nil {
+		return err
+	}
+
+	f.Close()
+	if err = os.Rename(tmpPath, fpath); err != nil {
+		return err
+	}
+	return nil
+}
diff --git a/models/ssh_key.go b/models/ssh_key.go
index e82fd3aad3..802333f48c 100644
--- a/models/ssh_key.go
+++ b/models/ssh_key.go
@@ -5,6 +5,7 @@
 package models
 
 import (
+	"bufio"
 	"encoding/base64"
 	"encoding/binary"
 	"errors"
@@ -28,7 +29,8 @@ import (
 )
 
 const (
-	tplPublicKey = `command="%s serv key-%d --config='%s'",no-port-forwarding,no-X11-forwarding,no-agent-forwarding,no-pty %s` + "\n"
+	tplCommentPrefix = `# gitea public key`
+	tplPublicKey     = tplCommentPrefix + "\n" + `command="%s serv key-%d --config='%s'",no-port-forwarding,no-X11-forwarding,no-agent-forwarding,no-pty %s` + "\n"
 )
 
 var sshOpLocker sync.Mutex
@@ -553,22 +555,46 @@ func RewriteAllPublicKeys() error {
 	if err != nil {
 		return err
 	}
-	defer os.Remove(tmpPath)
+	defer func() {
+		f.Close()
+		os.Remove(tmpPath)
+	}()
 
 	err = x.Iterate(new(PublicKey), func(idx int, bean interface{}) (err error) {
 		_, err = f.WriteString((bean.(*PublicKey)).AuthorizedString())
 		return err
 	})
-	f.Close()
 	if err != nil {
 		return err
 	}
 
 	if com.IsExist(fpath) {
-		if err = os.Remove(fpath); err != nil {
+		bakPath := fpath + fmt.Sprintf("_%d.gitea_bak", time.Now().Unix())
+		if err = com.Copy(fpath, bakPath); err != nil {
+			return err
+		}
+
+		p, err := os.Open(bakPath)
+		if err != nil {
 			return err
 		}
+		defer p.Close()
+
+		scanner := bufio.NewScanner(p)
+		for scanner.Scan() {
+			line := scanner.Text()
+			if strings.HasPrefix(line, tplCommentPrefix) {
+				scanner.Scan()
+				continue
+			}
+			_, err = f.WriteString(line + "\n")
+			if err != nil {
+				return err
+			}
+		}
 	}
+
+	f.Close()
 	if err = os.Rename(tmpPath, fpath); err != nil {
 		return err
 	}