You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

modelarts.go 16 kB

4 years ago
3 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
4 years ago
3 years ago
3 years ago
3 years ago
3 years ago
4 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
4 years ago
4 years ago
4 years ago
4 years ago
3 years ago
3 years ago
3 years ago
3 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
3 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
4 years ago
4 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
4 years ago
4 years ago
3 years ago
4 years ago
3 years ago
4 years ago
3 years ago
4 years ago
3 years ago
4 years ago
3 years ago
3 years ago
4 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557
  1. package modelarts
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "path"
  6. "strconv"
  7. "code.gitea.io/gitea/models"
  8. "code.gitea.io/gitea/modules/context"
  9. "code.gitea.io/gitea/modules/log"
  10. "code.gitea.io/gitea/modules/setting"
  11. "code.gitea.io/gitea/modules/storage"
  12. )
  13. const (
  14. //notebook
  15. storageTypeOBS = "obs"
  16. autoStopDuration = 4 * 60 * 60
  17. DataSetMountPath = "/home/ma-user/work"
  18. NotebookEnv = "Python3"
  19. NotebookType = "Ascend"
  20. FlavorInfo = "Ascend: 1*Ascend 910 CPU: 24 核 96GiB (modelarts.kat1.xlarge)"
  21. //train-job
  22. // ResourcePools = "{\"resource_pool\":[{\"id\":\"pool1328035d\", \"value\":\"专属资源池\"}]}"
  23. // Engines = "{\"engine\":[{\"id\":1, \"value\":\"Ascend-Powered-Engine\"}]}"
  24. // EngineVersions = "{\"version\":[{\"id\":118,\"value\":\"MindSpore-1.0.0-c75-python3.7-euleros2.8-aarch64\"}," +
  25. // "{\"id\":119,\"value\":\"MindSpore-1.1.1-c76-python3.7-euleros2.8-aarch64\"}," +
  26. // "{\"id\":120,\"value\":\"MindSpore-1.1.1-c76-tr5-python3.7-euleros2.8-aarch64\"}," +
  27. // "{\"id\":117,\"value\":\"TF-1.15-c75-python3.7-euleros2.8-aarch64\"}" +
  28. // "]}"
  29. // TrainJobFlavorInfo = "{\"flavor\":[{\"code\":\"modelarts.bm.910.arm.public.2\",\"value\":\"Ascend : 2 * Ascend 910 CPU:48 核 512GiB\"}," +
  30. // "{\"code\":\"modelarts.bm.910.arm.public.8\",\"value\":\"Ascend : 8 * Ascend 910 CPU:192 核 2048GiB\"}," +
  31. // "{\"code\":\"modelarts.bm.910.arm.public.4\",\"value\":\"Ascend : 4 * Ascend 910 CPU:96 核 1024GiB\"}," +
  32. // "{\"code\":\"modelarts.bm.910.arm.public.1\",\"value\":\"Ascend : 1 * Ascend 910 CPU:24 核 256GiB\"}" +
  33. // "]}"
  34. CodePath = "/code/"
  35. OutputPath = "/output/"
  36. ResultPath = "/result/"
  37. LogPath = "/log/"
  38. JobPath = "/job/"
  39. OrderDesc = "desc" //向下查询
  40. OrderAsc = "asc" //向上查询
  41. Lines = 500
  42. TrainUrl = "train_url"
  43. DataUrl = "data_url"
  44. ResultUrl = "result_url"
  45. CkptUrl = "ckpt_url"
  46. PerPage = 10
  47. IsLatestVersion = "1"
  48. NotLatestVersion = "0"
  49. DebugType = -1
  50. VersionCount = 1
  51. SortByCreateTime = "create_time"
  52. ConfigTypeCustom = "custom"
  53. TotalVersionCount = 1
  54. )
  55. var (
  56. poolInfos *models.PoolInfos
  57. FlavorInfos *models.FlavorInfos
  58. )
  59. type GenerateTrainJobReq struct {
  60. JobName string
  61. Uuid string
  62. Description string
  63. CodeObsPath string
  64. BootFile string
  65. BootFileUrl string
  66. DataUrl string
  67. TrainUrl string
  68. FlavorCode string
  69. LogUrl string
  70. PoolID string
  71. WorkServerNumber int
  72. EngineID int64
  73. Parameters []models.Parameter
  74. CommitID string
  75. IsLatestVersion string
  76. Params string
  77. BranchName string
  78. PreVersionId int64
  79. PreVersionName string
  80. FlavorName string
  81. VersionCount int
  82. EngineName string
  83. TotalVersionCount int
  84. }
  85. type GenerateTrainJobVersionReq struct {
  86. JobName string
  87. Uuid string
  88. Description string
  89. CodeObsPath string
  90. BootFile string
  91. BootFileUrl string
  92. DataUrl string
  93. TrainUrl string
  94. FlavorCode string
  95. LogUrl string
  96. PoolID string
  97. WorkServerNumber int
  98. EngineID int64
  99. Parameters []models.Parameter
  100. Params string
  101. PreVersionId int64
  102. CommitID string
  103. BranchName string
  104. FlavorName string
  105. EngineName string
  106. PreVersionName string
  107. TotalVersionCount int
  108. }
  109. type GenerateInferenceJobReq struct {
  110. JobName string
  111. Uuid string
  112. Description string
  113. CodeObsPath string
  114. BootFile string
  115. BootFileUrl string
  116. DataUrl string
  117. TrainUrl string
  118. FlavorCode string
  119. LogUrl string
  120. PoolID string
  121. WorkServerNumber int
  122. EngineID int64
  123. Parameters []models.Parameter
  124. CommitID string
  125. Params string
  126. BranchName string
  127. FlavorName string
  128. EngineName string
  129. LabelName string
  130. IsLatestVersion string
  131. VersionCount int
  132. TotalVersionCount int
  133. ModelName string
  134. ModelVersion string
  135. CkptName string
  136. ResultUrl string
  137. }
  138. type VersionInfo struct {
  139. Version []struct {
  140. ID int `json:"id"`
  141. Value string `json:"value"`
  142. } `json:"version"`
  143. }
  144. type Flavor struct {
  145. Info []struct {
  146. Code string `json:"code"`
  147. Value string `json:"value"`
  148. } `json:"flavor"`
  149. }
  150. type Engine struct {
  151. Info []struct {
  152. ID int `json:"id"`
  153. Value string `json:"value"`
  154. } `json:"engine"`
  155. }
  156. type ResourcePool struct {
  157. Info []struct {
  158. ID string `json:"id"`
  159. Value string `json:"value"`
  160. } `json:"resource_pool"`
  161. }
  162. // type Parameter struct {
  163. // Label string `json:"label"`
  164. // Value string `json:"value"`
  165. // }
  166. // type Parameters struct {
  167. // Parameter []Parameter `json:"parameter"`
  168. // }
  169. type Parameters struct {
  170. Parameter []struct {
  171. Label string `json:"label"`
  172. Value string `json:"value"`
  173. } `json:"parameter"`
  174. }
  175. func GenerateTask(ctx *context.Context, jobName, uuid, description, flavor string) error {
  176. var dataActualPath string
  177. if uuid != "" {
  178. dataActualPath = setting.Bucket + "/" + setting.BasePath + path.Join(uuid[0:1], uuid[1:2]) + "/" + uuid + "/"
  179. } else {
  180. userPath := setting.UserBasePath + ctx.User.Name + "/"
  181. isExist, err := storage.ObsHasObject(userPath)
  182. if err != nil {
  183. log.Error("ObsHasObject failed:%v", err.Error(), ctx.Data["MsgID"])
  184. return err
  185. }
  186. if !isExist {
  187. if err = storage.ObsCreateObject(userPath); err != nil {
  188. log.Error("ObsCreateObject failed:%v", err.Error(), ctx.Data["MsgID"])
  189. return err
  190. }
  191. }
  192. dataActualPath = setting.Bucket + "/" + userPath
  193. }
  194. if poolInfos == nil {
  195. json.Unmarshal([]byte(setting.PoolInfos), &poolInfos)
  196. }
  197. jobResult, err := CreateJob(models.CreateNotebookParams{
  198. JobName: jobName,
  199. Description: description,
  200. ProfileID: setting.ProfileID,
  201. Flavor: flavor,
  202. Pool: models.Pool{
  203. ID: poolInfos.PoolInfo[0].PoolId,
  204. Name: poolInfos.PoolInfo[0].PoolName,
  205. Type: poolInfos.PoolInfo[0].PoolType,
  206. },
  207. Spec: models.Spec{
  208. Storage: models.Storage{
  209. Type: storageTypeOBS,
  210. Location: models.Location{
  211. Path: dataActualPath,
  212. },
  213. },
  214. AutoStop: models.AutoStop{
  215. Enable: true,
  216. Duration: autoStopDuration,
  217. },
  218. },
  219. })
  220. if err != nil {
  221. log.Error("CreateJob failed: %v", err.Error())
  222. return err
  223. }
  224. err = models.CreateCloudbrain(&models.Cloudbrain{
  225. Status: string(models.JobWaiting),
  226. UserID: ctx.User.ID,
  227. RepoID: ctx.Repo.Repository.ID,
  228. JobID: jobResult.ID,
  229. JobName: jobName,
  230. JobType: string(models.JobTypeDebug),
  231. Type: models.TypeCloudBrainTwo,
  232. Uuid: uuid,
  233. ComputeResource: models.NPUResource,
  234. })
  235. if err != nil {
  236. return err
  237. }
  238. return nil
  239. }
  240. func GenerateTrainJob(ctx *context.Context, req *GenerateTrainJobReq) (err error) {
  241. jobResult, err := createTrainJob(models.CreateTrainJobParams{
  242. JobName: req.JobName,
  243. Description: req.Description,
  244. Config: models.Config{
  245. WorkServerNum: req.WorkServerNumber,
  246. AppUrl: req.CodeObsPath,
  247. BootFileUrl: req.BootFileUrl,
  248. DataUrl: req.DataUrl,
  249. EngineID: req.EngineID,
  250. TrainUrl: req.TrainUrl,
  251. LogUrl: req.LogUrl,
  252. PoolID: req.PoolID,
  253. CreateVersion: true,
  254. Flavor: models.Flavor{
  255. Code: req.FlavorCode,
  256. },
  257. Parameter: req.Parameters,
  258. },
  259. })
  260. if err != nil {
  261. log.Error("CreateJob failed: %v", err.Error())
  262. return err
  263. }
  264. attach, err := models.GetAttachmentByUUID(req.Uuid)
  265. if err != nil {
  266. log.Error("GetAttachmentByUUID(%s) failed:%v", strconv.FormatInt(jobResult.JobID, 10), err.Error())
  267. return err
  268. }
  269. err = models.CreateCloudbrain(&models.Cloudbrain{
  270. Status: TransTrainJobStatus(jobResult.Status),
  271. UserID: ctx.User.ID,
  272. RepoID: ctx.Repo.Repository.ID,
  273. JobID: strconv.FormatInt(jobResult.JobID, 10),
  274. JobName: req.JobName,
  275. JobType: string(models.JobTypeTrain),
  276. Type: models.TypeCloudBrainTwo,
  277. VersionID: jobResult.VersionID,
  278. VersionName: jobResult.VersionName,
  279. Uuid: req.Uuid,
  280. DatasetName: attach.Name,
  281. CommitID: req.CommitID,
  282. IsLatestVersion: req.IsLatestVersion,
  283. ComputeResource: models.NPUResource,
  284. EngineID: req.EngineID,
  285. TrainUrl: req.TrainUrl,
  286. BranchName: req.BranchName,
  287. Parameters: req.Params,
  288. BootFile: req.BootFile,
  289. DataUrl: req.DataUrl,
  290. LogUrl: req.LogUrl,
  291. FlavorCode: req.FlavorCode,
  292. Description: req.Description,
  293. WorkServerNumber: req.WorkServerNumber,
  294. FlavorName: req.FlavorName,
  295. EngineName: req.EngineName,
  296. VersionCount: req.VersionCount,
  297. TotalVersionCount: req.TotalVersionCount,
  298. })
  299. if err != nil {
  300. log.Error("CreateCloudbrain(%s) failed:%v", req.JobName, err.Error())
  301. return err
  302. }
  303. return nil
  304. }
  305. func GenerateTrainJobVersion(ctx *context.Context, req *GenerateTrainJobReq, jobId string) (err error) {
  306. jobResult, err := createTrainJobVersion(models.CreateTrainJobVersionParams{
  307. Description: req.Description,
  308. Config: models.TrainJobVersionConfig{
  309. WorkServerNum: req.WorkServerNumber,
  310. AppUrl: req.CodeObsPath,
  311. BootFileUrl: req.BootFileUrl,
  312. DataUrl: req.DataUrl,
  313. EngineID: req.EngineID,
  314. TrainUrl: req.TrainUrl,
  315. LogUrl: req.LogUrl,
  316. PoolID: req.PoolID,
  317. Flavor: models.Flavor{
  318. Code: req.FlavorCode,
  319. },
  320. Parameter: req.Parameters,
  321. PreVersionId: req.PreVersionId,
  322. },
  323. }, jobId)
  324. if err != nil {
  325. log.Error("CreateJob failed: %v", err.Error())
  326. return err
  327. }
  328. attach, err := models.GetAttachmentByUUID(req.Uuid)
  329. if err != nil {
  330. log.Error("GetAttachmentByUUID(%s) failed:%v", strconv.FormatInt(jobResult.JobID, 10), err.Error())
  331. return err
  332. }
  333. var jobTypes []string
  334. jobTypes = append(jobTypes, string(models.JobTypeTrain))
  335. repo := ctx.Repo.Repository
  336. VersionTaskList, VersionListCount, err := models.CloudbrainsVersionList(&models.CloudbrainsOptions{
  337. RepoID: repo.ID,
  338. Type: models.TypeCloudBrainTwo,
  339. JobTypes: jobTypes,
  340. JobID: strconv.FormatInt(jobResult.JobID, 10),
  341. })
  342. if err != nil {
  343. ctx.ServerError("Cloudbrain", err)
  344. return err
  345. }
  346. //将当前版本的isLatestVersion设置为"1"和任务数量更新,任务数量包括当前版本数VersionCount和历史创建的总版本数TotalVersionCount
  347. err = models.CreateCloudbrain(&models.Cloudbrain{
  348. Status: TransTrainJobStatus(jobResult.Status),
  349. UserID: ctx.User.ID,
  350. RepoID: ctx.Repo.Repository.ID,
  351. JobID: strconv.FormatInt(jobResult.JobID, 10),
  352. JobName: req.JobName,
  353. JobType: string(models.JobTypeTrain),
  354. Type: models.TypeCloudBrainTwo,
  355. VersionID: jobResult.VersionID,
  356. VersionName: jobResult.VersionName,
  357. Uuid: req.Uuid,
  358. DatasetName: attach.Name,
  359. CommitID: req.CommitID,
  360. IsLatestVersion: req.IsLatestVersion,
  361. PreVersionName: req.PreVersionName,
  362. ComputeResource: models.NPUResource,
  363. EngineID: req.EngineID,
  364. TrainUrl: req.TrainUrl,
  365. BranchName: req.BranchName,
  366. Parameters: req.Params,
  367. BootFile: req.BootFile,
  368. DataUrl: req.DataUrl,
  369. LogUrl: req.LogUrl,
  370. PreVersionId: req.PreVersionId,
  371. FlavorCode: req.FlavorCode,
  372. Description: req.Description,
  373. WorkServerNumber: req.WorkServerNumber,
  374. FlavorName: req.FlavorName,
  375. EngineName: req.EngineName,
  376. TotalVersionCount: VersionTaskList[0].TotalVersionCount + 1,
  377. VersionCount: VersionListCount + 1,
  378. })
  379. if err != nil {
  380. log.Error("CreateCloudbrain(%s) failed:%v", req.JobName, err.Error())
  381. return err
  382. }
  383. //将训练任务的上一版本的isLatestVersion设置为"0"
  384. err = models.SetVersionCountAndLatestVersion(strconv.FormatInt(jobResult.JobID, 10), VersionTaskList[0].VersionName, VersionCount, NotLatestVersion, TotalVersionCount)
  385. if err != nil {
  386. ctx.ServerError("Update IsLatestVersion failed", err)
  387. return err
  388. }
  389. return err
  390. }
  391. func TransTrainJobStatus(status int) string {
  392. switch status {
  393. case 0:
  394. return "UNKNOWN"
  395. case 1:
  396. return "INIT"
  397. case 2:
  398. return "IMAGE_CREATING"
  399. case 3:
  400. return "IMAGE_FAILED"
  401. case 4:
  402. return "SUBMIT_TRYING"
  403. case 5:
  404. return "SUBMIT_FAILED"
  405. case 6:
  406. return "DELETE_FAILED"
  407. case 7:
  408. return "WAITING"
  409. case 8:
  410. return "RUNNING"
  411. case 9:
  412. return "KILLING"
  413. case 10:
  414. return "COMPLETED"
  415. case 11:
  416. return "FAILED"
  417. case 12:
  418. return "KILLED"
  419. case 13:
  420. return "CANCELED"
  421. case 14:
  422. return "LOST"
  423. case 15:
  424. return "SCALING"
  425. case 16:
  426. return "SUBMIT_MODEL_FAILED"
  427. case 17:
  428. return "DEPLOY_SERVICE_FAILED"
  429. case 18:
  430. return "CHECK_INIT"
  431. case 19:
  432. return "CHECK_RUNNING"
  433. case 20:
  434. return "CHECK_RUNNING_COMPLETED"
  435. case 21:
  436. return "CHECK_FAILED"
  437. default:
  438. return strconv.Itoa(status)
  439. }
  440. }
  441. func GetOutputPathByCount(TotalVersionCount int) (VersionOutputPath string) {
  442. talVersionCountToString := fmt.Sprintf("%04d", TotalVersionCount)
  443. VersionOutputPath = "V" + talVersionCountToString
  444. return VersionOutputPath
  445. }
  446. func GenerateInferenceJob(ctx *context.Context, req *GenerateInferenceJobReq) (err error) {
  447. jobResult, err := createInferenceJob(models.CreateInferenceJobParams{
  448. JobName: req.JobName,
  449. Description: req.Description,
  450. InfConfig: models.InfConfig{
  451. WorkServerNum: req.WorkServerNumber,
  452. AppUrl: req.CodeObsPath,
  453. BootFileUrl: req.BootFileUrl,
  454. DataUrl: req.DataUrl,
  455. EngineID: req.EngineID,
  456. // TrainUrl: req.TrainUrl,
  457. LogUrl: req.LogUrl,
  458. PoolID: req.PoolID,
  459. CreateVersion: true,
  460. Flavor: models.Flavor{
  461. Code: req.FlavorCode,
  462. },
  463. Parameter: req.Parameters,
  464. },
  465. })
  466. if err != nil {
  467. log.Error("CreateJob failed: %v", err.Error())
  468. return err
  469. }
  470. attach, err := models.GetAttachmentByUUID(req.Uuid)
  471. if err != nil {
  472. log.Error("GetAttachmentByUUID(%s) failed:%v", strconv.FormatInt(jobResult.JobID, 10), err.Error())
  473. return err
  474. }
  475. err = models.CreateCloudbrain(&models.Cloudbrain{
  476. Status: TransTrainJobStatus(jobResult.Status),
  477. UserID: ctx.User.ID,
  478. RepoID: ctx.Repo.Repository.ID,
  479. JobID: strconv.FormatInt(jobResult.JobID, 10),
  480. JobName: req.JobName,
  481. JobType: string(models.JobTypeInference),
  482. Type: models.TypeCloudBrainTwo,
  483. VersionID: jobResult.VersionID,
  484. VersionName: jobResult.VersionName,
  485. Uuid: req.Uuid,
  486. DatasetName: attach.Name,
  487. CommitID: req.CommitID,
  488. EngineID: req.EngineID,
  489. TrainUrl: req.TrainUrl,
  490. BranchName: req.BranchName,
  491. Parameters: req.Params,
  492. BootFile: req.BootFile,
  493. DataUrl: req.DataUrl,
  494. LogUrl: req.LogUrl,
  495. FlavorCode: req.FlavorCode,
  496. Description: req.Description,
  497. WorkServerNumber: req.WorkServerNumber,
  498. FlavorName: req.FlavorName,
  499. EngineName: req.EngineName,
  500. LabelName: req.LabelName,
  501. IsLatestVersion: req.IsLatestVersion,
  502. VersionCount: req.VersionCount,
  503. TotalVersionCount: req.TotalVersionCount,
  504. ModelName: req.ModelName,
  505. ModelVersion: req.ModelVersion,
  506. CkptName: req.CkptName,
  507. ResultUrl: req.ResultUrl,
  508. })
  509. if err != nil {
  510. log.Error("CreateCloudbrain(%s) failed:%v", req.JobName, err.Error())
  511. return err
  512. }
  513. return nil
  514. }