You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

modelarts.go 12 kB

4 years ago
3 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
3 years ago
3 years ago
4 years ago
3 years ago
3 years ago
3 years ago
3 years ago
4 years ago
3 years ago
3 years ago
4 years ago
4 years ago
4 years ago
4 years ago
3 years ago
3 years ago
3 years ago
3 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
3 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
4 years ago
4 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
4 years ago
4 years ago
3 years ago
4 years ago
3 years ago
4 years ago
3 years ago
4 years ago
3 years ago
4 years ago
3 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451
  1. package modelarts
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "path"
  6. "strconv"
  7. "code.gitea.io/gitea/models"
  8. "code.gitea.io/gitea/modules/context"
  9. "code.gitea.io/gitea/modules/log"
  10. "code.gitea.io/gitea/modules/setting"
  11. "code.gitea.io/gitea/modules/storage"
  12. )
  13. const (
  14. //notebook
  15. storageTypeOBS = "obs"
  16. autoStopDuration = 4 * 60 * 60
  17. DataSetMountPath = "/home/ma-user/work"
  18. NotebookEnv = "Python3"
  19. NotebookType = "Ascend"
  20. FlavorInfo = "Ascend: 1*Ascend 910 CPU: 24 核 96GiB (modelarts.kat1.xlarge)"
  21. //train-job
  22. // ResourcePools = "{\"resource_pool\":[{\"id\":\"pool1328035d\", \"value\":\"专属资源池\"}]}"
  23. // Engines = "{\"engine\":[{\"id\":1, \"value\":\"Ascend-Powered-Engine\"}]}"
  24. // EngineVersions = "{\"version\":[{\"id\":118,\"value\":\"MindSpore-1.0.0-c75-python3.7-euleros2.8-aarch64\"}," +
  25. // "{\"id\":119,\"value\":\"MindSpore-1.1.1-c76-python3.7-euleros2.8-aarch64\"}," +
  26. // "{\"id\":120,\"value\":\"MindSpore-1.1.1-c76-tr5-python3.7-euleros2.8-aarch64\"}," +
  27. // "{\"id\":117,\"value\":\"TF-1.15-c75-python3.7-euleros2.8-aarch64\"}" +
  28. // "]}"
  29. // TrainJobFlavorInfo = "{\"flavor\":[{\"code\":\"modelarts.bm.910.arm.public.2\",\"value\":\"Ascend : 2 * Ascend 910 CPU:48 核 512GiB\"}," +
  30. // "{\"code\":\"modelarts.bm.910.arm.public.8\",\"value\":\"Ascend : 8 * Ascend 910 CPU:192 核 2048GiB\"}," +
  31. // "{\"code\":\"modelarts.bm.910.arm.public.4\",\"value\":\"Ascend : 4 * Ascend 910 CPU:96 核 1024GiB\"}," +
  32. // "{\"code\":\"modelarts.bm.910.arm.public.1\",\"value\":\"Ascend : 1 * Ascend 910 CPU:24 核 256GiB\"}" +
  33. // "]}"
  34. CodePath = "/code/"
  35. OutputPath = "/output/"
  36. LogPath = "/log/"
  37. JobPath = "/job/"
  38. OrderDesc = "desc" //向下查询
  39. OrderAsc = "asc" //向上查询
  40. Lines = 500
  41. TrainUrl = "train_url"
  42. DataUrl = "data_url"
  43. PerPage = 10
  44. IsLatestVersion = "1"
  45. NotLatestVersion = "0"
  46. // ComputeResource = "NPU"
  47. NPUResource = "NPU"
  48. GPUResource = "CPU/GPU"
  49. AllResource = "all"
  50. DebugType = -1
  51. VersionCount = 1
  52. SortByCreateTime = "create_time"
  53. ConfigTypeCustom = "custom"
  54. TotalVersionCount = 1
  55. )
  56. var (
  57. poolInfos *models.PoolInfos
  58. FlavorInfos *models.FlavorInfos
  59. )
  60. type GenerateTrainJobReq struct {
  61. JobName string
  62. Uuid string
  63. Description string
  64. CodeObsPath string
  65. BootFile string
  66. BootFileUrl string
  67. DataUrl string
  68. TrainUrl string
  69. FlavorCode string
  70. LogUrl string
  71. PoolID string
  72. WorkServerNumber int
  73. EngineID int64
  74. Parameters []models.Parameter
  75. CommitID string
  76. IsLatestVersion string
  77. Params string
  78. BranchName string
  79. PreVersionId int64
  80. PreVersionName string
  81. FlavorName string
  82. VersionCount int
  83. EngineName string
  84. TotalVersionCount int
  85. }
  86. type GenerateTrainJobVersionReq struct {
  87. JobName string
  88. Uuid string
  89. Description string
  90. CodeObsPath string
  91. BootFile string
  92. BootFileUrl string
  93. DataUrl string
  94. TrainUrl string
  95. FlavorCode string
  96. LogUrl string
  97. PoolID string
  98. WorkServerNumber int
  99. EngineID int64
  100. Parameters []models.Parameter
  101. Params string
  102. PreVersionId int64
  103. CommitID string
  104. BranchName string
  105. FlavorName string
  106. EngineName string
  107. PreVersionName string
  108. TotalVersionCount int
  109. }
  110. type VersionInfo struct {
  111. Version []struct {
  112. ID int `json:"id"`
  113. Value string `json:"value"`
  114. } `json:"version"`
  115. }
  116. type Flavor struct {
  117. Info []struct {
  118. Code string `json:"code"`
  119. Value string `json:"value"`
  120. } `json:"flavor"`
  121. }
  122. type Engine struct {
  123. Info []struct {
  124. ID int `json:"id"`
  125. Value string `json:"value"`
  126. } `json:"engine"`
  127. }
  128. type ResourcePool struct {
  129. Info []struct {
  130. ID string `json:"id"`
  131. Value string `json:"value"`
  132. } `json:"resource_pool"`
  133. }
  134. // type Parameter struct {
  135. // Label string `json:"label"`
  136. // Value string `json:"value"`
  137. // }
  138. // type Parameters struct {
  139. // Parameter []Parameter `json:"parameter"`
  140. // }
  141. type Parameters struct {
  142. Parameter []struct {
  143. Label string `json:"label"`
  144. Value string `json:"value"`
  145. } `json:"parameter"`
  146. }
  147. func GenerateTask(ctx *context.Context, jobName, uuid, description, flavor string) error {
  148. var dataActualPath string
  149. if uuid != "" {
  150. dataActualPath = setting.Bucket + "/" + setting.BasePath + path.Join(uuid[0:1], uuid[1:2]) + "/" + uuid + "/"
  151. } else {
  152. userPath := setting.UserBasePath + ctx.User.Name + "/"
  153. isExist, err := storage.ObsHasObject(userPath)
  154. if err != nil {
  155. log.Error("ObsHasObject failed:%v", err.Error(), ctx.Data["MsgID"])
  156. return err
  157. }
  158. if !isExist {
  159. if err = storage.ObsCreateObject(userPath); err != nil {
  160. log.Error("ObsCreateObject failed:%v", err.Error(), ctx.Data["MsgID"])
  161. return err
  162. }
  163. }
  164. dataActualPath = setting.Bucket + "/" + userPath
  165. }
  166. if poolInfos == nil {
  167. json.Unmarshal([]byte(setting.PoolInfos), &poolInfos)
  168. }
  169. jobResult, err := CreateJob(models.CreateNotebookParams{
  170. JobName: jobName,
  171. Description: description,
  172. ProfileID: setting.ProfileID,
  173. Flavor: flavor,
  174. Pool: models.Pool{
  175. ID: poolInfos.PoolInfo[0].PoolId,
  176. Name: poolInfos.PoolInfo[0].PoolName,
  177. Type: poolInfos.PoolInfo[0].PoolType,
  178. },
  179. Spec: models.Spec{
  180. Storage: models.Storage{
  181. Type: storageTypeOBS,
  182. Location: models.Location{
  183. Path: dataActualPath,
  184. },
  185. },
  186. AutoStop: models.AutoStop{
  187. Enable: true,
  188. Duration: autoStopDuration,
  189. },
  190. },
  191. })
  192. if err != nil {
  193. log.Error("CreateJob failed: %v", err.Error())
  194. return err
  195. }
  196. err = models.CreateCloudbrain(&models.Cloudbrain{
  197. Status: string(models.JobWaiting),
  198. UserID: ctx.User.ID,
  199. RepoID: ctx.Repo.Repository.ID,
  200. JobID: jobResult.ID,
  201. JobName: jobName,
  202. JobType: string(models.JobTypeDebug),
  203. Type: models.TypeCloudBrainTwo,
  204. Uuid: uuid,
  205. })
  206. if err != nil {
  207. return err
  208. }
  209. return nil
  210. }
  211. func GenerateTrainJob(ctx *context.Context, req *GenerateTrainJobReq) (err error) {
  212. jobResult, err := createTrainJob(models.CreateTrainJobParams{
  213. JobName: req.JobName,
  214. Description: req.Description,
  215. Config: models.Config{
  216. WorkServerNum: req.WorkServerNumber,
  217. AppUrl: req.CodeObsPath,
  218. BootFileUrl: req.BootFileUrl,
  219. DataUrl: req.DataUrl,
  220. EngineID: req.EngineID,
  221. TrainUrl: req.TrainUrl,
  222. LogUrl: req.LogUrl,
  223. PoolID: req.PoolID,
  224. CreateVersion: true,
  225. Flavor: models.Flavor{
  226. Code: req.FlavorCode,
  227. },
  228. Parameter: req.Parameters,
  229. },
  230. })
  231. if err != nil {
  232. log.Error("CreateJob failed: %v", err.Error())
  233. return err
  234. }
  235. attach, err := models.GetAttachmentByUUID(req.Uuid)
  236. if err != nil {
  237. log.Error("GetAttachmentByUUID(%s) failed:%v", strconv.FormatInt(jobResult.JobID, 10), err.Error())
  238. return err
  239. }
  240. err = models.CreateCloudbrain(&models.Cloudbrain{
  241. Status: TransTrainJobStatus(jobResult.Status),
  242. UserID: ctx.User.ID,
  243. RepoID: ctx.Repo.Repository.ID,
  244. JobID: strconv.FormatInt(jobResult.JobID, 10),
  245. JobName: req.JobName,
  246. JobType: string(models.JobTypeTrain),
  247. Type: models.TypeCloudBrainTwo,
  248. VersionID: jobResult.VersionID,
  249. VersionName: jobResult.VersionName,
  250. Uuid: req.Uuid,
  251. DatasetName: attach.Name,
  252. CommitID: req.CommitID,
  253. IsLatestVersion: req.IsLatestVersion,
  254. ComputeResource: NPUResource,
  255. EngineID: req.EngineID,
  256. TrainUrl: req.TrainUrl,
  257. BranchName: req.BranchName,
  258. Parameters: req.Params,
  259. BootFile: req.BootFile,
  260. DataUrl: req.DataUrl,
  261. LogUrl: req.LogUrl,
  262. FlavorCode: req.FlavorCode,
  263. Description: req.Description,
  264. WorkServerNumber: req.WorkServerNumber,
  265. FlavorName: req.FlavorName,
  266. EngineName: req.EngineName,
  267. VersionCount: req.VersionCount,
  268. TotalVersionCount: req.TotalVersionCount,
  269. })
  270. if err != nil {
  271. log.Error("CreateCloudbrain(%s) failed:%v", req.JobName, err.Error())
  272. return err
  273. }
  274. return nil
  275. }
  276. func GenerateTrainJobVersion(ctx *context.Context, req *GenerateTrainJobReq, jobId string) (err error) {
  277. jobResult, err := createTrainJobVersion(models.CreateTrainJobVersionParams{
  278. Description: req.Description,
  279. Config: models.TrainJobVersionConfig{
  280. WorkServerNum: req.WorkServerNumber,
  281. AppUrl: req.CodeObsPath,
  282. BootFileUrl: req.BootFileUrl,
  283. DataUrl: req.DataUrl,
  284. EngineID: req.EngineID,
  285. TrainUrl: req.TrainUrl,
  286. LogUrl: req.LogUrl,
  287. PoolID: req.PoolID,
  288. Flavor: models.Flavor{
  289. Code: req.FlavorCode,
  290. },
  291. Parameter: req.Parameters,
  292. PreVersionId: req.PreVersionId,
  293. },
  294. }, jobId)
  295. if err != nil {
  296. log.Error("CreateJob failed: %v", err.Error())
  297. return err
  298. }
  299. attach, err := models.GetAttachmentByUUID(req.Uuid)
  300. if err != nil {
  301. log.Error("GetAttachmentByUUID(%s) failed:%v", strconv.FormatInt(jobResult.JobID, 10), err.Error())
  302. return err
  303. }
  304. repo := ctx.Repo.Repository
  305. VersionTaskList, VersionListCount, err := models.CloudbrainsVersionList(&models.CloudbrainsOptions{
  306. RepoID: repo.ID,
  307. Type: models.TypeCloudBrainTwo,
  308. JobType: string(models.JobTypeTrain),
  309. JobID: strconv.FormatInt(jobResult.JobID, 10),
  310. })
  311. if err != nil {
  312. ctx.ServerError("Cloudbrain", err)
  313. return err
  314. }
  315. //将当前版本的isLatestVersion设置为"1"和任务数量更新,任务数量包括当前版本数VersionCount和历史创建的总版本数TotalVersionCount
  316. err = models.CreateCloudbrain(&models.Cloudbrain{
  317. Status: TransTrainJobStatus(jobResult.Status),
  318. UserID: ctx.User.ID,
  319. RepoID: ctx.Repo.Repository.ID,
  320. JobID: strconv.FormatInt(jobResult.JobID, 10),
  321. JobName: req.JobName,
  322. JobType: string(models.JobTypeTrain),
  323. Type: models.TypeCloudBrainTwo,
  324. VersionID: jobResult.VersionID,
  325. VersionName: jobResult.VersionName,
  326. Uuid: req.Uuid,
  327. DatasetName: attach.Name,
  328. CommitID: req.CommitID,
  329. IsLatestVersion: req.IsLatestVersion,
  330. PreVersionName: req.PreVersionName,
  331. ComputeResource: NPUResource,
  332. EngineID: req.EngineID,
  333. TrainUrl: req.TrainUrl,
  334. BranchName: req.BranchName,
  335. Parameters: req.Params,
  336. BootFile: req.BootFile,
  337. DataUrl: req.DataUrl,
  338. LogUrl: req.LogUrl,
  339. PreVersionId: req.PreVersionId,
  340. FlavorCode: req.FlavorCode,
  341. Description: req.Description,
  342. WorkServerNumber: req.WorkServerNumber,
  343. FlavorName: req.FlavorName,
  344. EngineName: req.EngineName,
  345. TotalVersionCount: VersionTaskList[0].TotalVersionCount + 1,
  346. VersionCount: VersionListCount + 1,
  347. })
  348. if err != nil {
  349. log.Error("CreateCloudbrain(%s) failed:%v", req.JobName, err.Error())
  350. return err
  351. }
  352. //将训练任务的上一版本的isLatestVersion设置为"0"
  353. err = models.SetVersionCountAndLatestVersion(strconv.FormatInt(jobResult.JobID, 10), VersionTaskList[0].VersionName, VersionCount, NotLatestVersion, TotalVersionCount)
  354. if err != nil {
  355. ctx.ServerError("Update IsLatestVersion failed", err)
  356. return err
  357. }
  358. return err
  359. }
  360. func TransTrainJobStatus(status int) string {
  361. switch status {
  362. case 0:
  363. return "UNKNOWN"
  364. case 1:
  365. return "INIT"
  366. case 2:
  367. return "IMAGE_CREATING"
  368. case 3:
  369. return "IMAGE_FAILED"
  370. case 4:
  371. return "SUBMIT_TRYING"
  372. case 5:
  373. return "SUBMIT_FAILED"
  374. case 6:
  375. return "DELETE_FAILED"
  376. case 7:
  377. return "WAITING"
  378. case 8:
  379. return "RUNNING"
  380. case 9:
  381. return "KILLING"
  382. case 10:
  383. return "COMPLETED"
  384. case 11:
  385. return "FAILED"
  386. case 12:
  387. return "KILLED"
  388. case 13:
  389. return "CANCELED"
  390. case 14:
  391. return "LOST"
  392. case 15:
  393. return "SCALING"
  394. case 16:
  395. return "SUBMIT_MODEL_FAILED"
  396. case 17:
  397. return "DEPLOY_SERVICE_FAILED"
  398. case 18:
  399. return "CHECK_INIT"
  400. case 19:
  401. return "CHECK_RUNNING"
  402. case 20:
  403. return "CHECK_RUNNING_COMPLETED"
  404. case 21:
  405. return "CHECK_FAILED"
  406. default:
  407. return strconv.Itoa(status)
  408. }
  409. }
  410. func GetVersionOutputPathByTotalVersionCount(TotalVersionCount int) (VersionOutputPath string) {
  411. talVersionCountToString := fmt.Sprintf("%04d", TotalVersionCount)
  412. VersionOutputPath = "V" + talVersionCountToString
  413. return VersionOutputPath
  414. }