You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

inference.go 21 kB

2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631
  1. package cloudbrainTask
  2. import (
  3. "bufio"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "io/ioutil"
  9. "net/http"
  10. "os"
  11. "path"
  12. "strconv"
  13. "strings"
  14. "unicode/utf8"
  15. "code.gitea.io/gitea/modules/modelarts"
  16. "code.gitea.io/gitea/modules/git"
  17. api "code.gitea.io/gitea/modules/structs"
  18. "code.gitea.io/gitea/models"
  19. "code.gitea.io/gitea/modules/cloudbrain"
  20. "code.gitea.io/gitea/modules/context"
  21. "code.gitea.io/gitea/modules/log"
  22. "code.gitea.io/gitea/modules/redis/redis_key"
  23. "code.gitea.io/gitea/modules/redis/redis_lock"
  24. "code.gitea.io/gitea/modules/setting"
  25. "code.gitea.io/gitea/modules/storage"
  26. "code.gitea.io/gitea/modules/util"
  27. "code.gitea.io/gitea/services/cloudbrain/resource"
  28. "code.gitea.io/gitea/services/reward/point/account"
  29. )
  30. const CLONE_FILE_PREFIX = "file:///"
  31. func CloudBrainInferenceJobCreate(ctx *context.Context, option api.CreateTrainJobOption) {
  32. displayJobName := option.DisplayJobName
  33. jobName := util.ConvertDisplayJobNameToJobName(displayJobName)
  34. image := strings.TrimSpace(option.Image)
  35. uuid := option.Attachment
  36. jobType := string(models.JobTypeInference)
  37. codePath := setting.JobPath + jobName + cloudbrain.CodeMountPath
  38. branchName := option.BranchName
  39. bootFile := strings.TrimSpace(option.BootFile)
  40. labelName := option.LabelName
  41. repo := ctx.Repo.Repository
  42. lock := redis_lock.NewDistributeLock(redis_key.CloudbrainBindingJobNameKey(fmt.Sprint(repo.ID), jobType, displayJobName))
  43. defer lock.UnLock()
  44. isOk, err := lock.Lock(models.CloudbrainKeyDuration)
  45. if !isOk {
  46. log.Error("lock processed failed:%v", err, ctx.Data["MsgID"])
  47. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(ctx.Tr("repo.cloudbrain_samejob_err")))
  48. return
  49. }
  50. ckptUrl := setting.Attachment.Minio.RealPath + option.PreTrainModelUrl + option.CkptName
  51. log.Info("ckpt url:" + ckptUrl)
  52. command, err := getInferenceJobCommand(option)
  53. if err != nil {
  54. log.Error("getTrainJobCommand failed: %v", err)
  55. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(err.Error()))
  56. return
  57. }
  58. tasks, err := models.GetCloudbrainsByDisplayJobName(repo.ID, jobType, displayJobName)
  59. if err == nil {
  60. if len(tasks) != 0 {
  61. log.Error("the job name did already exist", ctx.Data["MsgID"])
  62. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi("the job name did already exist"))
  63. return
  64. }
  65. } else {
  66. if !models.IsErrJobNotExist(err) {
  67. log.Error("system error, %v", err, ctx.Data["MsgID"])
  68. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi("system error"))
  69. return
  70. }
  71. }
  72. if !jobNamePattern.MatchString(displayJobName) {
  73. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(ctx.Tr("repo.cloudbrain_jobname_err")))
  74. return
  75. }
  76. bootFileExist, err := ctx.Repo.FileExists(bootFile, branchName)
  77. if err != nil || !bootFileExist {
  78. log.Error("Get bootfile error:", err, ctx.Data["MsgID"])
  79. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(ctx.Tr("repo.cloudbrain_bootfile_err")))
  80. return
  81. }
  82. count, err := GetNotFinalStatusTaskCount(ctx.User.ID, models.TypeCloudBrainOne, jobType)
  83. if err != nil {
  84. log.Error("GetCloudbrainCountByUserID failed:%v", err, ctx.Data["MsgID"])
  85. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi("system error"))
  86. return
  87. } else {
  88. if count >= 1 {
  89. log.Error("the user already has running or waiting task", ctx.Data["MsgID"])
  90. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(ctx.Tr("repo.cloudbrain.morethanonejob")))
  91. return
  92. }
  93. }
  94. if branchName == "" {
  95. branchName = cloudbrain.DefaultBranchName
  96. }
  97. errStr := loadCodeAndMakeModelPath(repo, codePath, branchName, jobName, cloudbrain.ResultPath)
  98. if errStr != "" {
  99. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(ctx.Tr(errStr)))
  100. return
  101. }
  102. commitID, _ := ctx.Repo.GitRepo.GetBranchCommitID(branchName)
  103. datasetInfos, datasetNames, err := models.GetDatasetInfo(uuid)
  104. if err != nil {
  105. log.Error("GetDatasetInfo failed: %v", err, ctx.Data["MsgID"])
  106. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(ctx.Tr("cloudbrain.error.dataset_select")))
  107. return
  108. }
  109. spec, err := resource.GetAndCheckSpec(ctx.User.ID, option.SpecId, models.FindSpecsOptions{
  110. JobType: models.JobTypeInference,
  111. ComputeResource: models.GPU,
  112. Cluster: models.OpenICluster,
  113. AiCenterCode: models.AICenterOfCloudBrainOne})
  114. if err != nil || spec == nil {
  115. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi("Resource specification is not available"))
  116. return
  117. }
  118. if !account.IsPointBalanceEnough(ctx.User.ID, spec.UnitPrice) {
  119. log.Error("point balance is not enough,userId=%d specId=%d", ctx.User.ID, spec.ID)
  120. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(ctx.Tr("points.insufficient_points_balance")))
  121. return
  122. }
  123. req := cloudbrain.GenerateCloudBrainTaskReq{
  124. Ctx: ctx,
  125. DisplayJobName: displayJobName,
  126. JobName: jobName,
  127. Image: image,
  128. Command: command,
  129. Uuids: uuid,
  130. DatasetNames: datasetNames,
  131. DatasetInfos: datasetInfos,
  132. CodePath: storage.GetMinioPath(jobName, cloudbrain.CodeMountPath+"/"),
  133. ModelPath: setting.Attachment.Minio.RealPath + option.PreTrainModelUrl,
  134. BenchmarkPath: storage.GetMinioPath(jobName, cloudbrain.BenchMarkMountPath+"/"),
  135. Snn4ImageNetPath: storage.GetMinioPath(jobName, cloudbrain.Snn4imagenetMountPath+"/"),
  136. BrainScorePath: storage.GetMinioPath(jobName, cloudbrain.BrainScoreMountPath+"/"),
  137. JobType: jobType,
  138. Description: option.Description,
  139. BranchName: branchName,
  140. BootFile: option.BootFile,
  141. Params: option.Params,
  142. CommitID: commitID,
  143. ResultPath: storage.GetMinioPath(jobName, cloudbrain.ResultPath+"/"),
  144. ModelName: option.ModelName,
  145. ModelVersion: option.ModelVersion,
  146. CkptName: option.CkptName,
  147. TrainUrl: option.PreTrainModelUrl,
  148. LabelName: labelName,
  149. Spec: spec,
  150. }
  151. jobId, err := cloudbrain.GenerateTask(req)
  152. if err != nil {
  153. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(err.Error()))
  154. return
  155. }
  156. ctx.JSON(http.StatusOK, models.BaseMessageApi{Code: 0, Message: jobId})
  157. }
  158. func ModelArtsInferenceJobCreate(ctx *context.Context, option api.CreateTrainJobOption) {
  159. ctx.Data["PageIsTrainJob"] = true
  160. VersionOutputPath := modelarts.GetOutputPathByCount(modelarts.TotalVersionCount)
  161. displayJobName := option.DisplayJobName
  162. jobName := util.ConvertDisplayJobNameToJobName(displayJobName)
  163. uuid := option.Attachment
  164. description := option.Description
  165. workServerNumber := option.WorkServerNumber
  166. engineID, _ := strconv.Atoi(option.ImageID)
  167. bootFile := strings.TrimSpace(option.BootFile)
  168. params := option.Params
  169. repo := ctx.Repo.Repository
  170. codeLocalPath := setting.JobPath + jobName + modelarts.CodePath
  171. codeObsPath := "/" + setting.Bucket + modelarts.JobPath + jobName + modelarts.CodePath
  172. resultObsPath := "/" + setting.Bucket + modelarts.JobPath + jobName + modelarts.ResultPath + VersionOutputPath + "/"
  173. logObsPath := "/" + setting.Bucket + modelarts.JobPath + jobName + modelarts.LogPath + VersionOutputPath + "/"
  174. //dataPath := "/" + setting.Bucket + "/" + setting.BasePath + path.Join(uuid[0:1], uuid[1:2]) + "/" + uuid + uuid + "/"
  175. branchName := option.BranchName
  176. EngineName := option.Image
  177. LabelName := option.LabelName
  178. isLatestVersion := modelarts.IsLatestVersion
  179. VersionCount := modelarts.VersionCountOne
  180. trainUrl := option.PreTrainModelUrl
  181. modelName := option.ModelName
  182. modelVersion := option.ModelVersion
  183. ckptName := option.CkptName
  184. ckptUrl := "/" + option.PreTrainModelUrl + option.CkptName
  185. errStr := checkInferenceJobMultiNode(ctx.User.ID, option.WorkServerNumber)
  186. if errStr != "" {
  187. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(ctx.Tr(errStr)))
  188. return
  189. }
  190. lock := redis_lock.NewDistributeLock(redis_key.CloudbrainBindingJobNameKey(fmt.Sprint(repo.ID), string(models.JobTypeInference), displayJobName))
  191. isOk, err := lock.Lock(models.CloudbrainKeyDuration)
  192. if !isOk {
  193. log.Error("lock processed failed:%v", err, ctx.Data["MsgID"])
  194. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(ctx.Tr("repo.cloudbrain_samejob_err")))
  195. return
  196. }
  197. defer lock.UnLock()
  198. count, err := GetNotFinalStatusTaskCount(ctx.User.ID, models.TypeCloudBrainTwo, string(models.JobTypeInference))
  199. if err != nil {
  200. log.Error("GetCloudbrainInferenceJobCountByUserID failed:%v", err, ctx.Data["MsgID"])
  201. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi("system error"))
  202. return
  203. } else {
  204. if count >= 1 {
  205. log.Error("the user already has running or waiting inference task", ctx.Data["MsgID"])
  206. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi("you have already a running or waiting inference task, can not create more"))
  207. return
  208. }
  209. }
  210. if err := paramCheckCreateInferenceJob(option); err != nil {
  211. log.Error("paramCheckCreateInferenceJob failed:(%v)", err)
  212. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(err.Error()))
  213. return
  214. }
  215. bootFileExist, err := ctx.Repo.FileExists(bootFile, branchName)
  216. if err != nil || !bootFileExist {
  217. log.Error("Get bootfile error:", err)
  218. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(ctx.Tr("repo.cloudbrain_bootfile_err")))
  219. return
  220. }
  221. //Determine whether the task name of the task in the project is duplicated
  222. tasks, err := models.GetCloudbrainsByDisplayJobName(repo.ID, string(models.JobTypeInference), displayJobName)
  223. if err == nil {
  224. if len(tasks) != 0 {
  225. log.Error("the job name did already exist", ctx.Data["MsgID"])
  226. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi("the job name did already exist"))
  227. return
  228. }
  229. } else {
  230. if !models.IsErrJobNotExist(err) {
  231. log.Error("system error, %v", err, ctx.Data["MsgID"])
  232. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi("system error"))
  233. return
  234. }
  235. }
  236. spec, err := resource.GetAndCheckSpec(ctx.User.ID, option.SpecId, models.FindSpecsOptions{
  237. JobType: models.JobTypeInference,
  238. ComputeResource: models.NPU,
  239. Cluster: models.OpenICluster,
  240. AiCenterCode: models.AICenterOfCloudBrainTwo})
  241. if err != nil || spec == nil {
  242. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi("Resource specification not available"))
  243. return
  244. }
  245. if !account.IsPointBalanceEnough(ctx.User.ID, spec.UnitPrice) {
  246. log.Error("point balance is not enough,userId=%d specId=%d ", ctx.User.ID, spec.ID)
  247. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(ctx.Tr("points.insufficient_points_balance")))
  248. return
  249. }
  250. //todo: del the codeLocalPath
  251. _, err = ioutil.ReadDir(codeLocalPath)
  252. if err == nil {
  253. os.RemoveAll(codeLocalPath)
  254. }
  255. gitRepo, _ := git.OpenRepository(repo.RepoPath())
  256. commitID, _ := gitRepo.GetBranchCommitID(branchName)
  257. if err := downloadCode(repo, codeLocalPath, branchName); err != nil {
  258. log.Error("Create task failed, server timed out: %s (%v)", repo.FullName(), err)
  259. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(ctx.Tr("cloudbrain.load_code_failed")))
  260. return
  261. }
  262. //todo: upload code (send to file_server todo this work?)
  263. if err := obsMkdir(setting.CodePathPrefix + jobName + modelarts.ResultPath + VersionOutputPath + "/"); err != nil {
  264. log.Error("Failed to obsMkdir_result: %s (%v)", repo.FullName(), err)
  265. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi("Failed to obsMkdir_result"))
  266. return
  267. }
  268. if err := obsMkdir(setting.CodePathPrefix + jobName + modelarts.LogPath + VersionOutputPath + "/"); err != nil {
  269. log.Error("Failed to obsMkdir_log: %s (%v)", repo.FullName(), err)
  270. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi("Failed to obsMkdir_log"))
  271. return
  272. }
  273. if err := uploadCodeToObs(codeLocalPath, jobName, ""); err != nil {
  274. log.Error("Failed to uploadCodeToObs: %s (%v)", repo.FullName(), err)
  275. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(ctx.Tr("cloudbrain.load_code_failed")))
  276. return
  277. }
  278. var parameters models.Parameters
  279. param := make([]models.Parameter, 0)
  280. param = append(param, models.Parameter{
  281. Label: modelarts.ResultUrl,
  282. Value: "s3:/" + resultObsPath,
  283. }, models.Parameter{
  284. Label: modelarts.CkptUrl,
  285. Value: "s3:/" + ckptUrl,
  286. })
  287. datasUrlList, dataUrl, datasetNames, isMultiDataset, err := getDatasUrlListByUUIDS(uuid)
  288. if err != nil {
  289. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(err.Error()))
  290. return
  291. }
  292. dataPath := dataUrl
  293. jsondatas, err := json.Marshal(datasUrlList)
  294. if err != nil {
  295. log.Error("Failed to Marshal: %v", err)
  296. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi("json error:"+err.Error()))
  297. return
  298. }
  299. if isMultiDataset {
  300. param = append(param, models.Parameter{
  301. Label: modelarts.MultiDataUrl,
  302. Value: string(jsondatas),
  303. })
  304. }
  305. existDeviceTarget := false
  306. if len(params) != 0 {
  307. err := json.Unmarshal([]byte(params), &parameters)
  308. if err != nil {
  309. log.Error("Failed to Unmarshal params: %s (%v)", params, err)
  310. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi("运行参数错误"))
  311. return
  312. }
  313. for _, parameter := range parameters.Parameter {
  314. if parameter.Label == modelarts.DeviceTarget {
  315. existDeviceTarget = true
  316. }
  317. if parameter.Label != modelarts.TrainUrl && parameter.Label != modelarts.DataUrl {
  318. param = append(param, models.Parameter{
  319. Label: parameter.Label,
  320. Value: parameter.Value,
  321. })
  322. }
  323. }
  324. }
  325. if !existDeviceTarget {
  326. param = append(param, models.Parameter{
  327. Label: modelarts.DeviceTarget,
  328. Value: modelarts.Ascend,
  329. })
  330. }
  331. req := &modelarts.GenerateInferenceJobReq{
  332. JobName: jobName,
  333. DisplayJobName: displayJobName,
  334. DataUrl: dataPath,
  335. Description: description,
  336. CodeObsPath: codeObsPath,
  337. BootFileUrl: codeObsPath + bootFile,
  338. BootFile: bootFile,
  339. TrainUrl: trainUrl,
  340. WorkServerNumber: workServerNumber,
  341. EngineID: int64(engineID),
  342. LogUrl: logObsPath,
  343. PoolID: getPoolId(),
  344. Uuid: uuid,
  345. Parameters: param, //modelarts train parameters
  346. CommitID: commitID,
  347. BranchName: branchName,
  348. Params: option.Params,
  349. EngineName: EngineName,
  350. LabelName: LabelName,
  351. IsLatestVersion: isLatestVersion,
  352. VersionCount: VersionCount,
  353. TotalVersionCount: modelarts.TotalVersionCount,
  354. ModelName: modelName,
  355. ModelVersion: modelVersion,
  356. CkptName: ckptName,
  357. ResultUrl: resultObsPath,
  358. Spec: spec,
  359. DatasetName: datasetNames,
  360. JobType: string(models.JobTypeInference),
  361. }
  362. jobId, err := modelarts.GenerateInferenceJob(ctx, req)
  363. if err != nil {
  364. log.Error("GenerateTrainJob failed:%v", err.Error())
  365. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(err.Error()))
  366. return
  367. }
  368. ctx.JSON(http.StatusOK, models.BaseMessageApi{Code: 0, Message: jobId})
  369. }
  370. func getDatasUrlListByUUIDS(uuidStr string) ([]models.Datasurl, string, string, bool, error) {
  371. var isMultiDataset bool
  372. var dataUrl string
  373. var datasetNames string
  374. var datasUrlList []models.Datasurl
  375. uuids := strings.Split(uuidStr, ";")
  376. if len(uuids) > setting.MaxDatasetNum {
  377. log.Error("the dataset count(%d) exceed the limit", len(uuids))
  378. return datasUrlList, dataUrl, datasetNames, isMultiDataset, errors.New("the dataset count exceed the limit")
  379. }
  380. datasetInfos := make(map[string]models.DatasetInfo)
  381. attachs, err := models.GetAttachmentsByUUIDs(uuids)
  382. if err != nil || len(attachs) != len(uuids) {
  383. log.Error("GetAttachmentsByUUIDs failed: %v", err)
  384. return datasUrlList, dataUrl, datasetNames, isMultiDataset, errors.New("GetAttachmentsByUUIDs failed")
  385. }
  386. for i, tmpUuid := range uuids {
  387. var attach *models.Attachment
  388. for _, tmpAttach := range attachs {
  389. if tmpAttach.UUID == tmpUuid {
  390. attach = tmpAttach
  391. break
  392. }
  393. }
  394. if attach == nil {
  395. log.Error("GetAttachmentsByUUIDs failed: %v", err)
  396. return datasUrlList, dataUrl, datasetNames, isMultiDataset, errors.New("GetAttachmentsByUUIDs failed")
  397. }
  398. fileName := strings.TrimSuffix(strings.TrimSuffix(strings.TrimSuffix(attach.Name, ".zip"), ".tar.gz"), ".tgz")
  399. for _, datasetInfo := range datasetInfos {
  400. if fileName == datasetInfo.Name {
  401. log.Error("the dataset name is same: %v", attach.Name)
  402. return datasUrlList, dataUrl, datasetNames, isMultiDataset, errors.New("the dataset name is same")
  403. }
  404. }
  405. if len(attachs) <= 1 {
  406. dataUrl = "/" + setting.Bucket + "/" + setting.BasePath + path.Join(attach.UUID[0:1], attach.UUID[1:2]) + "/" + attach.UUID + attach.UUID + "/"
  407. isMultiDataset = false
  408. } else {
  409. dataUrl = "/" + setting.Bucket + "/" + setting.BasePath + path.Join(attachs[0].UUID[0:1], attachs[0].UUID[1:2]) + "/" + attachs[0].UUID + attachs[0].UUID + "/"
  410. datasetUrl := "s3://" + setting.Bucket + "/" + setting.BasePath + path.Join(attach.UUID[0:1], attach.UUID[1:2]) + "/" + attach.UUID + attach.UUID + "/"
  411. datasUrlList = append(datasUrlList, models.Datasurl{
  412. DatasetUrl: datasetUrl,
  413. DatasetName: fileName,
  414. })
  415. isMultiDataset = true
  416. }
  417. if i == 0 {
  418. datasetNames = attach.Name
  419. } else {
  420. datasetNames += ";" + attach.Name
  421. }
  422. }
  423. return datasUrlList, dataUrl, datasetNames, isMultiDataset, nil
  424. }
  425. func checkInferenceJobMultiNode(userId int64, serverNum int) string {
  426. if serverNum == 1 {
  427. return ""
  428. }
  429. return "repo.modelarts.no_node_right"
  430. }
  431. func paramCheckCreateInferenceJob(option api.CreateTrainJobOption) error {
  432. if !strings.HasSuffix(strings.TrimSpace(option.BootFile), ".py") {
  433. log.Error("the boot file(%s) must be a python file", strings.TrimSpace(option.BootFile))
  434. return errors.New("启动文件必须是python文件")
  435. }
  436. if option.ModelName == "" {
  437. log.Error("the ModelName(%d) must not be nil", option.ModelName)
  438. return errors.New("模型名称不能为空")
  439. }
  440. if option.ModelVersion == "" {
  441. log.Error("the ModelVersion(%d) must not be nil", option.ModelVersion)
  442. return errors.New("模型版本不能为空")
  443. }
  444. if option.CkptName == "" {
  445. log.Error("the CkptName(%d) must not be nil", option.CkptName)
  446. return errors.New("权重文件不能为空")
  447. }
  448. if option.BranchName == "" {
  449. log.Error("the Branch(%d) must not be nil", option.BranchName)
  450. return errors.New("分支名不能为空")
  451. }
  452. if utf8.RuneCountInString(option.Description) > 255 {
  453. log.Error("the Description length(%d) must not more than 255", option.Description)
  454. return errors.New("描述字符不能超过255个字符")
  455. }
  456. return nil
  457. }
  458. func loadCodeAndMakeModelPath(repo *models.Repository, codePath string, branchName string, jobName string, resultPath string) string {
  459. err := downloadCode(repo, codePath, branchName)
  460. if err != nil {
  461. return "cloudbrain.load_code_failed"
  462. }
  463. err = uploadCodeToMinio(codePath+"/", jobName, cloudbrain.CodeMountPath+"/")
  464. if err != nil {
  465. return "cloudbrain.load_code_failed"
  466. }
  467. modelPath := setting.JobPath + jobName + resultPath + "/"
  468. err = mkModelPath(modelPath)
  469. if err != nil {
  470. return "cloudbrain.load_code_failed"
  471. }
  472. err = uploadCodeToMinio(modelPath, jobName, resultPath+"/")
  473. if err != nil {
  474. return "cloudbrain.load_code_failed"
  475. }
  476. return ""
  477. }
  478. func downloadCode(repo *models.Repository, codePath, branchName string) error {
  479. //add "file:///" prefix to make the depth valid
  480. if err := git.Clone(CLONE_FILE_PREFIX+repo.RepoPath(), codePath, git.CloneRepoOptions{Branch: branchName, Depth: 1}); err != nil {
  481. log.Error("Failed to clone repository: %s (%v)", repo.FullName(), err)
  482. return err
  483. }
  484. configFile, err := os.OpenFile(codePath+"/.git/config", os.O_RDWR, 0666)
  485. if err != nil {
  486. log.Error("open file(%s) failed:%v", codePath+"/,git/config", err)
  487. return err
  488. }
  489. defer configFile.Close()
  490. pos := int64(0)
  491. reader := bufio.NewReader(configFile)
  492. for {
  493. line, err := reader.ReadString('\n')
  494. if err != nil {
  495. if err == io.EOF {
  496. log.Error("not find the remote-url")
  497. return nil
  498. } else {
  499. log.Error("read error: %v", err)
  500. return err
  501. }
  502. }
  503. if strings.Contains(line, "url") && strings.Contains(line, ".git") {
  504. originUrl := "\turl = " + repo.CloneLink().HTTPS + "\n"
  505. if len(line) > len(originUrl) {
  506. originUrl += strings.Repeat(" ", len(line)-len(originUrl))
  507. }
  508. bytes := []byte(originUrl)
  509. _, err := configFile.WriteAt(bytes, pos)
  510. if err != nil {
  511. log.Error("WriteAt failed:%v", err)
  512. return err
  513. }
  514. break
  515. }
  516. pos += int64(len(line))
  517. }
  518. return nil
  519. }
  520. func getInferenceJobCommand(option api.CreateTrainJobOption) (string, error) {
  521. var command string
  522. bootFile := strings.TrimSpace(option.BootFile)
  523. params := option.Params
  524. if !strings.HasSuffix(bootFile, ".py") {
  525. log.Error("bootFile(%s) format error", bootFile)
  526. return command, errors.New("bootFile format error")
  527. }
  528. var parameters models.Parameters
  529. var param string
  530. if len(params) != 0 {
  531. err := json.Unmarshal([]byte(params), &parameters)
  532. if err != nil {
  533. log.Error("Failed to Unmarshal params: %s (%v)", params, err)
  534. return command, err
  535. }
  536. for _, parameter := range parameters.Parameter {
  537. param += " --" + parameter.Label + "=" + parameter.Value
  538. }
  539. }
  540. param += " --modelname" + "=" + option.CkptName
  541. command += "python /code/" + bootFile + param + " > " + cloudbrain.ResultPath + "/" + option.DisplayJobName + "-" + cloudbrain.LogFile
  542. return command, nil
  543. }