Thread local と 実効ユーザIDの変更クラスのつづき

先日のエントリで書いた実効ユーザID変更クラスはマルチスレッドではうまく動かない.それだとちょっと困るので書き直してみた.そのために,threading の localを使った.

threading local

threading.localは各スレッドに固有の値を持たせためのしかけ.Java でいうところのjava.lang.ThreadLocalに相当する.一見ただの変数にみえるのだけど,アクセスするスレッドによって読みだされる値が違う.threading.local はPythonのオブジェクトなので任意のkey-valueペアを属性として登録することができる.

localData = threading.local()
localData.key = 1

簡単な例としてスレッドごとに作ったインスタンスの数を覚えているクラスを書いてみよう.こんな感じ.

import threading

class testLocal():
    localData = threading.local()
    @classmethod
    def _makeSureCounter(cls):
        if not cls.localData.__dict__.has_key('counter'):
            cls.localData.counter = 0
    def __init__(self):
        self._makeSureCounter()
        self.localData.counter += 1
        print self.localData.counter, ' th instance'

使ってみて分かるのは,初期化が面倒なこと.初めて使うスレッドでは初期値を設定しないといけないのだけど,初めてなのかどうかをいちいち判断しないといけない.この辺,javaのThreadLocalのほうが遥かに洗練されている.

Thread固有の値を持たせる方法としてはThreadそのものを拡張してしまうという方法がある.Javaだとこれは結構たいへんだけど,Pythonだと簡単(というか属性をつけるだけ)なのでどっちが得なのか考えてしまう.Threadに付ける方針で書き直すとこんな感じ.ほとんど同じだ.勿論,別のクラスが_counterという属性を使うとおかしくなってしまうので,threading.localを使った方がいいのだけど.

class testLocal():
    def _makeSureCounter(cls):
        if not threading.currentThread().__dict__.has_key('_counter'):
            threading.currentThread()._counter = 0

    def __init__(self):
        self._makeSureCounter()
        threading.currentThread()._counter += 1
        print threading.currentThread()._counter, ' th instance'

実効ユーザID変更クラス.

複数のスレッドが,個別のIDでなにかしたい場合にseteuidできるようにしてみた.前提として,

  • 各スレッドはデフォルトではどの実効IDで動いていても気にしない
  • 特定の操作(ファイルを作るとか,ジョブを起動するとか)をする際だけsetuidを行う
  • 複数のスレッドが同時に特定の異なるIDで動くことはできない
  • 複数のスレッドが同時に特定の同じIDで動くことはできる

,とする.

  • デフォルトでは,nobodyで動作する

使い方は先日のエントリのものと同じでwithを使う.ユーザIDでなくユーザ名でも動く.

with Euid('root'):
  ...
  with Euid('hidemon'):
    ...

ソースはこんな感じ.

from __future__ import with_statement
import os
import threading
import pwd
import time
import random
import sys

if os.getuid() != 0:
    sys.stderr.write('module "' + __name__ + '" should be used by root only.\n')
    sys.stderr.write('exitting...\n')
    sys.exit(3)

_default_user_id = -1
def setDefaultUid(user):
    global _default_user_id
    if type(user) == str:
        _default_user_id = pwd.getpwnam(user).pw_uid
    else:
        _default_user_id = user
    os.seteuid(_default_user_id)

setDefaultUid('nobody')

class _shared():
    def __init__(self, default_user_id):
        self.cv    = threading.Condition()
        self.count = {}
        self.allCount = 0
        self.default_user_id = default_user_id

    def _is_other_user(self, uid):
        if self.allCount == 0:
            return False
        for k in self.count:
            if uid != k:
                return True
        return False

    def _incCount(self, uid):
        if not self.count.has_key(uid):
            self.count[uid] = 0
        self.count[uid] += 1
        self.allCount += 1
        
    def _decCount(self, uid):
        if not self.count.has_key(uid):
            self.count[uid] = 0
        self.count[uid] -= 1
        self.allCount -= 1

    def enterUser(self, uid):
        with self.cv:
            while self._is_other_user(uid):
                self.cv.wait()
            self._incCount(uid)
            if os.geteuid() != uid:
                os.seteuid(0)
                os.seteuid(uid)
    def exitUser(self, uid):
        with self.cv:
            self._decCount(uid)
            if self.allCount == 0:
                os.seteuid(0)
                os.seteuid(self.default_user_id)
            self.cv.notifyAll()

_shared_obj = _shared(_default_user_id)

class Euid():
    _mydata = threading.local()

    def __init__(self, id):
        if type(id) == str:
            self.id = pwd.getpwnam(id).pw_uid
        else:
            self.id = id
        if not self._mydata.__dict__.has_key('uid_stack'):
            self._mydata.uid_stack = []

    def __enter__(self):
        if len(self._mydata.uid_stack) != 0:
            _shared_obj.exitUser(self._mydata.uid_stack[-1])
        self._mydata.uid_stack.append(self.id)
        _shared_obj.enterUser(self.id)

    def __exit__(self, exc_type, exc_value, traceback):
        _shared_obj.exitUser(self.id)
        self._mydata.uid_stack.pop()
        if len(self._mydata.uid_stack) != 0:
            _shared_obj.enterUser(self._mydata.uid_stack[-1])
        if exc_type:
            return False
        return True

注意: 当然ですがrootで動かさないと動きません.