You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ex1.m 3.5 kB

8 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. %% Machine Learning Online Class - Exercise 1: Linear Regression
  2. % Instructions
  3. % ------------
  4. %
  5. % This file contains code that helps you get started on the
  6. % linear exercise. You will need to complete the following functions
  7. % in this exericse:
  8. %
  9. % warmUpExercise.m
  10. % plotData.m
  11. % gradientDescent.m
  12. % computeCost.m
  13. % gradientDescentMulti.m
  14. % computeCostMulti.m
  15. % featureNormalize.m
  16. % normalEqn.m
  17. %
  18. % For this exercise, you will not need to change any code in this file,
  19. % or any other files other than those mentioned above.
  20. %
  21. % x refers to the population size in 10,000s
  22. % y refers to the profit in $10,000s
  23. %
  24. %% Initialization
  25. clear ; close all; clc
  26. %% ==================== Part 1: Basic Function ====================
  27. % Complete warmUpExercise.m
  28. fprintf('Running warmUpExercise ... \n');
  29. fprintf('5x5 Identity Matrix: \n');
  30. warmUpExercise()
  31. fprintf('Program paused. Press enter to continue.\n');
  32. pause;
  33. %% ======================= Part 2: Plotting =======================
  34. fprintf('Plotting Data ...\n')
  35. data = load('ex1data1.txt');
  36. X = data(:, 1); y = data(:, 2);
  37. m = length(y); % number of training examples
  38. % Plot Data
  39. % Note: You have to complete the code in plotData.m
  40. plotData(X, y);
  41. fprintf('Program paused. Press enter to continue.\n');
  42. pause;
  43. %% =================== Part 3: Gradient descent ===================
  44. fprintf('Running Gradient Descent ...\n')
  45. X = [ones(m, 1), data(:,1)]; % Add a column of ones to x
  46. theta = zeros(2, 1); % initialize fitting parameters
  47. % Some gradient descent settings
  48. iterations = 1500;
  49. alpha = 0.01;
  50. % compute and display initial cost
  51. computeCost(X, y, theta)
  52. % run gradient descent
  53. theta = gradientDescent(X, y, theta, alpha, iterations);
  54. % print theta to screen
  55. fprintf('Theta found by gradient descent: ');
  56. fprintf('%f %f \n', theta(1), theta(2));
  57. % Plot the linear fit
  58. hold on; % keep previous plot visible
  59. plot(X(:,2), X*theta, '-')
  60. legend('Training data', 'Linear regression')
  61. hold off % don't overlay any more plots on this figure
  62. % Predict values for population sizes of 35,000 and 70,000
  63. predict1 = [1, 3.5] *theta;
  64. fprintf('For population = 35,000, we predict a profit of %f\n',...
  65. predict1*10000);
  66. predict2 = [1, 7] * theta;
  67. fprintf('For population = 70,000, we predict a profit of %f\n',...
  68. predict2*10000);
  69. fprintf('Program paused. Press enter to continue.\n');
  70. pause;
  71. %% ============= Part 4: Visualizing J(theta_0, theta_1) =============
  72. fprintf('Visualizing J(theta_0, theta_1) ...\n')
  73. % Grid over which we will calculate J
  74. theta0_vals = linspace(-10, 10, 100);
  75. theta1_vals = linspace(-1, 4, 100);%从-1到4之间取100个数组成一个向量
  76. % initialize J_vals to a matrix of 0's
  77. J_vals = zeros(length(theta0_vals), length(theta1_vals));
  78. % Fill out J_vals
  79. for i = 1:length(theta0_vals)
  80. for j = 1:length(theta1_vals)
  81. t = [theta0_vals(i); theta1_vals(j)];
  82. J_vals(i,j) = computeCost(X, y, t);
  83. end
  84. end
  85. % Because of the way meshgrids work in the surf command, we need to
  86. % transpose J_vals before calling surf, or else the axes will be flipped
  87. J_vals = J_vals';
  88. % Surface plot
  89. figure;
  90. surf(theta0_vals, theta1_vals, J_vals)%画出三维图形
  91. xlabel('\theta_0'); ylabel('\theta_1');
  92. % Contour plot 轮廓图
  93. figure;
  94. % Plot J_vals as 15 contours spaced logarithmically between 0.01 and 100
  95. contour(theta0_vals, theta1_vals, J_vals, logspace(-2, 3, 20))
  96. xlabel('\theta_0'); ylabel('\theta_1');
  97. hold on;
  98. plot(theta(1), theta(2), 'rx', 'MarkerSize', 10, 'LineWidth', 2);

机器学习

Contributors (1)