# -*- coding: utf-8 -*-

import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import odeint

# ===========================================================
# 常微分方程式を解くクラス
# ===========================================================
class ODE(object):

    # -------------------------------------------------------
    # コンストラクター
    # -------------------------------------------------------
    def __init__(self, diff_eq, init_con):

        self.diff_eq  = diff_eq
        self.init_con = init_con

        
    # -------------------------------------------------------
    # 常微分方程式の計算
    # -------------------------------------------------------
    def cal_euation(self, x_min, x_max, N):

        x = np.linspace(x_min, x_max, N)             # x の配列の生成
        y = odeint(self.diff_eq, self.init_con, x)   # 方程式の計算

        return x, y



# -------------------------------------------------------
# 解くべき常微分方程式
# -------------------------------------------------------
def diff_eq(y, x):
    dydx = np.cos(x)
    return dydx


# -------------------------------------------------------
# プロット
# ------------------------------------------------------- 
def plot(x, y, x_range, y_range):
    fig = plt.figure()

    # ----- プロットの準備 -----
    sol = fig.add_subplot(1,1,1)
    sol.set_xlabel("x", fontsize=20, fontname='serif')
    sol.set_ylabel("y", fontsize=20, fontname='serif')
    sol.tick_params(axis='both', length=10, which='major')
    sol.tick_params(axis='both', length=5,  which='minor')
    sol.set_xlim(x_range)
    sol.set_ylim(y_range)
    sol.minorticks_on()
    sol.plot(x, y, 'b-', markersize=5)

    # ----- スクリーン表示 -----
    fig.tight_layout()
    plt.show()
        
    # ----- pdf 作成 -----
    fig.savefig('ode_solve.pdf', orientation='portrait', \
                transparent=False, bbox_inches=None, frameon=None)
    fig.clf()



# -------------------------------------------------------
# メイン関数
# -------------------------------------------------------
if __name__ == "__main__":

    N = 1000                              # 分割数
    min_x = 0                             # x の最小
    max_x = 4*np.pi                       # x の最大
    initial_condition = np.array([0])     # 初期条件

    ode = ODE(diff_eq, initial_condition)
    x, y = ode.cal_euation(min_x, max_x, N)

    plot(x, y, (min_x, max_x), (-1.2, 1.2))
    
