You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ml.py 1.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. def loglik(X, y, w):
  2. import numpy as np
  3. return np.sum(-y*(X@w) + np.log(1+np.exp(X@w)))
  4. def reg_log(X, y, ite_max=100, lbd=1e-12, pos_contraint=False):
  5. """
  6. y \in 1,0
  7. """
  8. import numpy as np
  9. def proj_on_pos(w):
  10. return np.array([x if x > 0 else 0 for x in w])
  11. tol = 1e-4
  12. N, d = X.shape
  13. y = np.array(y)
  14. w = np.zeros(d) # see 4.4 of ESLII
  15. weights = [w]
  16. J = [loglik(X, y, w)]
  17. # print(f"J[0] = {J[0]}")
  18. old_J = J[0] + 1
  19. conv = False
  20. i = 0
  21. while(not conv):
  22. i = i + 1
  23. Xw = X @ w
  24. p = np.exp(Xw)/(1+np.exp(Xw))
  25. W = np.diag(p)
  26. regul = lbd*np.identity(d)
  27. descent = np.linalg.solve(X.T @ W @ X + regul, X.T@(y-p))
  28. # print(f"descent: {descent}")
  29. step = 1
  30. update = 0.1
  31. cur_w = w+step*descent
  32. if pos_contraint:
  33. cur_w = proj_on_pos(cur_w)
  34. # print(f"cur_w : {cur_w}")
  35. # print(f"J : {loglik(X,y,cur_w)}")
  36. while (loglik(X, y, cur_w) > J[-1]):
  37. step = step*update
  38. cur_w = w + step*descent
  39. if pos_contraint:
  40. cur_w = proj_on_pos(cur_w)
  41. # print(f"step : {step}")
  42. w = cur_w
  43. J.append(loglik(X, y, w))
  44. weights.append(w)
  45. if (i > ite_max):
  46. conv = True
  47. if ((old_J - J[-1]) < tol):
  48. conv = True
  49. else:
  50. old_J = J[-1]
  51. return w, J, weights

A Python package for graph kernels, graph edit distances and graph pre-image problem.