diff --git a/models/cloudbrain.go b/models/cloudbrain.go index e1dc2cbd0..e28ba3ea5 100755 --- a/models/cloudbrain.go +++ b/models/cloudbrain.go @@ -1566,6 +1566,11 @@ func GetCloudbrainCountByUserID(userID int64, jobType string) (int, error) { return int(count), err } +func GetBenchmarkCountByUserID(userID int64) (int, error) { + count, err := x.In("status", JobWaiting, JobRunning).And("(job_type = ? or job_type = ? or job_type = ?) and user_id = ? and type = ?", string(JobTypeBenchmark), string(JobTypeBrainScore), string(JobTypeSnn4imagenet), userID, TypeCloudBrainOne).Count(new(Cloudbrain)) + return int(count), err +} + func GetCloudbrainNotebookCountByUserID(userID int64) (int, error) { count, err := x.In("status", ModelArtsCreateQueue, ModelArtsCreating, ModelArtsStarting, ModelArtsReadyToStart, ModelArtsResizing, ModelArtsStartQueuing, ModelArtsRunning, ModelArtsRestarting). And("job_type = ? and user_id = ? and type = ?", JobTypeDebug, userID, TypeCloudBrainTwo).Count(new(Cloudbrain)) diff --git a/modules/cloudbrain/cloudbrain.go b/modules/cloudbrain/cloudbrain.go index 417611ddf..3dee3a3fa 100755 --- a/modules/cloudbrain/cloudbrain.go +++ b/modules/cloudbrain/cloudbrain.go @@ -349,7 +349,7 @@ func GenerateTask(ctx *context.Context, displayJobName, jobName, image, command, } stringId := strconv.FormatInt(task.ID, 10) - if string(models.JobTypeBenchmark) == jobType { + if IsBenchmarkJob(jobType) { notification.NotifyOtherTask(ctx.User, ctx.Repo.Repository, stringId, displayJobName, models.ActionCreateBenchMarkTask) } else if string(models.JobTypeTrain) == jobType { notification.NotifyOtherTask(ctx.User, ctx.Repo.Repository, jobID, displayJobName, models.ActionCreateGPUTrainTask) @@ -360,6 +360,10 @@ func GenerateTask(ctx *context.Context, displayJobName, jobName, image, command, return nil } +func IsBenchmarkJob(jobType string) bool { + return string(models.JobTypeBenchmark) == jobType || string(models.JobTypeBrainScore) == jobType || string(models.JobTypeSnn4imagenet) == jobType +} + func RestartTask(ctx *context.Context, task *models.Cloudbrain, newID *string) error { dataActualPath := setting.Attachment.Minio.RealPath + setting.Attachment.Minio.Bucket + "/" + diff --git a/routers/repo/cloudbrain.go b/routers/repo/cloudbrain.go index 67e65f6f9..d4e2acdaa 100755 --- a/routers/repo/cloudbrain.go +++ b/routers/repo/cloudbrain.go @@ -1903,7 +1903,7 @@ func BenchMarkAlgorithmCreate(ctx *context.Context, form auth.CreateCloudBrainFo return } - count, err := models.GetCloudbrainCountByUserID(ctx.User.ID, string(models.JobTypeBenchmark)) + count, err := models.GetBenchmarkCountByUserID(ctx.User.ID) if err != nil { log.Error("GetCloudbrainCountByUserID failed:%v", err, ctx.Data["MsgID"]) cloudBrainNewDataPrepare(ctx) @@ -2040,7 +2040,7 @@ func ModelBenchmarkCreate(ctx *context.Context, form auth.CreateCloudBrainForm) return } - count, err := models.GetCloudbrainCountByUserID(ctx.User.ID, jobType) + count, err := models.GetBenchmarkCountByUserID(ctx.User.ID) if err != nil { log.Error("GetCloudbrainCountByUserID failed:%v", err, ctx.Data["MsgID"]) cloudBrainNewDataPrepare(ctx)